Spaces:
Paused
Paused
Embeddings multilingual-e5-largeとLLM rinnaを使えるようにした
Browse files- app.py +103 -33
- config.py +12 -6
- requirements.txt +2 -0
- store.py +40 -29
app.py
CHANGED
@@ -1,45 +1,105 @@
|
|
1 |
-
from time import time
|
2 |
import gradio as gr
|
3 |
from langchain.chains import RetrievalQA
|
4 |
-
|
5 |
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
-
from langchain.
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
from langchain.vectorstores import Qdrant
|
10 |
from openai.error import InvalidRequestError
|
11 |
from qdrant_client import QdrantClient
|
12 |
-
from config import DB_CONFIG
|
13 |
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
-
def
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
encode_kwargs=encode_kwargs,
|
30 |
)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
if model_name is None:
|
36 |
-
model = "
|
|
|
|
|
37 |
elif model_name == "GPT-3.5":
|
38 |
model = "gpt-3.5-turbo"
|
39 |
elif model_name == "GPT-4":
|
40 |
model = "gpt-4"
|
41 |
else:
|
42 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if option is None or option == "All":
|
44 |
retriever = db.as_retriever()
|
45 |
else:
|
@@ -48,14 +108,17 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
|
|
48 |
"filter": {"category": option},
|
49 |
}
|
50 |
)
|
|
|
|
|
|
|
|
|
|
|
51 |
result = RetrievalQA.from_chain_type(
|
52 |
-
llm=
|
53 |
-
model=model,
|
54 |
-
temperature=temperature
|
55 |
-
),
|
56 |
chain_type="stuff",
|
57 |
retriever=retriever,
|
58 |
return_source_documents=True,
|
|
|
59 |
)
|
60 |
return result
|
61 |
|
@@ -73,8 +136,10 @@ def get_related_url(metadata):
|
|
73 |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
|
74 |
|
75 |
|
76 |
-
def main(
|
77 |
-
|
|
|
|
|
78 |
try:
|
79 |
result = qa(query)
|
80 |
except InvalidRequestError as e:
|
@@ -90,9 +155,14 @@ nvdajp_book_qa = gr.Interface(
|
|
90 |
fn=main,
|
91 |
inputs=[
|
92 |
gr.Textbox(label="query"),
|
93 |
-
gr.Radio(["
|
94 |
-
gr.Radio(["
|
95 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
96 |
],
|
97 |
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
|
98 |
)
|
|
|
1 |
+
# from time import time
|
2 |
import gradio as gr
|
3 |
from langchain.chains import RetrievalQA
|
4 |
+
from langchain.embeddings import OpenAIEmbeddings
|
5 |
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
+
from langchain.prompts import PromptTemplate
|
7 |
+
import torch
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
9 |
+
from langchain.llms import HuggingFacePipeline
|
10 |
+
|
11 |
+
# from langchain.llms import OpenAI
|
12 |
from langchain.chat_models import ChatOpenAI
|
13 |
from langchain.vectorstores import Qdrant
|
14 |
from openai.error import InvalidRequestError
|
15 |
from qdrant_client import QdrantClient
|
16 |
+
from config import DB_CONFIG, DB_E5_CONFIG
|
17 |
|
18 |
|
19 |
+
def _get_config_and_embeddings(collection_name: str | None) -> tuple:
|
20 |
+
if collection_name is None or collection_name == "E5":
|
21 |
+
db_config = DB_E5_CONFIG
|
22 |
+
model_name = "intfloat/multilingual-e5-large"
|
23 |
+
model_kwargs = {"device": "cpu"}
|
24 |
+
encode_kwargs = {"normalize_embeddings": False}
|
25 |
+
embeddings = HuggingFaceEmbeddings(
|
26 |
+
model_name=model_name,
|
27 |
+
model_kwargs=model_kwargs,
|
28 |
+
encode_kwargs=encode_kwargs,
|
29 |
+
)
|
30 |
+
elif collection_name == "OpenAI":
|
31 |
+
db_config = DB_CONFIG
|
32 |
+
embeddings = OpenAIEmbeddings()
|
33 |
+
else:
|
34 |
+
raise ValueError("Unknow collection name")
|
35 |
+
return db_config, embeddings
|
36 |
|
37 |
|
38 |
+
def _get_rinna_llm(temperature: float):
|
39 |
+
model = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
41 |
+
model = AutoModelForCausalLM.from_pretrained(
|
42 |
+
model,
|
43 |
+
load_in_8bit=True,
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
device_map="auto",
|
|
|
46 |
)
|
47 |
+
pipe = pipeline(
|
48 |
+
"text-generation",
|
49 |
+
model=model,
|
50 |
+
tokenizer=tokenizer,
|
51 |
+
max_new_tokens=1024,
|
52 |
+
temperature=temperature,
|
53 |
+
)
|
54 |
+
llm = HuggingFacePipeline(pipeline=pipe)
|
55 |
+
return llm
|
56 |
+
|
57 |
+
|
58 |
+
def _get_llm_model(
|
59 |
+
model_name: str | None,
|
60 |
+
temperature: float,
|
61 |
+
):
|
62 |
if model_name is None:
|
63 |
+
model = "rinna"
|
64 |
+
elif model_name == "rinna":
|
65 |
+
model = "rinna"
|
66 |
elif model_name == "GPT-3.5":
|
67 |
model = "gpt-3.5-turbo"
|
68 |
elif model_name == "GPT-4":
|
69 |
model = "gpt-4"
|
70 |
else:
|
71 |
+
raise ValueError("Unknow model name")
|
72 |
+
if model.startswith("gpt"):
|
73 |
+
llm = ChatOpenAI(model=model, temperature=temperature)
|
74 |
+
elif model == "rinna":
|
75 |
+
llm = _get_rinna_llm(temperature)
|
76 |
+
return llm
|
77 |
+
|
78 |
+
|
79 |
+
# prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
80 |
+
|
81 |
+
# {context}
|
82 |
+
|
83 |
+
# Question: {question}
|
84 |
+
# Answer in Japanese:"""
|
85 |
+
# PROMPT = PromptTemplate(
|
86 |
+
# template=prompt_template, input_variables=["context", "question"]
|
87 |
+
# )
|
88 |
+
|
89 |
+
|
90 |
+
def get_retrieval_qa(
|
91 |
+
collection_name: str | None,
|
92 |
+
model_name: str | None,
|
93 |
+
temperature: float,
|
94 |
+
option: str | None,
|
95 |
+
) -> RetrievalQA:
|
96 |
+
db_config, embeddings = _get_config_and_embeddings(collection_name)
|
97 |
+
db_url, db_api_key, db_collection_name = db_config
|
98 |
+
client = QdrantClient(url=db_url, api_key=db_api_key)
|
99 |
+
db = Qdrant(
|
100 |
+
client=client, collection_name=db_collection_name, embeddings=embeddings
|
101 |
+
)
|
102 |
+
|
103 |
if option is None or option == "All":
|
104 |
retriever = db.as_retriever()
|
105 |
else:
|
|
|
108 |
"filter": {"category": option},
|
109 |
}
|
110 |
)
|
111 |
+
|
112 |
+
llm = _get_llm_model(model_name, temperature)
|
113 |
+
|
114 |
+
# chain_type_kwargs = {"prompt": PROMPT}
|
115 |
+
|
116 |
result = RetrievalQA.from_chain_type(
|
117 |
+
llm=llm,
|
|
|
|
|
|
|
118 |
chain_type="stuff",
|
119 |
retriever=retriever,
|
120 |
return_source_documents=True,
|
121 |
+
# chain_type_kwargs=chain_type_kwargs,
|
122 |
)
|
123 |
return result
|
124 |
|
|
|
136 |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
|
137 |
|
138 |
|
139 |
+
def main(
|
140 |
+
query: str, collection_name: str, model_name: str, option: str, temperature: float
|
141 |
+
):
|
142 |
+
qa = get_retrieval_qa(collection_name, model_name, temperature, option)
|
143 |
try:
|
144 |
result = qa(query)
|
145 |
except InvalidRequestError as e:
|
|
|
155 |
fn=main,
|
156 |
inputs=[
|
157 |
gr.Textbox(label="query"),
|
158 |
+
gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"),
|
159 |
+
gr.Radio(["rinna", "GPT-3.5", "GPT-4"], label="Model", info="選択なしで「rinna」を使用"),
|
160 |
+
gr.Radio(
|
161 |
+
["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
|
162 |
+
label="絞り込み",
|
163 |
+
info="ドキュメント制限する?",
|
164 |
+
),
|
165 |
+
gr.Slider(0, 2),
|
166 |
],
|
167 |
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
|
168 |
)
|
config.py
CHANGED
@@ -1,21 +1,27 @@
|
|
1 |
import os
|
2 |
|
3 |
|
4 |
-
SAAS =
|
5 |
|
6 |
|
7 |
-
def get_db_config():
|
8 |
-
|
|
|
|
|
|
|
9 |
api_key = os.environ["QDRANT_API_KEY"]
|
10 |
-
collection_name =
|
11 |
return url, api_key, collection_name
|
12 |
|
13 |
|
14 |
-
def get_local_db_congin():
|
15 |
url = "localhost"
|
16 |
# api_key = os.environ["QDRANT_API_KEY"]
|
17 |
-
collection_name =
|
18 |
return url, None, collection_name
|
19 |
|
20 |
|
21 |
DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
|
4 |
+
SAAS = True
|
5 |
|
6 |
|
7 |
+
def get_db_config(cname="nvdajp-book"):
|
8 |
+
if cname == "nvdajp-book":
|
9 |
+
url = os.environ["QDRANT_URL"]
|
10 |
+
elif cname == "nvdajp-book-e5":
|
11 |
+
url = os.environ["QDRANT_E5_URL"]
|
12 |
api_key = os.environ["QDRANT_API_KEY"]
|
13 |
+
collection_name = cname
|
14 |
return url, api_key, collection_name
|
15 |
|
16 |
|
17 |
+
def get_local_db_congin(cname="nvdajp-book"):
|
18 |
url = "localhost"
|
19 |
# api_key = os.environ["QDRANT_API_KEY"]
|
20 |
+
collection_name = cname
|
21 |
return url, None, collection_name
|
22 |
|
23 |
|
24 |
DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
|
25 |
+
DB_E5_CONFIG = (
|
26 |
+
get_db_config("nvdajp-book-e5") if SAAS else get_local_db_congin("nvdajp-book-e5")
|
27 |
+
)
|
requirements.txt
CHANGED
@@ -4,3 +4,5 @@ tiktoken
|
|
4 |
gradio
|
5 |
qdrant-client
|
6 |
beautifulsoup4
|
|
|
|
|
|
4 |
gradio
|
5 |
qdrant-client
|
6 |
beautifulsoup4
|
7 |
+
accelerate
|
8 |
+
bitsandbytes
|
store.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
from langchain.document_loaders import ReadTheDocsLoader
|
2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
-
|
4 |
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
-
from langchain.embeddings import GPT4AllEmbeddings
|
6 |
from langchain.vectorstores import Qdrant
|
|
|
7 |
# from qdrant_client import QdrantClient
|
8 |
from nvda_ug_loader import NVDAUserGuideLoader
|
9 |
-
from config import DB_CONFIG
|
10 |
|
11 |
|
12 |
CHUNK_SIZE = 500
|
@@ -25,46 +25,55 @@ def get_documents(path: str):
|
|
25 |
for doc in docs:
|
26 |
org_metadata = doc.metadata
|
27 |
source = _remove_prefix_path(org_metadata["source"])
|
28 |
-
add_meta = {
|
|
|
|
|
|
|
|
|
29 |
doc.metadata = org_metadata | add_meta
|
30 |
yield doc
|
31 |
|
32 |
|
33 |
def get_text_chunk(docs):
|
34 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
|
35 |
texts = text_splitter.split_documents(docs)
|
36 |
return texts
|
37 |
|
38 |
|
39 |
-
def store(texts):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
_ = Qdrant.from_documents(
|
53 |
texts,
|
54 |
embeddings,
|
55 |
url=db_url,
|
56 |
api_key=db_api_key,
|
57 |
-
collection_name=db_collection_name
|
58 |
)
|
59 |
|
60 |
|
61 |
-
def rtd_main(path: str):
|
62 |
docs = get_documents(path)
|
63 |
texts = get_text_chunk(docs)
|
64 |
-
store(texts)
|
65 |
|
66 |
|
67 |
-
def nul_main(url: str):
|
68 |
if "www.nvda.jp" in url:
|
69 |
category = "ja-nvda-user-guide"
|
70 |
else:
|
@@ -72,25 +81,27 @@ def nul_main(url: str):
|
|
72 |
loader = NVDAUserGuideLoader(url, category)
|
73 |
docs = loader.load()
|
74 |
texts = get_text_chunk(docs)
|
75 |
-
store(texts)
|
76 |
|
77 |
|
78 |
if __name__ == "__main__":
|
79 |
"""
|
80 |
-
$ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
|
81 |
-
$ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
|
82 |
-
$ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html"
|
83 |
"""
|
84 |
import sys
|
|
|
85 |
args = sys.argv
|
86 |
-
if len(args) !=
|
87 |
print("No args, you need two args for type, html_path")
|
88 |
else:
|
89 |
type_ = args[1]
|
90 |
path = args[2]
|
|
|
91 |
if type_ == "rtd":
|
92 |
-
rtd_main(path)
|
93 |
elif type_ == "nul":
|
94 |
-
nul_main(path)
|
95 |
else:
|
96 |
print("No type for store")
|
|
|
1 |
from langchain.document_loaders import ReadTheDocsLoader
|
2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain.embeddings import OpenAIEmbeddings
|
4 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
5 |
from langchain.vectorstores import Qdrant
|
6 |
+
|
7 |
# from qdrant_client import QdrantClient
|
8 |
from nvda_ug_loader import NVDAUserGuideLoader
|
9 |
+
from config import DB_CONFIG, DB_E5_CONFIG
|
10 |
|
11 |
|
12 |
CHUNK_SIZE = 500
|
|
|
25 |
for doc in docs:
|
26 |
org_metadata = doc.metadata
|
27 |
source = _remove_prefix_path(org_metadata["source"])
|
28 |
+
add_meta = {
|
29 |
+
"category": category,
|
30 |
+
"source": source,
|
31 |
+
"url": f"{base_url}{source}",
|
32 |
+
}
|
33 |
doc.metadata = org_metadata | add_meta
|
34 |
yield doc
|
35 |
|
36 |
|
37 |
def get_text_chunk(docs):
|
38 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
39 |
+
chunk_size=CHUNK_SIZE, chunk_overlap=0
|
40 |
+
)
|
41 |
texts = text_splitter.split_documents(docs)
|
42 |
return texts
|
43 |
|
44 |
|
45 |
+
def store(texts, mname):
|
46 |
+
if mname == "openai":
|
47 |
+
embeddings = OpenAIEmbeddings()
|
48 |
+
db_url, db_api_key, db_collection_name = DB_CONFIG
|
49 |
+
elif mname == "e5":
|
50 |
+
model_name = "intfloat/multilingual-e5-large"
|
51 |
+
model_kwargs = {"device": "cuda"}
|
52 |
+
encode_kwargs = {"normalize_embeddings": False}
|
53 |
+
embeddings = HuggingFaceEmbeddings(
|
54 |
+
model_name=model_name,
|
55 |
+
model_kwargs=model_kwargs,
|
56 |
+
encode_kwargs=encode_kwargs,
|
57 |
+
)
|
58 |
+
db_url, db_api_key, db_collection_name = DB_E5_CONFIG
|
59 |
+
else:
|
60 |
+
raise ValueError("Invalid mname")
|
61 |
_ = Qdrant.from_documents(
|
62 |
texts,
|
63 |
embeddings,
|
64 |
url=db_url,
|
65 |
api_key=db_api_key,
|
66 |
+
collection_name=db_collection_name,
|
67 |
)
|
68 |
|
69 |
|
70 |
+
def rtd_main(path: str, mname: str):
|
71 |
docs = get_documents(path)
|
72 |
texts = get_text_chunk(docs)
|
73 |
+
store(texts, mname)
|
74 |
|
75 |
|
76 |
+
def nul_main(url: str, mname: str):
|
77 |
if "www.nvda.jp" in url:
|
78 |
category = "ja-nvda-user-guide"
|
79 |
else:
|
|
|
81 |
loader = NVDAUserGuideLoader(url, category)
|
82 |
docs = loader.load()
|
83 |
texts = get_text_chunk(docs)
|
84 |
+
store(texts, mname)
|
85 |
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
"""
|
89 |
+
$ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest" openai
|
90 |
+
$ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html" e5
|
91 |
+
$ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html" e5
|
92 |
"""
|
93 |
import sys
|
94 |
+
|
95 |
args = sys.argv
|
96 |
+
if len(args) != 4:
|
97 |
print("No args, you need two args for type, html_path")
|
98 |
else:
|
99 |
type_ = args[1]
|
100 |
path = args[2]
|
101 |
+
mname = args[3]
|
102 |
if type_ == "rtd":
|
103 |
+
rtd_main(path, mname)
|
104 |
elif type_ == "nul":
|
105 |
+
nul_main(path, mname)
|
106 |
else:
|
107 |
print("No type for store")
|