dh-mc commited on
Commit
4359eb6
·
1 Parent(s): 2826548

code complete

Browse files
app.py CHANGED
@@ -6,6 +6,7 @@ from timeit import default_timer as timer
6
 
7
  import gradio as gr
8
  from anyio.from_thread import start_blocking_portal
 
9
  from app_modules.init import app_init
10
  from app_modules.utils import print_llm_response, remove_extra_spaces
11
 
 
6
 
7
  import gradio as gr
8
  from anyio.from_thread import start_blocking_portal
9
+
10
  from app_modules.init import app_init
11
  from app_modules.utils import print_llm_response, remove_extra_spaces
12
 
app_modules/llm_chat_chain.py CHANGED
@@ -1,7 +1,8 @@
 
1
  from langchain.chains import ConversationalRetrievalChain
2
  from langchain.chains.base import Chain
3
  from langchain.memory import ConversationBufferMemory
4
- from langchain import LLMChain, PromptTemplate
5
  from app_modules.llm_inference import LLMInference
6
 
7
 
 
1
+ from langchain import LLMChain, PromptTemplate
2
  from langchain.chains import ConversationalRetrievalChain
3
  from langchain.chains.base import Chain
4
  from langchain.memory import ConversationBufferMemory
5
+
6
  from app_modules.llm_inference import LLMInference
7
 
8
 
app_modules/llm_loader.py CHANGED
@@ -93,7 +93,7 @@ class LLMLoader:
93
  def __init__(self, llm_model_type, max_tokens_limit: int = 2048):
94
  self.llm_model_type = llm_model_type
95
  self.llm = None
96
- self.streamer = TextIteratorStreamer("")
97
  self.max_tokens_limit = max_tokens_limit
98
  self.search_kwargs = {"k": 4}
99
 
@@ -138,7 +138,9 @@ class LLMLoader:
138
  bnb_8bit_use_double_quant=load_quantized_model == "8bit",
139
  )
140
 
141
- callbacks = [self.streamer]
 
 
142
  if custom_handler is not None:
143
  callbacks.append(custom_handler)
144
 
 
93
  def __init__(self, llm_model_type, max_tokens_limit: int = 2048):
94
  self.llm_model_type = llm_model_type
95
  self.llm = None
96
+ self.streamer = None
97
  self.max_tokens_limit = max_tokens_limit
98
  self.search_kwargs = {"k": 4}
99
 
 
138
  bnb_8bit_use_double_quant=load_quantized_model == "8bit",
139
  )
140
 
141
+ callbacks = []
142
+ if self.streamer is not None:
143
+ callbacks.append(self.streamer)
144
  if custom_handler is not None:
145
  callbacks.append(custom_handler)
146
 
server.py CHANGED
@@ -1,74 +1,21 @@
1
  """Main entrypoint for the app."""
2
  import json
3
  import os
4
- import time
5
- from queue import Queue
6
  from timeit import default_timer as timer
7
  from typing import List, Optional
8
 
9
- from langchain.embeddings import HuggingFaceInstructEmbeddings
10
- from langchain.vectorstores.chroma import Chroma
11
- from langchain.vectorstores.faiss import FAISS
12
  from lcserve import serving
13
  from pydantic import BaseModel
14
 
15
- from app_modules.presets import *
16
- from app_modules.qa_chain import QAChain
17
- from app_modules.utils import *
18
 
19
- # Constants
20
- init_settings()
21
 
22
- # https://github.com/huggingface/transformers/issues/17611
23
- os.environ["CURL_CA_BUNDLE"] = ""
24
-
25
- hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
26
- print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
27
- print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
28
-
29
- hf_embeddings_model_name = (
30
- os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
31
- )
32
- n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
33
- index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get("CHROMADB_INDEX_PATH")
34
- using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
35
- llm_model_type = os.environ.get("LLM_MODEL_TYPE")
36
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
37
- show_param_settings = os.environ.get("SHOW_PARAM_SETTINGS") == "true"
38
- share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
39
-
40
-
41
- streaming_enabled = True # llm_model_type in ["openai", "llamacpp"]
42
-
43
- start = timer()
44
- embeddings = HuggingFaceInstructEmbeddings(
45
- model_name=hf_embeddings_model_name,
46
- model_kwargs={"device": hf_embeddings_device_type},
47
- )
48
- end = timer()
49
-
50
- print(f"Completed in {end - start:.3f}s")
51
-
52
- start = timer()
53
-
54
- print(f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}")
55
 
56
- if not os.path.isdir(index_path):
57
- raise ValueError(f"{index_path} does not exist!")
58
- elif using_faiss:
59
- vectorstore = FAISS.load_local(index_path, embeddings)
60
- else:
61
- vectorstore = Chroma(embedding_function=embeddings, persist_directory=index_path)
62
-
63
- end = timer()
64
-
65
- print(f"Completed in {end - start:.3f}s")
66
-
67
- start = timer()
68
- qa_chain = QAChain(vectorstore, llm_model_type)
69
- qa_chain.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
70
- end = timer()
71
- print(f"Completed in {end - start:.3f}s")
72
 
73
 
74
  class ChatResponse(BaseModel):
@@ -80,30 +27,41 @@ class ChatResponse(BaseModel):
80
 
81
 
82
  @serving(websocket=True)
83
- def chat(question: str, history: Optional[List], **kwargs) -> str:
 
 
 
84
  # Get the `streaming_handler` from `kwargs`. This is used to stream data to the client.
85
- streaming_handler = kwargs.get("streaming_handler") if streaming_enabled else None
86
- chat_history = []
87
- if chat_history_enabled:
88
- for element in history:
89
- item = (element[0] or "", element[1] or "")
90
- chat_history.append(item)
91
-
92
- start = timer()
93
- result = qa_chain.call(
94
- {"question": question, "chat_history": chat_history}, streaming_handler
95
- )
96
- end = timer()
97
- print(f"Completed in {end - start:.3f}s")
98
-
99
- resp = ChatResponse(sourceDocs=result["source_documents"])
100
-
101
- if not streaming_enabled:
102
- resp.token = remove_extra_spaces(result["answer"])
103
- print(resp.token)
104
-
105
- return json.dumps(resp.dict())
 
 
 
 
 
 
 
 
106
 
107
 
108
  if __name__ == "__main__":
109
- print_llm_response(json.loads(chat("What is PCI DSS?", [])))
 
1
  """Main entrypoint for the app."""
2
  import json
3
  import os
 
 
4
  from timeit import default_timer as timer
5
  from typing import List, Optional
6
 
 
 
 
7
  from lcserve import serving
8
  from pydantic import BaseModel
9
 
10
+ from app_modules.init import app_init
11
+ from app_modules.llm_chat_chain import ChatChain
12
+ from app_modules.utils import print_llm_response
13
 
14
+ llm_loader, qa_chain = app_init()
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ uuid_to_chat_chain_mapping = dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class ChatResponse(BaseModel):
 
27
 
28
 
29
  @serving(websocket=True)
30
+ def chat(
31
+ question: str, history: Optional[List] = [], uuid: Optional[str] = None, **kwargs
32
+ ) -> str:
33
+ print(f"uuid: {uuid}")
34
  # Get the `streaming_handler` from `kwargs`. This is used to stream data to the client.
35
+ streaming_handler = kwargs.get("streaming_handler")
36
+ if uuid is None:
37
+ chat_history = []
38
+ if chat_history_enabled:
39
+ for element in history:
40
+ item = (element[0] or "", element[1] or "")
41
+ chat_history.append(item)
42
+
43
+ start = timer()
44
+ result = qa_chain.call_chain(
45
+ {"question": question, "chat_history": chat_history}, streaming_handler
46
+ )
47
+ end = timer()
48
+ print(f"Completed in {end - start:.3f}s")
49
+
50
+ resp = ChatResponse(sourceDocs=result["source_documents"])
51
+
52
+ return json.dumps(resp.dict())
53
+ else:
54
+ if uuid in uuid_to_chat_chain_mapping:
55
+ chat = uuid_to_chat_chain_mapping[uuid]
56
+ else:
57
+ chat = ChatChain(llm_loader)
58
+ uuid_to_chat_chain_mapping[uuid] = chat
59
+ result = chat.call_chain({"question": question}, streaming_handler)
60
+ print(f"result: {result}")
61
+
62
+ resp = ChatResponse(sourceDocs=[])
63
+ return json.dumps(resp.dict())
64
 
65
 
66
  if __name__ == "__main__":
67
+ print_llm_response(json.loads(chat("What's deep learning?", [])))