dh-mc commited on
Commit
4ce9985
·
1 Parent(s): 4ae9830

lib search API ready

Browse files
app_modules/init.py CHANGED
@@ -79,14 +79,28 @@ def app_init(initQAChain: bool = True):
79
 
80
  print(f"Completed in {end - start:.3f}s")
81
 
82
- vectorstore = load_vectorstor(index_path)
 
 
 
 
 
 
 
 
 
 
83
 
84
  start = timer()
85
  llm_loader = LLMLoader(llm_model_type)
86
  llm_loader.init(
87
  n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
88
  )
89
- qa_chain = QAChain(vectorstore, llm_loader) if initQAChain else None
 
 
 
 
90
  end = timer()
91
  print(f"Completed in {end - start:.3f}s")
92
 
 
79
 
80
  print(f"Completed in {end - start:.3f}s")
81
 
82
+ vectorstore = load_vectorstor(using_faiss, index_path, embeddings)
83
+
84
+ doc_id_to_vectorstore_mapping = {}
85
+ rootdir = index_path
86
+ for file in os.listdir(rootdir):
87
+ d = os.path.join(rootdir, file)
88
+ if os.path.isdir(d):
89
+ v = load_vectorstor(using_faiss, d, embeddings)
90
+ doc_id_to_vectorstore_mapping[file] = v
91
+
92
+ # print(doc_id_to_vectorstore_mapping)
93
 
94
  start = timer()
95
  llm_loader = LLMLoader(llm_model_type)
96
  llm_loader.init(
97
  n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
98
  )
99
+ qa_chain = (
100
+ QAChain(vectorstore, llm_loader, doc_id_to_vectorstore_mapping)
101
+ if initQAChain
102
+ else None
103
+ )
104
  end = timer()
105
  print(f"Completed in {end - start:.3f}s")
106
 
app_modules/llm_chat_chain.py CHANGED
@@ -27,7 +27,7 @@ class ChatChain(LLMInference):
27
  def __init__(self, llm_loader):
28
  super().__init__(llm_loader)
29
 
30
- def create_chain(self) -> Chain:
31
  template = (
32
  get_llama_2_prompt_template()
33
  if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
 
27
  def __init__(self, llm_loader):
28
  super().__init__(llm_loader)
29
 
30
+ def create_chain(self, inputs) -> Chain:
31
  template = (
32
  get_llama_2_prompt_template()
33
  if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
app_modules/llm_inference.py CHANGED
@@ -22,12 +22,12 @@ class LLMInference(metaclass=abc.ABCMeta):
22
  self.chain = None
23
 
24
  @abc.abstractmethod
25
- def create_chain(self) -> Chain:
26
  pass
27
 
28
- def get_chain(self) -> Chain:
29
  if self.chain is None:
30
- self.chain = self.create_chain()
31
 
32
  return self.chain
33
 
@@ -48,7 +48,7 @@ class LLMInference(metaclass=abc.ABCMeta):
48
  try:
49
  self.llm_loader.streamer.reset(q)
50
 
51
- chain = self.get_chain()
52
  result = (
53
  self._run_chain_with_streaming_handler(
54
  chain, inputs, streaming_handler, testing
@@ -61,7 +61,14 @@ class LLMInference(metaclass=abc.ABCMeta):
61
  result["answer"] = remove_extra_spaces(result["answer"])
62
 
63
  source_path = os.environ.get("SOURCE_PATH")
64
- if source_path is not None and len(source_path) > 0:
 
 
 
 
 
 
 
65
  documents = result["source_documents"]
66
  for doc in documents:
67
  source = doc.metadata["source"]
 
22
  self.chain = None
23
 
24
  @abc.abstractmethod
25
+ def create_chain(self, inputs) -> Chain:
26
  pass
27
 
28
+ def get_chain(self, inputs) -> Chain:
29
  if self.chain is None:
30
+ self.chain = self.create_chain(inputs)
31
 
32
  return self.chain
33
 
 
48
  try:
49
  self.llm_loader.streamer.reset(q)
50
 
51
+ chain = self.get_chain(inputs)
52
  result = (
53
  self._run_chain_with_streaming_handler(
54
  chain, inputs, streaming_handler, testing
 
61
  result["answer"] = remove_extra_spaces(result["answer"])
62
 
63
  source_path = os.environ.get("SOURCE_PATH")
64
+ base_url = os.environ.get("PDF_FILE_BASE_URL")
65
+ if base_url is not None and len(base_url) > 0:
66
+ documents = result["source_documents"]
67
+ for doc in documents:
68
+ source = doc.metadata["source"]
69
+ title = source.split("/")[-1]
70
+ doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
71
+ elif source_path is not None and len(source_path) > 0:
72
  documents = result["source_documents"]
73
  for doc in documents:
74
  source = doc.metadata["source"]
app_modules/llm_qa_chain.py CHANGED
@@ -8,14 +8,23 @@ from app_modules.llm_inference import LLMInference
8
  class QAChain(LLMInference):
9
  vectorstore: VectorStore
10
 
11
- def __init__(self, vectorstore, llm_loader):
12
  super().__init__(llm_loader)
13
  self.vectorstore = vectorstore
 
 
 
 
 
 
 
 
 
 
14
 
15
- def create_chain(self) -> Chain:
16
  qa = ConversationalRetrievalChain.from_llm(
17
  self.llm_loader.llm,
18
- self.vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
19
  max_tokens_limit=self.llm_loader.max_tokens_limit,
20
  return_source_documents=True,
21
  )
 
8
  class QAChain(LLMInference):
9
  vectorstore: VectorStore
10
 
11
+ def __init__(self, vectorstore, llm_loader, doc_id_to_vectorstore_mapping=None):
12
  super().__init__(llm_loader)
13
  self.vectorstore = vectorstore
14
+ self.doc_id_to_vectorstore_mapping = doc_id_to_vectorstore_mapping
15
+
16
+ def get_chain(self, inputs) -> Chain:
17
+ return self.create_chain(inputs)
18
+
19
+ def create_chain(self, inputs) -> Chain:
20
+ vectorstore = self.vectorstore
21
+ if "chat_id" in inputs:
22
+ if inputs["chat_id"] in self.doc_id_to_vectorstore_mapping:
23
+ vectorstore = self.doc_id_to_vectorstore_mapping[inputs["chat_id"]]
24
 
 
25
  qa = ConversationalRetrievalChain.from_llm(
26
  self.llm_loader.llm,
27
+ vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
28
  max_tokens_limit=self.llm_loader.max_tokens_limit,
29
  return_source_documents=True,
30
  )
app_modules/llm_summarize_chain.py CHANGED
@@ -23,7 +23,7 @@ class SummarizeChain(LLMInference):
23
  def __init__(self, llm_loader):
24
  super().__init__(llm_loader)
25
 
26
- def create_chain(self) -> Chain:
27
  use_llama_2_prompt_template = (
28
  os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
29
  )
 
23
  def __init__(self, llm_loader):
24
  super().__init__(llm_loader)
25
 
26
+ def create_chain(self, inputs) -> Chain:
27
  use_llama_2_prompt_template = (
28
  os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
29
  )
server.py CHANGED
@@ -28,11 +28,11 @@ class ChatResponse(BaseModel):
28
 
29
  def do_chat(
30
  question: str,
31
- history: Optional[List] = [],
32
  chat_id: Optional[str] = None,
33
  streaming_handler: any = None,
34
  ):
35
- if chat_id is None:
36
  chat_history = []
37
  if chat_history_enabled:
38
  for element in history:
@@ -41,7 +41,8 @@ def do_chat(
41
 
42
  start = timer()
43
  result = qa_chain.call_chain(
44
- {"question": question, "chat_history": chat_history}, streaming_handler
 
45
  )
46
  end = timer()
47
  print(f"Completed in {end - start:.3f}s")
@@ -61,20 +62,26 @@ def do_chat(
61
 
62
  @serving(websocket=True)
63
  def chat(
64
- question: str, history: Optional[List] = [], chat_id: Optional[str] = None, **kwargs
 
 
 
65
  ) -> str:
66
  print("question@chat:", question)
67
  streaming_handler = kwargs.get("streaming_handler")
68
  result = do_chat(question, history, chat_id, streaming_handler)
69
  resp = ChatResponse(
70
- sourceDocs=result["source_documents"] if chat_id is None else []
71
  )
72
  return json.dumps(resp.dict())
73
 
74
 
75
  @serving
76
  def chat_sync(
77
- question: str, history: Optional[List] = [], chat_id: Optional[str] = None, **kwargs
 
 
 
78
  ) -> str:
79
  print("question@chat_sync:", question)
80
  result = do_chat(question, history, chat_id, None)
 
28
 
29
  def do_chat(
30
  question: str,
31
+ history: Optional[List] = None,
32
  chat_id: Optional[str] = None,
33
  streaming_handler: any = None,
34
  ):
35
+ if history is not None:
36
  chat_history = []
37
  if chat_history_enabled:
38
  for element in history:
 
41
 
42
  start = timer()
43
  result = qa_chain.call_chain(
44
+ {"question": question, "chat_history": chat_history, "chat_id": chat_id},
45
+ streaming_handler,
46
  )
47
  end = timer()
48
  print(f"Completed in {end - start:.3f}s")
 
62
 
63
  @serving(websocket=True)
64
  def chat(
65
+ question: str,
66
+ history: Optional[List] = None,
67
+ chat_id: Optional[str] = None,
68
+ **kwargs,
69
  ) -> str:
70
  print("question@chat:", question)
71
  streaming_handler = kwargs.get("streaming_handler")
72
  result = do_chat(question, history, chat_id, streaming_handler)
73
  resp = ChatResponse(
74
+ sourceDocs=result["source_documents"] if history is not None else []
75
  )
76
  return json.dumps(resp.dict())
77
 
78
 
79
  @serving
80
  def chat_sync(
81
+ question: str,
82
+ history: Optional[List] = None,
83
+ chat_id: Optional[str] = None,
84
+ **kwargs,
85
  ) -> str:
86
  print("question@chat_sync:", question)
87
  result = do_chat(question, history, chat_id, None)
test.py CHANGED
@@ -30,6 +30,7 @@ class MyCustomHandler(BaseCallbackHandler):
30
 
31
 
32
  chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
 
33
  questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
34
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
35
 
@@ -68,8 +69,9 @@ while True:
68
  custom_handler.reset()
69
 
70
  start = timer()
 
71
  result = qa_chain.call_chain(
72
- {"question": query, "chat_history": chat_history},
73
  custom_handler,
74
  None,
75
  True,
@@ -87,13 +89,14 @@ while True:
87
  if standalone_question is not None:
88
  print(f"Load relevant documents for standalone question: {standalone_question}")
89
  start = timer()
90
- qa = qa_chain.get_chain()
91
  docs = qa.retriever.get_relevant_documents(standalone_question)
92
  end = timer()
93
-
94
- # print(docs)
95
  print(f"Completed in {end - start:.3f}s")
96
 
 
 
 
97
  if chat_history_enabled == "true":
98
  chat_history.append((query, result["answer"]))
99
 
 
30
 
31
 
32
  chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
33
+ chat_id = sys.argv[2] if len(sys.argv) > 2 else None
34
  questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
35
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
36
 
 
69
  custom_handler.reset()
70
 
71
  start = timer()
72
+ inputs = {"question": query, "chat_history": chat_history, "chat_id": chat_id}
73
  result = qa_chain.call_chain(
74
+ inputs,
75
  custom_handler,
76
  None,
77
  True,
 
89
  if standalone_question is not None:
90
  print(f"Load relevant documents for standalone question: {standalone_question}")
91
  start = timer()
92
+ qa = qa_chain.get_chain(inputs)
93
  docs = qa.retriever.get_relevant_documents(standalone_question)
94
  end = timer()
 
 
95
  print(f"Completed in {end - start:.3f}s")
96
 
97
+ if chatting:
98
+ print(docs)
99
+
100
  if chat_history_enabled == "true":
101
  chat_history.append((query, result["answer"]))
102
 
web CHANGED
@@ -1 +1 @@
1
- Subproject commit a6e3dd97a3cb23eb06f8ad94644aa5b71e624f61
 
1
+ Subproject commit 15f2b72afe6170badfb982c7adba585af30d578a