Daniel Marques commited on
Commit
27e6a14
·
1 Parent(s): abb8521

fix: add websocket in handlerToken

Browse files
Files changed (3) hide show
  1. constants.py +3 -0
  2. main.py +38 -53
  3. requirements.txt +1 -0
constants.py CHANGED
@@ -13,6 +13,8 @@ ROOT_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
13
 
14
  PATH_NAME_SOURCE_DIRECTORY = "SOURCE_DOCUMENTS"
15
 
 
 
16
  # Define the folder for storing database
17
  SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/{PATH_NAME_SOURCE_DIRECTORY}"
18
 
@@ -43,6 +45,7 @@ N_BATCH = 2048
43
  # N_BATCH = 512
44
 
45
 
 
46
  # https://python.langchain.com/en/latest/_modules/langchain/document_loaders/excel.html#UnstructuredExcelLoader
47
  DOCUMENT_MAP = {
48
  ".txt": TextLoader,
 
13
 
14
  PATH_NAME_SOURCE_DIRECTORY = "SOURCE_DOCUMENTS"
15
 
16
+ SHOW_SOURCES=True
17
+
18
  # Define the folder for storing database
19
  SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/{PATH_NAME_SOURCE_DIRECTORY}"
20
 
 
45
  # N_BATCH = 512
46
 
47
 
48
+
49
  # https://python.langchain.com/en/latest/_modules/langchain/document_loaders/excel.html#UnstructuredExcelLoader
50
  DOCUMENT_MAP = {
51
  ".txt": TextLoader,
main.py CHANGED
@@ -1,33 +1,29 @@
 
 
1
  import os
2
  import glob
3
  import shutil
4
  import subprocess
5
  import asyncio
6
-
7
- from typing import Any, Dict, List
8
 
9
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
10
  from fastapi.staticfiles import StaticFiles
11
 
12
  from pydantic import BaseModel
13
 
14
- # import torch
15
  from langchain.chains import RetrievalQA
16
  from langchain.embeddings import HuggingFaceInstructEmbeddings
17
- from langchain.prompts import PromptTemplate
18
- from langchain.memory import ConversationBufferMemory
19
  from langchain.callbacks.base import BaseCallbackHandler
20
  from langchain.schema import LLMResult
 
21
 
22
  from prompt_template_utils import get_prompt_template
23
-
24
- # from langchain.embeddings import HuggingFaceEmbeddings
25
  from load_models import load_model
26
 
27
- # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
28
- from langchain.vectorstores import Chroma
29
-
30
- from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY
31
 
32
  class Predict(BaseModel):
33
  prompt: str
@@ -35,54 +31,36 @@ class Predict(BaseModel):
35
  class Delete(BaseModel):
36
  filename: str
37
 
38
- # if torch.backends.mps.is_available():
39
- # DEVICE_TYPE = "mps"
40
- # elif torch.cuda.is_available():
41
- # DEVICE_TYPE = "cuda"
42
- # else:
43
- # DEVICE_TYPE = "cpu"
44
-
45
- DEVICE_TYPE = "cuda"
46
- SHOW_SOURCES = True
47
 
48
  EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
49
-
50
- # load the vectorstore
51
- DB = Chroma(
52
- persist_directory=PERSIST_DIRECTORY,
53
- embedding_function=EMBEDDINGS,
54
- client_settings=CHROMA_SETTINGS,
55
- )
56
-
57
  RETRIEVER = DB.as_retriever()
58
 
59
  class MyCustomSyncHandler(BaseCallbackHandler):
60
- def __init__(self):
61
- self.end = False
62
-
63
  def on_llm_start(
64
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
65
  ) -> None:
66
- self.end = False
 
 
67
 
68
  def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
69
- self.end = True
 
 
70
 
71
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
72
- print(self)
73
- print(kwargs)
74
-
75
-
76
- # Create State
77
- handlerToken = MyCustomSyncHandler()
78
 
79
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
80
-
81
- # template = """You are a helpful, respectful and honest assistant.
82
- # Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
83
- # Context: {context}
84
- # Question: {question}
85
- # """
86
 
87
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
88
 
@@ -97,7 +75,9 @@ QA = RetrievalQA.from_chain_type(
97
  },
98
  )
99
 
 
100
 
 
101
 
102
  app = FastAPI(title="homepage-app")
103
  api_app = FastAPI(title="api app")
@@ -146,7 +126,7 @@ def run_ingest_route():
146
  retriever=RETRIEVER,
147
  return_source_documents=SHOW_SOURCES,
148
  chain_type_kwargs={
149
- "prompt": QA_CHAIN_PROMPT,
150
  "memory": memory
151
  },
152
  )
@@ -250,16 +230,21 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
250
 
251
  await websocket.accept()
252
 
253
- oldReceiveText = ''
254
-
255
  try:
256
  while True:
257
  prompt = await websocket.receive_text()
 
 
 
 
 
 
 
 
 
 
258
 
259
- if (oldReceiveText != prompt):
260
- handlerToken.callback = websocket.send_text
261
- oldReceiveText = prompt
262
- await QA(prompt)
263
 
264
  except WebSocketDisconnect:
265
  print('disconnect')
 
1
+ from typing import Any, Dict, List
2
+
3
  import os
4
  import glob
5
  import shutil
6
  import subprocess
7
  import asyncio
8
+ import redis
9
+ import torch
10
 
11
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
12
  from fastapi.staticfiles import StaticFiles
13
 
14
  from pydantic import BaseModel
15
 
16
+ # langchain
17
  from langchain.chains import RetrievalQA
18
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
 
19
  from langchain.callbacks.base import BaseCallbackHandler
20
  from langchain.schema import LLMResult
21
+ from langchain.vectorstores import Chroma
22
 
23
  from prompt_template_utils import get_prompt_template
 
 
24
  from load_models import load_model
25
 
26
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY, SHOW_SOURCES
 
 
 
27
 
28
  class Predict(BaseModel):
29
  prompt: str
 
31
  class Delete(BaseModel):
32
  filename: str
33
 
34
+ if torch.backends.mps.is_available():
35
+ DEVICE_TYPE = "mps"
36
+ elif torch.cuda.is_available():
37
+ DEVICE_TYPE = "cuda"
38
+ else:
39
+ DEVICE_TYPE = "cpu"
 
 
 
40
 
41
  EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
42
+ DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
 
 
 
 
 
 
 
43
  RETRIEVER = DB.as_retriever()
44
 
45
  class MyCustomSyncHandler(BaseCallbackHandler):
 
 
 
46
  def on_llm_start(
47
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
48
  ) -> None:
49
+ print(f'on_llm_start self {self}')
50
+ print(f'on_llm_start kwargs {prompts}')
51
+ print(f'on_llm_start token {kwargs}')
52
 
53
  def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
54
+ print(f'on_llm_end self {self}')
55
+ print(f'on_llm_end kwargs {response}')
56
+ print(f'on_llm_end token {kwargs}')
57
 
58
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
59
+ print(f'on_llm_new_token self {self}')
60
+ print(f'on_llm_new_token kwargs {kwargs}')
61
+ print(f'on_llm_new_token token {token}')
 
 
 
62
 
63
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
 
 
 
 
 
 
64
 
65
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
66
 
 
75
  },
76
  )
77
 
78
+ redisClient = redis.Redis(host='localhost', port=6379, db=0)
79
 
80
+ r.set('foo', 'bar')
81
 
82
  app = FastAPI(title="homepage-app")
83
  api_app = FastAPI(title="api app")
 
126
  retriever=RETRIEVER,
127
  return_source_documents=SHOW_SOURCES,
128
  chain_type_kwargs={
129
+ "prompt": prompt,
130
  "memory": memory
131
  },
132
  )
 
230
 
231
  await websocket.accept()
232
 
 
 
233
  try:
234
  while True:
235
  prompt = await websocket.receive_text()
236
+ QA(
237
+ inputs=prompt,
238
+ return_only_outputs=True,
239
+ callbacks=[MyCustomSyncHandler()],
240
+ tags=f'{client_id}',
241
+ run_name=f'{client_id}',
242
+ include_run_info=True
243
+ )
244
+
245
+ response = redisClient.get('foo')
246
 
247
+ await websocket.send_text(response)
 
 
 
248
 
249
  except WebSocketDisconnect:
250
  print('disconnect')
requirements.txt CHANGED
@@ -29,6 +29,7 @@ uvicorn
29
  fastapi
30
  websockets
31
  pydantic
 
32
 
33
  # Streamlit related
34
  streamlit
 
29
  fastapi
30
  websockets
31
  pydantic
32
+ redis
33
 
34
  # Streamlit related
35
  streamlit