terapyon commited on
Commit
227586c
·
1 Parent(s): 8d3ec3e

added NVDA User guide content and added filter QA

Browse files
Files changed (4) hide show
  1. app.py +17 -6
  2. config.py +1 -1
  3. nvda_ug_loader.py +107 -0
  4. store.py +26 -10
app.py CHANGED
@@ -11,14 +11,21 @@ from config import DB_CONFIG
11
  PERSIST_DIR_NAME = "nvdajp-book"
12
 
13
 
14
- def get_retrieval_qa() -> RetrievalQA:
15
  embeddings = OpenAIEmbeddings()
16
  db_url, db_api_key, db_collection_name = DB_CONFIG
17
  client = QdrantClient(url=db_url, api_key=db_api_key)
18
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
19
- retriever = db.as_retriever()
 
 
 
 
 
 
 
20
  return RetrievalQA.from_chain_type(
21
- llm=OpenAI(temperature=0), chain_type="stuff", retriever=retriever, return_source_documents=True,
22
  )
23
 
24
 
@@ -35,8 +42,8 @@ def get_related_url(metadata):
35
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
36
 
37
 
38
- def main(query: str):
39
- qa = get_retrieval_qa()
40
  try:
41
  result = qa(query)
42
  except InvalidRequestError as e:
@@ -50,7 +57,11 @@ def main(query: str):
50
 
51
  nvdajp_book_qa = gr.Interface(
52
  fn=main,
53
- inputs=[gr.Textbox(label="query")],
 
 
 
 
54
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
55
  )
56
 
 
11
  PERSIST_DIR_NAME = "nvdajp-book"
12
 
13
 
14
+ def get_retrieval_qa(temperature: int, option: str) -> RetrievalQA:
15
  embeddings = OpenAIEmbeddings()
16
  db_url, db_api_key, db_collection_name = DB_CONFIG
17
  client = QdrantClient(url=db_url, api_key=db_api_key)
18
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
19
+ if option is None or option == "All":
20
+ retriever = db.as_retriever()
21
+ else:
22
+ retriever = db.as_retriever(
23
+ search_kwargs={
24
+ "filter": {"category": option},
25
+ }
26
+ )
27
  return RetrievalQA.from_chain_type(
28
+ llm=OpenAI(temperature=temperature), chain_type="stuff", retriever=retriever, return_source_documents=True,
29
  )
30
 
31
 
 
42
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
43
 
44
 
45
+ def main(query: str, option: str, temperature: int):
46
+ qa = get_retrieval_qa(temperature, option)
47
  try:
48
  result = qa(query)
49
  except InvalidRequestError as e:
 
57
 
58
  nvdajp_book_qa = gr.Interface(
59
  fn=main,
60
+ inputs=[
61
+ gr.Textbox(label="query"),
62
+ gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
63
+ gr.Slider(0, 2)
64
+ ],
65
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
66
  )
67
 
config.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
 
3
 
4
- SAAS = True
5
 
6
 
7
  def get_db_config():
 
1
  import os
2
 
3
 
4
+ SAAS = False
5
 
6
 
7
  def get_db_config():
nvda_ug_loader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import re
3
+ from typing import Iterator, List
4
+ from langchain.docstore.document import Document
5
+ from langchain.document_loaders.base import BaseLoader
6
+
7
+ from bs4 import BeautifulSoup, Tag, ResultSet
8
+ import requests
9
+
10
+
11
+ RE_HEADERS = re.compile(r"h[23]")
12
+
13
+
14
+ @dataclass
15
+ class Content:
16
+ name: str
17
+ title: str
18
+ text: str
19
+ body: list[Tag]
20
+
21
+
22
+ def _get_anchor_name(header: Tag) -> str:
23
+ for tag in header.previous_elements:
24
+ if tag.name == "a":
25
+ return tag.attrs.get("name", "")
26
+ return ""
27
+
28
+
29
+ def _reversed_remove_last_anchor(body: list[Tag]) -> Iterator[Tag]:
30
+ has_anchor = False
31
+ for tag in reversed(body):
32
+ if not has_anchor:
33
+ if tag.name == "a":
34
+ has_anchor = True
35
+ continue
36
+ else:
37
+ yield tag
38
+
39
+
40
+ def _remove_last_anchor(body: list[Tag]) -> Iterator[Tag]:
41
+ return reversed(list(_reversed_remove_last_anchor(body)))
42
+
43
+
44
+ def _get_bodys_text(body: list[Tag]) -> str:
45
+ text = ""
46
+ for tag in body:
47
+ text += tag.get_text()
48
+ return text
49
+
50
+
51
+ def _get_child_content(header: Tag) -> Content:
52
+ title = header.get_text()
53
+ name = _get_anchor_name(header)
54
+ body = [header]
55
+ for i, child in enumerate(header.next_elements):
56
+ if i == 0:
57
+ continue
58
+ if child.name == "h2" or child.name == "h3":
59
+ break
60
+ body.append(child)
61
+ removed_next_anchor_body = list(_remove_last_anchor(body))
62
+ text = _get_bodys_text(removed_next_anchor_body)
63
+ return Content(name,
64
+ title,
65
+ text,
66
+ removed_next_anchor_body
67
+ )
68
+
69
+
70
+ def get_contents(headers: ResultSet[Tag]) -> Iterator[Content]:
71
+ for header in headers:
72
+ yield _get_child_content(header)
73
+
74
+
75
+ class NVDAUserGuideLoader(BaseLoader):
76
+ """
77
+ """
78
+ def __init__(self, url: str, category: str) -> None:
79
+ self.url = url
80
+ self.category = category
81
+
82
+ def fetch(self) -> BeautifulSoup:
83
+ res = requests.get(self.url)
84
+ soup = BeautifulSoup(res.content, 'lxml')
85
+ return soup
86
+
87
+ def lazy_load(self) -> Iterator[Document]:
88
+ soup = self.fetch()
89
+ # body = soup.body
90
+ headers = soup.find_all(RE_HEADERS)
91
+ for content in get_contents(headers):
92
+ name = content.name
93
+ title = content.title
94
+ text = content.text
95
+ metadata = {"category": self.category, "source": name, "url": f"{self.url}#{name}", "title": title}
96
+ yield Document(page_content=text, metadata=metadata)
97
+
98
+ def load(self) -> List[Document]:
99
+ return list(self.lazy_load())
100
+
101
+
102
+ if __name__ == "__main__":
103
+ url = "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
104
+ loader = NVDAUserGuideLoader(url, "en-nvda-user-guide")
105
+ data = loader.load()
106
+ print(data)
107
+ # breakpoint()
store.py CHANGED
@@ -3,6 +3,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.vectorstores import Qdrant
5
  # from qdrant_client import QdrantClient
 
6
  from config import DB_CONFIG
7
 
8
 
@@ -46,24 +47,39 @@ def store(texts):
46
  )
47
 
48
 
49
- def main(path: str):
50
  docs = get_documents(path)
51
  texts = get_text_chunk(docs)
52
  store(texts)
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  if __name__ == "__main__":
56
  """
57
- $ python store.py "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
 
 
58
  """
59
  import sys
60
  args = sys.argv
61
- if len(args) != 2:
62
- print("No args, you need two args for html_path")
63
- docs = get_documents("data/rtdocs/nvdajp-book.readthedocs.io/ja/latest")
64
- print(type(docs))
65
- breakpoint()
66
  else:
67
- path = args[1]
68
- # dir_name = args[2]
69
- main(path)
 
 
 
 
 
 
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.vectorstores import Qdrant
5
  # from qdrant_client import QdrantClient
6
+ from nvda_ug_loader import NVDAUserGuideLoader
7
  from config import DB_CONFIG
8
 
9
 
 
47
  )
48
 
49
 
50
+ def rtd_main(path: str):
51
  docs = get_documents(path)
52
  texts = get_text_chunk(docs)
53
  store(texts)
54
 
55
 
56
+ def nul_main(url: str):
57
+ if "www.nvda.jp" in url:
58
+ category = "ja-nvda-user-guide"
59
+ else:
60
+ category = "en-nvda-user-guide"
61
+ loader = NVDAUserGuideLoader(url, category)
62
+ docs = loader.load()
63
+ texts = get_text_chunk(docs)
64
+ store(texts)
65
+
66
+
67
  if __name__ == "__main__":
68
  """
69
+ $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
70
+ $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
71
+ $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html"
72
  """
73
  import sys
74
  args = sys.argv
75
+ if len(args) != 3:
76
+ print("No args, you need two args for type, html_path")
 
 
 
77
  else:
78
+ type_ = args[1]
79
+ path = args[2]
80
+ if type_ == "rtd":
81
+ rtd_main(path)
82
+ elif type_ == "nul":
83
+ nul_main(path)
84
+ else:
85
+ print("No type for store")