terapyon commited on
Commit
33db183
·
1 Parent(s): 3ec177b

can select model for GPT-4

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from langchain.chains import RetrievalQA
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.llms import OpenAI
 
5
  from langchain.vectorstores import Qdrant
6
  from openai.error import InvalidRequestError
7
  from qdrant_client import QdrantClient
@@ -9,13 +10,24 @@ from config import DB_CONFIG
9
 
10
 
11
  PERSIST_DIR_NAME = "nvdajp-book"
 
 
 
12
 
13
 
14
- def get_retrieval_qa(temperature: int, option: str) -> RetrievalQA:
15
  embeddings = OpenAIEmbeddings()
16
  db_url, db_api_key, db_collection_name = DB_CONFIG
17
  client = QdrantClient(url=db_url, api_key=db_api_key)
18
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
 
 
 
 
 
 
 
 
19
  if option is None or option == "All":
20
  retriever = db.as_retriever()
21
  else:
@@ -25,7 +37,13 @@ def get_retrieval_qa(temperature: int, option: str) -> RetrievalQA:
25
  }
26
  )
27
  return RetrievalQA.from_chain_type(
28
- llm=OpenAI(temperature=temperature), chain_type="stuff", retriever=retriever, return_source_documents=True,
 
 
 
 
 
 
29
  )
30
 
31
 
@@ -42,8 +60,8 @@ def get_related_url(metadata):
42
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
43
 
44
 
45
- def main(query: str, option: str, temperature: int):
46
- qa = get_retrieval_qa(temperature, option)
47
  try:
48
  result = qa(query)
49
  except InvalidRequestError as e:
@@ -59,6 +77,7 @@ nvdajp_book_qa = gr.Interface(
59
  fn=main,
60
  inputs=[
61
  gr.Textbox(label="query"),
 
62
  gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
63
  gr.Slider(0, 2)
64
  ],
 
2
  from langchain.chains import RetrievalQA
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.llms import OpenAI
5
+ from langchain.chat_models import ChatOpenAI
6
  from langchain.vectorstores import Qdrant
7
  from openai.error import InvalidRequestError
8
  from qdrant_client import QdrantClient
 
10
 
11
 
12
  PERSIST_DIR_NAME = "nvdajp-book"
13
+ # MODEL_NAME = "text-davinci-003"
14
+ # MODEL_NAME = "gpt-3.5-turbo"
15
+ # MODEL_NAME = "gpt-4"
16
 
17
 
18
+ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
19
  embeddings = OpenAIEmbeddings()
20
  db_url, db_api_key, db_collection_name = DB_CONFIG
21
  client = QdrantClient(url=db_url, api_key=db_api_key)
22
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
23
+ if model_name is None:
24
+ model = "gpt-3.5-turbo"
25
+ elif model_name == "GPT-3.5":
26
+ model = "gpt-3.5-turbo"
27
+ elif model_name == "GPT-4":
28
+ model = "gpt-4"
29
+ else:
30
+ model = "gpt-3.5-turbo"
31
  if option is None or option == "All":
32
  retriever = db.as_retriever()
33
  else:
 
37
  }
38
  )
39
  return RetrievalQA.from_chain_type(
40
+ llm=ChatOpenAI(
41
+ model=model,
42
+ temperature=temperature
43
+ ),
44
+ chain_type="stuff",
45
+ retriever=retriever,
46
+ return_source_documents=True,
47
  )
48
 
49
 
 
60
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
61
 
62
 
63
+ def main(query: str, model_name: str, option: str, temperature: int):
64
+ qa = get_retrieval_qa(model_name, temperature, option)
65
  try:
66
  result = qa(query)
67
  except InvalidRequestError as e:
 
77
  fn=main,
78
  inputs=[
79
  gr.Textbox(label="query"),
80
+ gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
81
  gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
82
  gr.Slider(0, 2)
83
  ],