Daniel Marques commited on
Commit
07e217d
1 Parent(s): 99f6cbc

fix: add websocket in handlerToken

Browse files
Files changed (2) hide show
  1. main.py +13 -9
  2. requirements.txt +1 -0
main.py CHANGED
@@ -4,6 +4,7 @@ import shutil
4
  import subprocess
5
  import asyncio
6
 
 
7
  from typing import Any, Dict, List
8
 
9
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
@@ -17,10 +18,10 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
17
  from langchain.prompts import PromptTemplate
18
  from langchain.memory import ConversationBufferMemory
19
  from langchain.callbacks.base import BaseCallbackHandler
20
-
21
-
22
  from langchain.schema import LLMResult
23
 
 
 
24
  # from langchain.embeddings import HuggingFaceEmbeddings
25
  from load_models import load_model
26
 
@@ -35,8 +36,6 @@ class Predict(BaseModel):
35
  class Delete(BaseModel):
36
  filename: str
37
 
38
-
39
-
40
  # if torch.backends.mps.is_available():
41
  # DEVICE_TYPE = "mps"
42
  # elif torch.cuda.is_available():
@@ -59,9 +58,9 @@ DB = Chroma(
59
  RETRIEVER = DB.as_retriever()
60
 
61
  class MyCustomSyncHandler(BaseCallbackHandler):
62
- def __init__(self):
63
  self.end = False
64
- self.websocket = None
65
 
66
  def on_llm_start(
67
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -80,8 +79,13 @@ class MyCustomSyncHandler(BaseCallbackHandler):
80
  print(token)
81
 
82
 
 
 
 
 
 
83
 
84
- handlerToken = MyCustomSyncHandler()
85
 
86
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
87
 
@@ -258,9 +262,9 @@ async def websocket_endpoint(websocket: WebSocket):
258
 
259
  try:
260
  while True:
261
- handlerToken.websocket = websocket
262
 
263
- print(handlerToken.websocket)
264
 
265
  data = await websocket.receive_text()
266
  res = QA(data)
 
4
  import subprocess
5
  import asyncio
6
 
7
+
8
  from typing import Any, Dict, List
9
 
10
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
 
18
  from langchain.prompts import PromptTemplate
19
  from langchain.memory import ConversationBufferMemory
20
  from langchain.callbacks.base import BaseCallbackHandler
 
 
21
  from langchain.schema import LLMResult
22
 
23
+ from varstate import State
24
+
25
  # from langchain.embeddings import HuggingFaceEmbeddings
26
  from load_models import load_model
27
 
 
36
  class Delete(BaseModel):
37
  filename: str
38
 
 
 
39
  # if torch.backends.mps.is_available():
40
  # DEVICE_TYPE = "mps"
41
  # elif torch.cuda.is_available():
 
58
  RETRIEVER = DB.as_retriever()
59
 
60
  class MyCustomSyncHandler(BaseCallbackHandler):
61
+ def __init__(self, state):
62
  self.end = False
63
+ self.state = state
64
 
65
  def on_llm_start(
66
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
 
79
  print(token)
80
 
81
 
82
+ # Create State
83
+
84
+ tokenMessageLLM = State()
85
+
86
+ get, update = tokenMessageLLM.create('')
87
 
88
+ handlerToken = MyCustomSyncHandler(update)
89
 
90
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
91
 
 
262
 
263
  try:
264
  while True:
265
+ tokenMessageLLM.after_create(lambda now, old: print(f"{old} updated to {now}."))
266
 
267
+ print(tokenMessageLLM)
268
 
269
  data = await websocket.receive_text()
270
  res = QA(data)
requirements.txt CHANGED
@@ -29,6 +29,7 @@ uvicorn
29
  fastapi
30
  websockets
31
  pydantic
 
32
 
33
  # Streamlit related
34
  streamlit
 
29
  fastapi
30
  websockets
31
  pydantic
32
+ varstate
33
 
34
  # Streamlit related
35
  streamlit