lib search API ready
Browse files- app_modules/init.py +16 -2
- app_modules/llm_chat_chain.py +1 -1
- app_modules/llm_inference.py +12 -5
- app_modules/llm_qa_chain.py +12 -3
- app_modules/llm_summarize_chain.py +1 -1
- server.py +13 -6
- test.py +7 -4
- web +1 -1
app_modules/init.py
CHANGED
@@ -79,14 +79,28 @@ def app_init(initQAChain: bool = True):
|
|
79 |
|
80 |
print(f"Completed in {end - start:.3f}s")
|
81 |
|
82 |
-
vectorstore = load_vectorstor(index_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
start = timer()
|
85 |
llm_loader = LLMLoader(llm_model_type)
|
86 |
llm_loader.init(
|
87 |
n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
|
88 |
)
|
89 |
-
qa_chain =
|
|
|
|
|
|
|
|
|
90 |
end = timer()
|
91 |
print(f"Completed in {end - start:.3f}s")
|
92 |
|
|
|
79 |
|
80 |
print(f"Completed in {end - start:.3f}s")
|
81 |
|
82 |
+
vectorstore = load_vectorstor(using_faiss, index_path, embeddings)
|
83 |
+
|
84 |
+
doc_id_to_vectorstore_mapping = {}
|
85 |
+
rootdir = index_path
|
86 |
+
for file in os.listdir(rootdir):
|
87 |
+
d = os.path.join(rootdir, file)
|
88 |
+
if os.path.isdir(d):
|
89 |
+
v = load_vectorstor(using_faiss, d, embeddings)
|
90 |
+
doc_id_to_vectorstore_mapping[file] = v
|
91 |
+
|
92 |
+
# print(doc_id_to_vectorstore_mapping)
|
93 |
|
94 |
start = timer()
|
95 |
llm_loader = LLMLoader(llm_model_type)
|
96 |
llm_loader.init(
|
97 |
n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
|
98 |
)
|
99 |
+
qa_chain = (
|
100 |
+
QAChain(vectorstore, llm_loader, doc_id_to_vectorstore_mapping)
|
101 |
+
if initQAChain
|
102 |
+
else None
|
103 |
+
)
|
104 |
end = timer()
|
105 |
print(f"Completed in {end - start:.3f}s")
|
106 |
|
app_modules/llm_chat_chain.py
CHANGED
@@ -27,7 +27,7 @@ class ChatChain(LLMInference):
|
|
27 |
def __init__(self, llm_loader):
|
28 |
super().__init__(llm_loader)
|
29 |
|
30 |
-
def create_chain(self) -> Chain:
|
31 |
template = (
|
32 |
get_llama_2_prompt_template()
|
33 |
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
|
|
27 |
def __init__(self, llm_loader):
|
28 |
super().__init__(llm_loader)
|
29 |
|
30 |
+
def create_chain(self, inputs) -> Chain:
|
31 |
template = (
|
32 |
get_llama_2_prompt_template()
|
33 |
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
app_modules/llm_inference.py
CHANGED
@@ -22,12 +22,12 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
22 |
self.chain = None
|
23 |
|
24 |
@abc.abstractmethod
|
25 |
-
def create_chain(self) -> Chain:
|
26 |
pass
|
27 |
|
28 |
-
def get_chain(self) -> Chain:
|
29 |
if self.chain is None:
|
30 |
-
self.chain = self.create_chain()
|
31 |
|
32 |
return self.chain
|
33 |
|
@@ -48,7 +48,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
48 |
try:
|
49 |
self.llm_loader.streamer.reset(q)
|
50 |
|
51 |
-
chain = self.get_chain()
|
52 |
result = (
|
53 |
self._run_chain_with_streaming_handler(
|
54 |
chain, inputs, streaming_handler, testing
|
@@ -61,7 +61,14 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
61 |
result["answer"] = remove_extra_spaces(result["answer"])
|
62 |
|
63 |
source_path = os.environ.get("SOURCE_PATH")
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
documents = result["source_documents"]
|
66 |
for doc in documents:
|
67 |
source = doc.metadata["source"]
|
|
|
22 |
self.chain = None
|
23 |
|
24 |
@abc.abstractmethod
|
25 |
+
def create_chain(self, inputs) -> Chain:
|
26 |
pass
|
27 |
|
28 |
+
def get_chain(self, inputs) -> Chain:
|
29 |
if self.chain is None:
|
30 |
+
self.chain = self.create_chain(inputs)
|
31 |
|
32 |
return self.chain
|
33 |
|
|
|
48 |
try:
|
49 |
self.llm_loader.streamer.reset(q)
|
50 |
|
51 |
+
chain = self.get_chain(inputs)
|
52 |
result = (
|
53 |
self._run_chain_with_streaming_handler(
|
54 |
chain, inputs, streaming_handler, testing
|
|
|
61 |
result["answer"] = remove_extra_spaces(result["answer"])
|
62 |
|
63 |
source_path = os.environ.get("SOURCE_PATH")
|
64 |
+
base_url = os.environ.get("PDF_FILE_BASE_URL")
|
65 |
+
if base_url is not None and len(base_url) > 0:
|
66 |
+
documents = result["source_documents"]
|
67 |
+
for doc in documents:
|
68 |
+
source = doc.metadata["source"]
|
69 |
+
title = source.split("/")[-1]
|
70 |
+
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
|
71 |
+
elif source_path is not None and len(source_path) > 0:
|
72 |
documents = result["source_documents"]
|
73 |
for doc in documents:
|
74 |
source = doc.metadata["source"]
|
app_modules/llm_qa_chain.py
CHANGED
@@ -8,14 +8,23 @@ from app_modules.llm_inference import LLMInference
|
|
8 |
class QAChain(LLMInference):
|
9 |
vectorstore: VectorStore
|
10 |
|
11 |
-
def __init__(self, vectorstore, llm_loader):
|
12 |
super().__init__(llm_loader)
|
13 |
self.vectorstore = vectorstore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
def create_chain(self) -> Chain:
|
16 |
qa = ConversationalRetrievalChain.from_llm(
|
17 |
self.llm_loader.llm,
|
18 |
-
|
19 |
max_tokens_limit=self.llm_loader.max_tokens_limit,
|
20 |
return_source_documents=True,
|
21 |
)
|
|
|
8 |
class QAChain(LLMInference):
|
9 |
vectorstore: VectorStore
|
10 |
|
11 |
+
def __init__(self, vectorstore, llm_loader, doc_id_to_vectorstore_mapping=None):
|
12 |
super().__init__(llm_loader)
|
13 |
self.vectorstore = vectorstore
|
14 |
+
self.doc_id_to_vectorstore_mapping = doc_id_to_vectorstore_mapping
|
15 |
+
|
16 |
+
def get_chain(self, inputs) -> Chain:
|
17 |
+
return self.create_chain(inputs)
|
18 |
+
|
19 |
+
def create_chain(self, inputs) -> Chain:
|
20 |
+
vectorstore = self.vectorstore
|
21 |
+
if "chat_id" in inputs:
|
22 |
+
if inputs["chat_id"] in self.doc_id_to_vectorstore_mapping:
|
23 |
+
vectorstore = self.doc_id_to_vectorstore_mapping[inputs["chat_id"]]
|
24 |
|
|
|
25 |
qa = ConversationalRetrievalChain.from_llm(
|
26 |
self.llm_loader.llm,
|
27 |
+
vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
|
28 |
max_tokens_limit=self.llm_loader.max_tokens_limit,
|
29 |
return_source_documents=True,
|
30 |
)
|
app_modules/llm_summarize_chain.py
CHANGED
@@ -23,7 +23,7 @@ class SummarizeChain(LLMInference):
|
|
23 |
def __init__(self, llm_loader):
|
24 |
super().__init__(llm_loader)
|
25 |
|
26 |
-
def create_chain(self) -> Chain:
|
27 |
use_llama_2_prompt_template = (
|
28 |
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
29 |
)
|
|
|
23 |
def __init__(self, llm_loader):
|
24 |
super().__init__(llm_loader)
|
25 |
|
26 |
+
def create_chain(self, inputs) -> Chain:
|
27 |
use_llama_2_prompt_template = (
|
28 |
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
29 |
)
|
server.py
CHANGED
@@ -28,11 +28,11 @@ class ChatResponse(BaseModel):
|
|
28 |
|
29 |
def do_chat(
|
30 |
question: str,
|
31 |
-
history: Optional[List] =
|
32 |
chat_id: Optional[str] = None,
|
33 |
streaming_handler: any = None,
|
34 |
):
|
35 |
-
if
|
36 |
chat_history = []
|
37 |
if chat_history_enabled:
|
38 |
for element in history:
|
@@ -41,7 +41,8 @@ def do_chat(
|
|
41 |
|
42 |
start = timer()
|
43 |
result = qa_chain.call_chain(
|
44 |
-
{"question": question, "chat_history": chat_history},
|
|
|
45 |
)
|
46 |
end = timer()
|
47 |
print(f"Completed in {end - start:.3f}s")
|
@@ -61,20 +62,26 @@ def do_chat(
|
|
61 |
|
62 |
@serving(websocket=True)
|
63 |
def chat(
|
64 |
-
question: str,
|
|
|
|
|
|
|
65 |
) -> str:
|
66 |
print("question@chat:", question)
|
67 |
streaming_handler = kwargs.get("streaming_handler")
|
68 |
result = do_chat(question, history, chat_id, streaming_handler)
|
69 |
resp = ChatResponse(
|
70 |
-
sourceDocs=result["source_documents"] if
|
71 |
)
|
72 |
return json.dumps(resp.dict())
|
73 |
|
74 |
|
75 |
@serving
|
76 |
def chat_sync(
|
77 |
-
question: str,
|
|
|
|
|
|
|
78 |
) -> str:
|
79 |
print("question@chat_sync:", question)
|
80 |
result = do_chat(question, history, chat_id, None)
|
|
|
28 |
|
29 |
def do_chat(
|
30 |
question: str,
|
31 |
+
history: Optional[List] = None,
|
32 |
chat_id: Optional[str] = None,
|
33 |
streaming_handler: any = None,
|
34 |
):
|
35 |
+
if history is not None:
|
36 |
chat_history = []
|
37 |
if chat_history_enabled:
|
38 |
for element in history:
|
|
|
41 |
|
42 |
start = timer()
|
43 |
result = qa_chain.call_chain(
|
44 |
+
{"question": question, "chat_history": chat_history, "chat_id": chat_id},
|
45 |
+
streaming_handler,
|
46 |
)
|
47 |
end = timer()
|
48 |
print(f"Completed in {end - start:.3f}s")
|
|
|
62 |
|
63 |
@serving(websocket=True)
|
64 |
def chat(
|
65 |
+
question: str,
|
66 |
+
history: Optional[List] = None,
|
67 |
+
chat_id: Optional[str] = None,
|
68 |
+
**kwargs,
|
69 |
) -> str:
|
70 |
print("question@chat:", question)
|
71 |
streaming_handler = kwargs.get("streaming_handler")
|
72 |
result = do_chat(question, history, chat_id, streaming_handler)
|
73 |
resp = ChatResponse(
|
74 |
+
sourceDocs=result["source_documents"] if history is not None else []
|
75 |
)
|
76 |
return json.dumps(resp.dict())
|
77 |
|
78 |
|
79 |
@serving
|
80 |
def chat_sync(
|
81 |
+
question: str,
|
82 |
+
history: Optional[List] = None,
|
83 |
+
chat_id: Optional[str] = None,
|
84 |
+
**kwargs,
|
85 |
) -> str:
|
86 |
print("question@chat_sync:", question)
|
87 |
result = do_chat(question, history, chat_id, None)
|
test.py
CHANGED
@@ -30,6 +30,7 @@ class MyCustomHandler(BaseCallbackHandler):
|
|
30 |
|
31 |
|
32 |
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
|
|
|
33 |
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
34 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
|
35 |
|
@@ -68,8 +69,9 @@ while True:
|
|
68 |
custom_handler.reset()
|
69 |
|
70 |
start = timer()
|
|
|
71 |
result = qa_chain.call_chain(
|
72 |
-
|
73 |
custom_handler,
|
74 |
None,
|
75 |
True,
|
@@ -87,13 +89,14 @@ while True:
|
|
87 |
if standalone_question is not None:
|
88 |
print(f"Load relevant documents for standalone question: {standalone_question}")
|
89 |
start = timer()
|
90 |
-
qa = qa_chain.get_chain()
|
91 |
docs = qa.retriever.get_relevant_documents(standalone_question)
|
92 |
end = timer()
|
93 |
-
|
94 |
-
# print(docs)
|
95 |
print(f"Completed in {end - start:.3f}s")
|
96 |
|
|
|
|
|
|
|
97 |
if chat_history_enabled == "true":
|
98 |
chat_history.append((query, result["answer"]))
|
99 |
|
|
|
30 |
|
31 |
|
32 |
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
|
33 |
+
chat_id = sys.argv[2] if len(sys.argv) > 2 else None
|
34 |
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
35 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
|
36 |
|
|
|
69 |
custom_handler.reset()
|
70 |
|
71 |
start = timer()
|
72 |
+
inputs = {"question": query, "chat_history": chat_history, "chat_id": chat_id}
|
73 |
result = qa_chain.call_chain(
|
74 |
+
inputs,
|
75 |
custom_handler,
|
76 |
None,
|
77 |
True,
|
|
|
89 |
if standalone_question is not None:
|
90 |
print(f"Load relevant documents for standalone question: {standalone_question}")
|
91 |
start = timer()
|
92 |
+
qa = qa_chain.get_chain(inputs)
|
93 |
docs = qa.retriever.get_relevant_documents(standalone_question)
|
94 |
end = timer()
|
|
|
|
|
95 |
print(f"Completed in {end - start:.3f}s")
|
96 |
|
97 |
+
if chatting:
|
98 |
+
print(docs)
|
99 |
+
|
100 |
if chat_history_enabled == "true":
|
101 |
chat_history.append((query, result["answer"]))
|
102 |
|
web
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit 15f2b72afe6170badfb982c7adba585af30d578a
|