from time import time from typing import Iterable # import gradio as gr import streamlit as st from langchain.chains import RetrievalQA from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import HuggingFaceEmbeddings # from langchain.prompts import PromptTemplate import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.llms import HuggingFacePipeline # from langchain.llms import OpenAI from langchain.chat_models import ChatOpenAI from langchain.vectorstores import Qdrant from openai.error import InvalidRequestError from qdrant_client import QdrantClient from config import DB_CONFIG, DB_E5_CONFIG @st.cache_resource def load_e5_embeddings(): model_name = "intfloat/multilingual-e5-large" model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"} encode_kwargs = {"normalize_embeddings": False} embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, ) return embeddings @st.cache_resource def load_rinna_model(): if torch.cuda.is_available(): model_name = "rinna/bilingual-gpt-neox-4b-instruction-ppo" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", ) return tokenizer, model else: return None, None E5_EMBEDDINGS = load_e5_embeddings() RINNA_TOKENIZER, RINNA_MODEL = load_rinna_model() def _get_config_and_embeddings(collection_name: str | None) -> tuple: if collection_name is None or collection_name == "E5": db_config = DB_E5_CONFIG embeddings = E5_EMBEDDINGS elif collection_name == "OpenAI": db_config = DB_CONFIG embeddings = OpenAIEmbeddings() else: raise ValueError("Unknow collection name") return db_config, embeddings @st.cache_resource def _get_rinna_llm(temperature: float) -> HuggingFacePipeline | None: if RINNA_MODEL is not None: pipe = pipeline( "text-generation", model=RINNA_MODEL, tokenizer=RINNA_TOKENIZER, max_new_tokens=1024, temperature=temperature, ) llm = HuggingFacePipeline(pipeline=pipe) else: llm = None return llm def _get_llm_model( model_name: str | None, temperature: float, ): if model_name is None: model = "gpt-3.5-turbo" elif model_name == "rinna": model = "rinna" elif model_name == "GPT-3.5": model = "gpt-3.5-turbo" elif model_name == "GPT-4": model = "gpt-4" else: raise ValueError("Unknow model name") if model.startswith("gpt"): llm = ChatOpenAI(model=model, temperature=temperature) elif model == "rinna": llm = _get_rinna_llm(temperature) return llm def get_retrieval_qa( collection_name: str | None, model_name: str | None, temperature: float, option: str | None, ): db_config, embeddings = _get_config_and_embeddings(collection_name) db_url, db_api_key, db_collection_name = db_config client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant( client=client, collection_name=db_collection_name, embeddings=embeddings ) if option is None or option == "All": retriever = db.as_retriever() else: retriever = db.as_retriever( search_kwargs={ "filter": {"category": option}, } ) llm = _get_llm_model(model_name, temperature) # chain_type_kwargs = {"prompt": PROMPT} result = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, # chain_type_kwargs=chain_type_kwargs, ) return result def get_related_url(metadata) -> Iterable[str]: urls = set() for m in metadata: # p = m['source'] url = m["url"] if url in urls: continue urls.add(url) category = m["category"] # print(m) yield f'

URL: {url} (category: {category})

' def run_qa(query: str, qa: RetrievalQA) -> tuple[str, str]: now = time() try: result = qa(query) except InvalidRequestError as e: return "回答が見つかりませんでした。別な質問をしてみてください", str(e) else: metadata = [s.metadata for s in result["source_documents"]] sec_html = f"

実行時間: {(time() - now):.2f}秒

" html = "
" + sec_html + "\n".join(get_related_url(metadata)) + "
" return result["result"], html def main( query: str, collection_name: str | None, model_name: str | None, option: str | None, temperature: float, e5_option: list[str], ) -> Iterable[tuple[str, tuple[str, str]]]: qa = get_retrieval_qa(collection_name, model_name, temperature, option) if collection_name == "E5": for option in e5_option: if option == "No": yield "E5 No", run_qa(query, qa) elif option == "Query": yield "E5 Query", run_qa("query: " + query, qa) elif option == "Passage": yield "E5 Passage", run_qa("passage: " + query, qa) else: raise ValueError("Unknow option") else: yield "OpenAI", run_qa(query, qa) AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"] if RINNA_MODEL is not None: AVAILABLE_LLMS.append("rinna") with st.form("my_form"): query = st.text_input(label="query") collection_name = st.radio(options=["E5", "OpenAI"], label="Embedding") # if collection_name == "E5": # TODO : 選択肢で選べるようにする e5_option = st.multiselect("E5 option", ["No", "Query", "Passage"], default="No") model_name = st.radio( options=AVAILABLE_LLMS, label="Model", help="GPU環境だとrinnaが選択可能", ) option = st.radio( options=["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", help="ドキュメント制限する?", ) temperature = st.slider(label="temperature", min_value=0, max_value=2) submitted = st.form_submit_button("Submit") if submitted: with st.spinner("Searching..."): results = main( query, collection_name, model_name, option, temperature, e5_option ) for type_, (answer, html) in results: with st.container(): st.header(type_) st.write(answer) st.markdown(html, unsafe_allow_html=True) st.divider()