Spaces:
Paused
Paused
from time import time | |
import gradio as gr | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain.llms import HuggingFacePipeline | |
# from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from langchain.vectorstores import Qdrant | |
from openai.error import InvalidRequestError | |
from qdrant_client import QdrantClient | |
from config import DB_CONFIG, DB_E5_CONFIG | |
E5_MODEL_NAME = "intfloat/multilingual-e5-large" | |
E5_MODEL_KWARGS = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"} | |
E5_ENCODE_KWARGS = {"normalize_embeddings": False} | |
E5_EMBEDDINGS = HuggingFaceEmbeddings( | |
model_name=E5_MODEL_NAME, | |
model_kwargs=E5_MODEL_KWARGS, | |
encode_kwargs=E5_ENCODE_KWARGS, | |
) | |
if torch.cuda.is_available(): | |
RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo" | |
RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False) | |
RINNA_MODEL = AutoModelForCausalLM.from_pretrained( | |
RINNA_MODEL_NAME, | |
load_in_8bit=True, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
else: | |
RINNA_MODEL = None | |
def _get_config_and_embeddings(collection_name: str | None) -> tuple: | |
if collection_name is None or collection_name == "E5": | |
db_config = DB_E5_CONFIG | |
embeddings = E5_EMBEDDINGS | |
elif collection_name == "OpenAI": | |
db_config = DB_CONFIG | |
embeddings = OpenAIEmbeddings() | |
else: | |
raise ValueError("Unknow collection name") | |
return db_config, embeddings | |
def _get_rinna_llm(temperature: float): | |
if RINNA_MODEL is not None: | |
pipe = pipeline( | |
"text-generation", | |
model=RINNA_MODEL, | |
tokenizer=RINNA_TOKENIZER, | |
max_new_tokens=1024, | |
temperature=temperature, | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
else: | |
llm = None | |
return llm | |
def _get_llm_model( | |
model_name: str | None, | |
temperature: float, | |
): | |
if model_name is None: | |
model = "gpt-3.5-turbo" | |
elif model_name == "rinna": | |
model = "rinna" | |
elif model_name == "GPT-3.5": | |
model = "gpt-3.5-turbo" | |
elif model_name == "GPT-4": | |
model = "gpt-4" | |
else: | |
raise ValueError("Unknow model name") | |
if model.startswith("gpt"): | |
llm = ChatOpenAI(model=model, temperature=temperature) | |
elif model == "rinna": | |
llm = _get_rinna_llm(temperature) | |
return llm | |
# prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
# {context} | |
# Question: {question} | |
# Answer in Japanese:""" | |
# PROMPT = PromptTemplate( | |
# template=prompt_template, input_variables=["context", "question"] | |
# ) | |
def get_retrieval_qa( | |
collection_name: str | None, | |
model_name: str | None, | |
temperature: float, | |
option: str | None, | |
) -> RetrievalQA: | |
db_config, embeddings = _get_config_and_embeddings(collection_name) | |
db_url, db_api_key, db_collection_name = db_config | |
client = QdrantClient(url=db_url, api_key=db_api_key) | |
db = Qdrant( | |
client=client, collection_name=db_collection_name, embeddings=embeddings | |
) | |
if option is None or option == "All": | |
retriever = db.as_retriever() | |
else: | |
retriever = db.as_retriever( | |
search_kwargs={ | |
"filter": {"category": option}, | |
} | |
) | |
llm = _get_llm_model(model_name, temperature) | |
# chain_type_kwargs = {"prompt": PROMPT} | |
result = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
# chain_type_kwargs=chain_type_kwargs, | |
) | |
return result | |
def get_related_url(metadata): | |
urls = set() | |
for m in metadata: | |
# p = m['source'] | |
url = m["url"] | |
if url in urls: | |
continue | |
urls.add(url) | |
category = m["category"] | |
# print(m) | |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>' | |
def main( | |
query: str, collection_name: str, model_name: str, option: str, temperature: float | |
): | |
now = time() | |
qa = get_retrieval_qa(collection_name, model_name, temperature, option) | |
try: | |
result = qa(query) | |
except InvalidRequestError as e: | |
return "回答が見つかりませんでした。別な質問をしてみてください", str(e) | |
else: | |
metadata = [s.metadata for s in result["source_documents"]] | |
sec_html = f"<p>実行時間: {(time() - now):.2f}秒</p>" | |
html = "<div>" + sec_html + "\n".join(get_related_url(metadata)) + "</div>" | |
return result["result"], html | |
AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"] | |
if RINNA_MODEL is not None: | |
AVAILABLE_LLMS.append("rinna") | |
nvdajp_book_qa = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.Textbox(label="query"), | |
gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"), | |
gr.Radio(AVAILABLE_LLMS, label="Model", info="選択なしで「GPT-3.5」を使用"), | |
gr.Radio( | |
["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], | |
label="絞り込み", | |
info="ドキュメント制限する?", | |
), | |
gr.Slider(0, 2), | |
], | |
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()], | |
) | |
nvdajp_book_qa.launch() | |