Spaces:
Paused
Paused
dev/select-model
#3
by
terapyon
- opened
- app.py +36 -6
- nvda_ug_loader.py +107 -0
- store.py +28 -13
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
from langchain.chains import RetrievalQA
|
3 |
from langchain.embeddings import OpenAIEmbeddings
|
4 |
from langchain.llms import OpenAI
|
|
|
5 |
from langchain.vectorstores import Qdrant
|
6 |
from openai.error import InvalidRequestError
|
7 |
from qdrant_client import QdrantClient
|
@@ -9,16 +10,40 @@ from config import DB_CONFIG
|
|
9 |
|
10 |
|
11 |
PERSIST_DIR_NAME = "nvdajp-book"
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
-
def get_retrieval_qa() -> RetrievalQA:
|
15 |
embeddings = OpenAIEmbeddings()
|
16 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
17 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
18 |
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
return RetrievalQA.from_chain_type(
|
21 |
-
llm=
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
|
24 |
|
@@ -35,8 +60,8 @@ def get_related_url(metadata):
|
|
35 |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
|
36 |
|
37 |
|
38 |
-
def main(query: str):
|
39 |
-
qa = get_retrieval_qa()
|
40 |
try:
|
41 |
result = qa(query)
|
42 |
except InvalidRequestError as e:
|
@@ -50,7 +75,12 @@ def main(query: str):
|
|
50 |
|
51 |
nvdajp_book_qa = gr.Interface(
|
52 |
fn=main,
|
53 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
54 |
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
|
55 |
)
|
56 |
|
|
|
2 |
from langchain.chains import RetrievalQA
|
3 |
from langchain.embeddings import OpenAIEmbeddings
|
4 |
from langchain.llms import OpenAI
|
5 |
+
from langchain.chat_models import ChatOpenAI
|
6 |
from langchain.vectorstores import Qdrant
|
7 |
from openai.error import InvalidRequestError
|
8 |
from qdrant_client import QdrantClient
|
|
|
10 |
|
11 |
|
12 |
PERSIST_DIR_NAME = "nvdajp-book"
|
13 |
+
# MODEL_NAME = "text-davinci-003"
|
14 |
+
# MODEL_NAME = "gpt-3.5-turbo"
|
15 |
+
# MODEL_NAME = "gpt-4"
|
16 |
|
17 |
|
18 |
+
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
|
19 |
embeddings = OpenAIEmbeddings()
|
20 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
21 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
22 |
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
|
23 |
+
if model_name is None:
|
24 |
+
model = "gpt-3.5-turbo"
|
25 |
+
elif model_name == "GPT-3.5":
|
26 |
+
model = "gpt-3.5-turbo"
|
27 |
+
elif model_name == "GPT-4":
|
28 |
+
model = "gpt-4"
|
29 |
+
else:
|
30 |
+
model = "gpt-3.5-turbo"
|
31 |
+
if option is None or option == "All":
|
32 |
+
retriever = db.as_retriever()
|
33 |
+
else:
|
34 |
+
retriever = db.as_retriever(
|
35 |
+
search_kwargs={
|
36 |
+
"filter": {"category": option},
|
37 |
+
}
|
38 |
+
)
|
39 |
return RetrievalQA.from_chain_type(
|
40 |
+
llm=ChatOpenAI(
|
41 |
+
model=model,
|
42 |
+
temperature=temperature
|
43 |
+
),
|
44 |
+
chain_type="stuff",
|
45 |
+
retriever=retriever,
|
46 |
+
return_source_documents=True,
|
47 |
)
|
48 |
|
49 |
|
|
|
60 |
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
|
61 |
|
62 |
|
63 |
+
def main(query: str, model_name: str, option: str, temperature: int):
|
64 |
+
qa = get_retrieval_qa(model_name, temperature, option)
|
65 |
try:
|
66 |
result = qa(query)
|
67 |
except InvalidRequestError as e:
|
|
|
75 |
|
76 |
nvdajp_book_qa = gr.Interface(
|
77 |
fn=main,
|
78 |
+
inputs=[
|
79 |
+
gr.Textbox(label="query"),
|
80 |
+
gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
|
81 |
+
gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
|
82 |
+
gr.Slider(0, 2)
|
83 |
+
],
|
84 |
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
|
85 |
)
|
86 |
|
nvda_ug_loader.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import re
|
3 |
+
from typing import Iterator, List
|
4 |
+
from langchain.docstore.document import Document
|
5 |
+
from langchain.document_loaders.base import BaseLoader
|
6 |
+
|
7 |
+
from bs4 import BeautifulSoup, Tag, ResultSet
|
8 |
+
import requests
|
9 |
+
|
10 |
+
|
11 |
+
RE_HEADERS = re.compile(r"h[23]")
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Content:
|
16 |
+
name: str
|
17 |
+
title: str
|
18 |
+
text: str
|
19 |
+
body: list[Tag]
|
20 |
+
|
21 |
+
|
22 |
+
def _get_anchor_name(header: Tag) -> str:
|
23 |
+
for tag in header.previous_elements:
|
24 |
+
if tag.name == "a":
|
25 |
+
return tag.attrs.get("name", "")
|
26 |
+
return ""
|
27 |
+
|
28 |
+
|
29 |
+
def _reversed_remove_last_anchor(body: list[Tag]) -> Iterator[Tag]:
|
30 |
+
has_anchor = False
|
31 |
+
for tag in reversed(body):
|
32 |
+
if not has_anchor:
|
33 |
+
if tag.name == "a":
|
34 |
+
has_anchor = True
|
35 |
+
continue
|
36 |
+
else:
|
37 |
+
yield tag
|
38 |
+
|
39 |
+
|
40 |
+
def _remove_last_anchor(body: list[Tag]) -> Iterator[Tag]:
|
41 |
+
return reversed(list(_reversed_remove_last_anchor(body)))
|
42 |
+
|
43 |
+
|
44 |
+
def _get_bodys_text(body: list[Tag]) -> str:
|
45 |
+
text = ""
|
46 |
+
for tag in body:
|
47 |
+
text += tag.get_text()
|
48 |
+
return text
|
49 |
+
|
50 |
+
|
51 |
+
def _get_child_content(header: Tag) -> Content:
|
52 |
+
title = header.get_text()
|
53 |
+
name = _get_anchor_name(header)
|
54 |
+
body = [header]
|
55 |
+
for i, child in enumerate(header.next_elements):
|
56 |
+
if i == 0:
|
57 |
+
continue
|
58 |
+
if child.name == "h2" or child.name == "h3":
|
59 |
+
break
|
60 |
+
body.append(child)
|
61 |
+
removed_next_anchor_body = list(_remove_last_anchor(body))
|
62 |
+
text = _get_bodys_text(removed_next_anchor_body)
|
63 |
+
return Content(name,
|
64 |
+
title,
|
65 |
+
text,
|
66 |
+
removed_next_anchor_body
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def get_contents(headers: ResultSet[Tag]) -> Iterator[Content]:
|
71 |
+
for header in headers:
|
72 |
+
yield _get_child_content(header)
|
73 |
+
|
74 |
+
|
75 |
+
class NVDAUserGuideLoader(BaseLoader):
|
76 |
+
"""
|
77 |
+
"""
|
78 |
+
def __init__(self, url: str, category: str) -> None:
|
79 |
+
self.url = url
|
80 |
+
self.category = category
|
81 |
+
|
82 |
+
def fetch(self) -> BeautifulSoup:
|
83 |
+
res = requests.get(self.url)
|
84 |
+
soup = BeautifulSoup(res.content, 'lxml')
|
85 |
+
return soup
|
86 |
+
|
87 |
+
def lazy_load(self) -> Iterator[Document]:
|
88 |
+
soup = self.fetch()
|
89 |
+
# body = soup.body
|
90 |
+
headers = soup.find_all(RE_HEADERS)
|
91 |
+
for content in get_contents(headers):
|
92 |
+
name = content.name
|
93 |
+
title = content.title
|
94 |
+
text = content.text
|
95 |
+
metadata = {"category": self.category, "source": name, "url": f"{self.url}#{name}", "title": title}
|
96 |
+
yield Document(page_content=text, metadata=metadata)
|
97 |
+
|
98 |
+
def load(self) -> List[Document]:
|
99 |
+
return list(self.lazy_load())
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
url = "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
|
104 |
+
loader = NVDAUserGuideLoader(url, "en-nvda-user-guide")
|
105 |
+
data = loader.load()
|
106 |
+
print(data)
|
107 |
+
# breakpoint()
|
store.py
CHANGED
@@ -3,6 +3,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
3 |
from langchain.embeddings import OpenAIEmbeddings
|
4 |
from langchain.vectorstores import Qdrant
|
5 |
# from qdrant_client import QdrantClient
|
|
|
6 |
from config import DB_CONFIG
|
7 |
|
8 |
|
@@ -18,14 +19,13 @@ def get_documents(path: str):
|
|
18 |
loader = ReadTheDocsLoader(path, encoding="utf-8")
|
19 |
docs = loader.load()
|
20 |
base_url = "https://nvdajp-book.readthedocs.io/"
|
21 |
-
|
22 |
for doc in docs:
|
23 |
org_metadata = doc.metadata
|
24 |
source = _remove_prefix_path(org_metadata["source"])
|
25 |
-
add_meta = {"category":
|
26 |
doc.metadata = org_metadata | add_meta
|
27 |
yield doc
|
28 |
-
# return docs
|
29 |
|
30 |
|
31 |
def get_text_chunk(docs):
|
@@ -47,24 +47,39 @@ def store(texts):
|
|
47 |
)
|
48 |
|
49 |
|
50 |
-
def
|
51 |
docs = get_documents(path)
|
52 |
texts = get_text_chunk(docs)
|
53 |
store(texts)
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if __name__ == "__main__":
|
57 |
"""
|
58 |
-
$ python store.py "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
|
|
|
|
|
59 |
"""
|
60 |
import sys
|
61 |
args = sys.argv
|
62 |
-
if len(args) !=
|
63 |
-
print("No args, you need two args for html_path")
|
64 |
-
docs = get_documents("data/rtdocs/nvdajp-book.readthedocs.io/ja/latest")
|
65 |
-
print(type(docs))
|
66 |
-
breakpoint()
|
67 |
else:
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain.embeddings import OpenAIEmbeddings
|
4 |
from langchain.vectorstores import Qdrant
|
5 |
# from qdrant_client import QdrantClient
|
6 |
+
from nvda_ug_loader import NVDAUserGuideLoader
|
7 |
from config import DB_CONFIG
|
8 |
|
9 |
|
|
|
19 |
loader = ReadTheDocsLoader(path, encoding="utf-8")
|
20 |
docs = loader.load()
|
21 |
base_url = "https://nvdajp-book.readthedocs.io/"
|
22 |
+
category = "ja-book"
|
23 |
for doc in docs:
|
24 |
org_metadata = doc.metadata
|
25 |
source = _remove_prefix_path(org_metadata["source"])
|
26 |
+
add_meta = {"category": category, "source": source, "url": f"{base_url}{source}"}
|
27 |
doc.metadata = org_metadata | add_meta
|
28 |
yield doc
|
|
|
29 |
|
30 |
|
31 |
def get_text_chunk(docs):
|
|
|
47 |
)
|
48 |
|
49 |
|
50 |
+
def rtd_main(path: str):
|
51 |
docs = get_documents(path)
|
52 |
texts = get_text_chunk(docs)
|
53 |
store(texts)
|
54 |
|
55 |
|
56 |
+
def nul_main(url: str):
|
57 |
+
if "www.nvda.jp" in url:
|
58 |
+
category = "ja-nvda-user-guide"
|
59 |
+
else:
|
60 |
+
category = "en-nvda-user-guide"
|
61 |
+
loader = NVDAUserGuideLoader(url, category)
|
62 |
+
docs = loader.load()
|
63 |
+
texts = get_text_chunk(docs)
|
64 |
+
store(texts)
|
65 |
+
|
66 |
+
|
67 |
if __name__ == "__main__":
|
68 |
"""
|
69 |
+
$ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
|
70 |
+
$ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
|
71 |
+
$ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html"
|
72 |
"""
|
73 |
import sys
|
74 |
args = sys.argv
|
75 |
+
if len(args) != 3:
|
76 |
+
print("No args, you need two args for type, html_path")
|
|
|
|
|
|
|
77 |
else:
|
78 |
+
type_ = args[1]
|
79 |
+
path = args[2]
|
80 |
+
if type_ == "rtd":
|
81 |
+
rtd_main(path)
|
82 |
+
elif type_ == "nul":
|
83 |
+
nul_main(path)
|
84 |
+
else:
|
85 |
+
print("No type for store")
|