terapyon commited on
Commit
2c70642
·
1 Parent(s): 9022e07

dev/modify-embedding-test (#4)

Browse files

- try to change embedding model (6125df0f09d11d466fb34cc188e156a8b8d2d7e0)
- Embeddings multilingual-e5-largeとLLM rinnaを使えるようにした (143e47c91fdbdedb218bb197b96d4a6ed84d892d)

Files changed (4) hide show
  1. app.py +106 -23
  2. config.py +11 -5
  3. requirements.txt +2 -0
  4. store.py +40 -18
app.py CHANGED
@@ -1,33 +1,105 @@
 
1
  import gradio as gr
2
  from langchain.chains import RetrievalQA
3
  from langchain.embeddings import OpenAIEmbeddings
4
- from langchain.llms import OpenAI
 
 
 
 
 
 
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.vectorstores import Qdrant
7
  from openai.error import InvalidRequestError
8
  from qdrant_client import QdrantClient
9
- from config import DB_CONFIG
10
 
11
 
12
- PERSIST_DIR_NAME = "nvdajp-book"
13
- # MODEL_NAME = "text-davinci-003"
14
- # MODEL_NAME = "gpt-3.5-turbo"
15
- # MODEL_NAME = "gpt-4"
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
- def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
19
- embeddings = OpenAIEmbeddings()
20
- db_url, db_api_key, db_collection_name = DB_CONFIG
21
- client = QdrantClient(url=db_url, api_key=db_api_key)
22
- db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  if model_name is None:
24
- model = "gpt-3.5-turbo"
 
 
25
  elif model_name == "GPT-3.5":
26
  model = "gpt-3.5-turbo"
27
  elif model_name == "GPT-4":
28
  model = "gpt-4"
29
  else:
30
- model = "gpt-3.5-turbo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if option is None or option == "All":
32
  retriever = db.as_retriever()
33
  else:
@@ -36,15 +108,19 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
36
  "filter": {"category": option},
37
  }
38
  )
39
- return RetrievalQA.from_chain_type(
40
- llm=ChatOpenAI(
41
- model=model,
42
- temperature=temperature
43
- ),
 
 
44
  chain_type="stuff",
45
  retriever=retriever,
46
  return_source_documents=True,
 
47
  )
 
48
 
49
 
50
  def get_related_url(metadata):
@@ -60,8 +136,10 @@ def get_related_url(metadata):
60
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
61
 
62
 
63
- def main(query: str, model_name: str, option: str, temperature: int):
64
- qa = get_retrieval_qa(model_name, temperature, option)
 
 
65
  try:
66
  result = qa(query)
67
  except InvalidRequestError as e:
@@ -77,9 +155,14 @@ nvdajp_book_qa = gr.Interface(
77
  fn=main,
78
  inputs=[
79
  gr.Textbox(label="query"),
80
- gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
81
- gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
82
- gr.Slider(0, 2)
 
 
 
 
 
83
  ],
84
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
85
  )
 
1
+ # from time import time
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
  from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.prompts import PromptTemplate
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
+ from langchain.llms import HuggingFacePipeline
10
+
11
+ # from langchain.llms import OpenAI
12
  from langchain.chat_models import ChatOpenAI
13
  from langchain.vectorstores import Qdrant
14
  from openai.error import InvalidRequestError
15
  from qdrant_client import QdrantClient
16
+ from config import DB_CONFIG, DB_E5_CONFIG
17
 
18
 
19
+ def _get_config_and_embeddings(collection_name: str | None) -> tuple:
20
+ if collection_name is None or collection_name == "E5":
21
+ db_config = DB_E5_CONFIG
22
+ model_name = "intfloat/multilingual-e5-large"
23
+ model_kwargs = {"device": "cpu"}
24
+ encode_kwargs = {"normalize_embeddings": False}
25
+ embeddings = HuggingFaceEmbeddings(
26
+ model_name=model_name,
27
+ model_kwargs=model_kwargs,
28
+ encode_kwargs=encode_kwargs,
29
+ )
30
+ elif collection_name == "OpenAI":
31
+ db_config = DB_CONFIG
32
+ embeddings = OpenAIEmbeddings()
33
+ else:
34
+ raise ValueError("Unknow collection name")
35
+ return db_config, embeddings
36
 
37
 
38
+ def _get_rinna_llm(temperature: float):
39
+ model = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
40
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model,
43
+ load_in_8bit=True,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto",
46
+ )
47
+ pipe = pipeline(
48
+ "text-generation",
49
+ model=model,
50
+ tokenizer=tokenizer,
51
+ max_new_tokens=1024,
52
+ temperature=temperature,
53
+ )
54
+ llm = HuggingFacePipeline(pipeline=pipe)
55
+ return llm
56
+
57
+
58
+ def _get_llm_model(
59
+ model_name: str | None,
60
+ temperature: float,
61
+ ):
62
  if model_name is None:
63
+ model = "rinna"
64
+ elif model_name == "rinna":
65
+ model = "rinna"
66
  elif model_name == "GPT-3.5":
67
  model = "gpt-3.5-turbo"
68
  elif model_name == "GPT-4":
69
  model = "gpt-4"
70
  else:
71
+ raise ValueError("Unknow model name")
72
+ if model.startswith("gpt"):
73
+ llm = ChatOpenAI(model=model, temperature=temperature)
74
+ elif model == "rinna":
75
+ llm = _get_rinna_llm(temperature)
76
+ return llm
77
+
78
+
79
+ # prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
80
+
81
+ # {context}
82
+
83
+ # Question: {question}
84
+ # Answer in Japanese:"""
85
+ # PROMPT = PromptTemplate(
86
+ # template=prompt_template, input_variables=["context", "question"]
87
+ # )
88
+
89
+
90
+ def get_retrieval_qa(
91
+ collection_name: str | None,
92
+ model_name: str | None,
93
+ temperature: float,
94
+ option: str | None,
95
+ ) -> RetrievalQA:
96
+ db_config, embeddings = _get_config_and_embeddings(collection_name)
97
+ db_url, db_api_key, db_collection_name = db_config
98
+ client = QdrantClient(url=db_url, api_key=db_api_key)
99
+ db = Qdrant(
100
+ client=client, collection_name=db_collection_name, embeddings=embeddings
101
+ )
102
+
103
  if option is None or option == "All":
104
  retriever = db.as_retriever()
105
  else:
 
108
  "filter": {"category": option},
109
  }
110
  )
111
+
112
+ llm = _get_llm_model(model_name, temperature)
113
+
114
+ # chain_type_kwargs = {"prompt": PROMPT}
115
+
116
+ result = RetrievalQA.from_chain_type(
117
+ llm=llm,
118
  chain_type="stuff",
119
  retriever=retriever,
120
  return_source_documents=True,
121
+ # chain_type_kwargs=chain_type_kwargs,
122
  )
123
+ return result
124
 
125
 
126
  def get_related_url(metadata):
 
136
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
137
 
138
 
139
+ def main(
140
+ query: str, collection_name: str, model_name: str, option: str, temperature: float
141
+ ):
142
+ qa = get_retrieval_qa(collection_name, model_name, temperature, option)
143
  try:
144
  result = qa(query)
145
  except InvalidRequestError as e:
 
155
  fn=main,
156
  inputs=[
157
  gr.Textbox(label="query"),
158
+ gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"),
159
+ gr.Radio(["rinna", "GPT-3.5", "GPT-4"], label="Model", info="選択なしで「rinna」を使用"),
160
+ gr.Radio(
161
+ ["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
162
+ label="絞り込み",
163
+ info="ドキュメント制限する?",
164
+ ),
165
+ gr.Slider(0, 2),
166
  ],
167
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
168
  )
config.py CHANGED
@@ -4,18 +4,24 @@ import os
4
  SAAS = True
5
 
6
 
7
- def get_db_config():
8
- url = os.environ["QDRANT_URL"]
 
 
 
9
  api_key = os.environ["QDRANT_API_KEY"]
10
- collection_name = "nvdajp-book"
11
  return url, api_key, collection_name
12
 
13
 
14
- def get_local_db_congin():
15
  url = "localhost"
16
  # api_key = os.environ["QDRANT_API_KEY"]
17
- collection_name = "nvdajp-book"
18
  return url, None, collection_name
19
 
20
 
21
  DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
 
 
 
 
4
  SAAS = True
5
 
6
 
7
+ def get_db_config(cname="nvdajp-book"):
8
+ if cname == "nvdajp-book":
9
+ url = os.environ["QDRANT_URL"]
10
+ elif cname == "nvdajp-book-e5":
11
+ url = os.environ["QDRANT_E5_URL"]
12
  api_key = os.environ["QDRANT_API_KEY"]
13
+ collection_name = cname
14
  return url, api_key, collection_name
15
 
16
 
17
+ def get_local_db_congin(cname="nvdajp-book"):
18
  url = "localhost"
19
  # api_key = os.environ["QDRANT_API_KEY"]
20
+ collection_name = cname
21
  return url, None, collection_name
22
 
23
 
24
  DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
25
+ DB_E5_CONFIG = (
26
+ get_db_config("nvdajp-book-e5") if SAAS else get_local_db_congin("nvdajp-book-e5")
27
+ )
requirements.txt CHANGED
@@ -4,3 +4,5 @@ tiktoken
4
  gradio
5
  qdrant-client
6
  beautifulsoup4
 
 
 
4
  gradio
5
  qdrant-client
6
  beautifulsoup4
7
+ accelerate
8
+ bitsandbytes
store.py CHANGED
@@ -1,10 +1,12 @@
1
  from langchain.document_loaders import ReadTheDocsLoader
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain.embeddings import OpenAIEmbeddings
 
4
  from langchain.vectorstores import Qdrant
 
5
  # from qdrant_client import QdrantClient
6
  from nvda_ug_loader import NVDAUserGuideLoader
7
- from config import DB_CONFIG
8
 
9
 
10
  CHUNK_SIZE = 500
@@ -23,37 +25,55 @@ def get_documents(path: str):
23
  for doc in docs:
24
  org_metadata = doc.metadata
25
  source = _remove_prefix_path(org_metadata["source"])
26
- add_meta = {"category": category, "source": source, "url": f"{base_url}{source}"}
 
 
 
 
27
  doc.metadata = org_metadata | add_meta
28
  yield doc
29
 
30
 
31
  def get_text_chunk(docs):
32
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=0)
 
 
33
  texts = text_splitter.split_documents(docs)
34
  return texts
35
 
36
 
37
- def store(texts):
38
- embeddings = OpenAIEmbeddings()
39
- db_url, db_api_key, db_collection_name = DB_CONFIG
40
- # client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
 
 
 
 
 
 
 
 
 
 
 
 
41
  _ = Qdrant.from_documents(
42
  texts,
43
  embeddings,
44
  url=db_url,
45
  api_key=db_api_key,
46
- collection_name=db_collection_name
47
  )
48
 
49
 
50
- def rtd_main(path: str):
51
  docs = get_documents(path)
52
  texts = get_text_chunk(docs)
53
- store(texts)
54
 
55
 
56
- def nul_main(url: str):
57
  if "www.nvda.jp" in url:
58
  category = "ja-nvda-user-guide"
59
  else:
@@ -61,25 +81,27 @@ def nul_main(url: str):
61
  loader = NVDAUserGuideLoader(url, category)
62
  docs = loader.load()
63
  texts = get_text_chunk(docs)
64
- store(texts)
65
 
66
 
67
  if __name__ == "__main__":
68
  """
69
- $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
70
- $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
71
- $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html"
72
  """
73
  import sys
 
74
  args = sys.argv
75
- if len(args) != 3:
76
  print("No args, you need two args for type, html_path")
77
  else:
78
  type_ = args[1]
79
  path = args[2]
 
80
  if type_ == "rtd":
81
- rtd_main(path)
82
  elif type_ == "nul":
83
- nul_main(path)
84
  else:
85
  print("No type for store")
 
1
  from langchain.document_loaders import ReadTheDocsLoader
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.vectorstores import Qdrant
6
+
7
  # from qdrant_client import QdrantClient
8
  from nvda_ug_loader import NVDAUserGuideLoader
9
+ from config import DB_CONFIG, DB_E5_CONFIG
10
 
11
 
12
  CHUNK_SIZE = 500
 
25
  for doc in docs:
26
  org_metadata = doc.metadata
27
  source = _remove_prefix_path(org_metadata["source"])
28
+ add_meta = {
29
+ "category": category,
30
+ "source": source,
31
+ "url": f"{base_url}{source}",
32
+ }
33
  doc.metadata = org_metadata | add_meta
34
  yield doc
35
 
36
 
37
  def get_text_chunk(docs):
38
+ text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=CHUNK_SIZE, chunk_overlap=0
40
+ )
41
  texts = text_splitter.split_documents(docs)
42
  return texts
43
 
44
 
45
+ def store(texts, mname):
46
+ if mname == "openai":
47
+ embeddings = OpenAIEmbeddings()
48
+ db_url, db_api_key, db_collection_name = DB_CONFIG
49
+ elif mname == "e5":
50
+ model_name = "intfloat/multilingual-e5-large"
51
+ model_kwargs = {"device": "cuda"}
52
+ encode_kwargs = {"normalize_embeddings": False}
53
+ embeddings = HuggingFaceEmbeddings(
54
+ model_name=model_name,
55
+ model_kwargs=model_kwargs,
56
+ encode_kwargs=encode_kwargs,
57
+ )
58
+ db_url, db_api_key, db_collection_name = DB_E5_CONFIG
59
+ else:
60
+ raise ValueError("Invalid mname")
61
  _ = Qdrant.from_documents(
62
  texts,
63
  embeddings,
64
  url=db_url,
65
  api_key=db_api_key,
66
+ collection_name=db_collection_name,
67
  )
68
 
69
 
70
+ def rtd_main(path: str, mname: str):
71
  docs = get_documents(path)
72
  texts = get_text_chunk(docs)
73
+ store(texts, mname)
74
 
75
 
76
+ def nul_main(url: str, mname: str):
77
  if "www.nvda.jp" in url:
78
  category = "ja-nvda-user-guide"
79
  else:
 
81
  loader = NVDAUserGuideLoader(url, category)
82
  docs = loader.load()
83
  texts = get_text_chunk(docs)
84
+ store(texts, mname)
85
 
86
 
87
  if __name__ == "__main__":
88
  """
89
+ $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest" openai
90
+ $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html" e5
91
+ $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html" e5
92
  """
93
  import sys
94
+
95
  args = sys.argv
96
+ if len(args) != 4:
97
  print("No args, you need two args for type, html_path")
98
  else:
99
  type_ = args[1]
100
  path = args[2]
101
+ mname = args[3]
102
  if type_ == "rtd":
103
+ rtd_main(path, mname)
104
  elif type_ == "nul":
105
+ nul_main(path, mname)
106
  else:
107
  print("No type for store")