Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
07e217d
1
Parent(s):
99f6cbc
fix: add websocket in handlerToken
Browse files- main.py +13 -9
- requirements.txt +1 -0
main.py
CHANGED
@@ -4,6 +4,7 @@ import shutil
|
|
4 |
import subprocess
|
5 |
import asyncio
|
6 |
|
|
|
7 |
from typing import Any, Dict, List
|
8 |
|
9 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
@@ -17,10 +18,10 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
17 |
from langchain.prompts import PromptTemplate
|
18 |
from langchain.memory import ConversationBufferMemory
|
19 |
from langchain.callbacks.base import BaseCallbackHandler
|
20 |
-
|
21 |
-
|
22 |
from langchain.schema import LLMResult
|
23 |
|
|
|
|
|
24 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
25 |
from load_models import load_model
|
26 |
|
@@ -35,8 +36,6 @@ class Predict(BaseModel):
|
|
35 |
class Delete(BaseModel):
|
36 |
filename: str
|
37 |
|
38 |
-
|
39 |
-
|
40 |
# if torch.backends.mps.is_available():
|
41 |
# DEVICE_TYPE = "mps"
|
42 |
# elif torch.cuda.is_available():
|
@@ -59,9 +58,9 @@ DB = Chroma(
|
|
59 |
RETRIEVER = DB.as_retriever()
|
60 |
|
61 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
62 |
-
def __init__(self):
|
63 |
self.end = False
|
64 |
-
self.
|
65 |
|
66 |
def on_llm_start(
|
67 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
@@ -80,8 +79,13 @@ class MyCustomSyncHandler(BaseCallbackHandler):
|
|
80 |
print(token)
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
handlerToken = MyCustomSyncHandler()
|
85 |
|
86 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
87 |
|
@@ -258,9 +262,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
258 |
|
259 |
try:
|
260 |
while True:
|
261 |
-
|
262 |
|
263 |
-
print(
|
264 |
|
265 |
data = await websocket.receive_text()
|
266 |
res = QA(data)
|
|
|
4 |
import subprocess
|
5 |
import asyncio
|
6 |
|
7 |
+
|
8 |
from typing import Any, Dict, List
|
9 |
|
10 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
|
|
18 |
from langchain.prompts import PromptTemplate
|
19 |
from langchain.memory import ConversationBufferMemory
|
20 |
from langchain.callbacks.base import BaseCallbackHandler
|
|
|
|
|
21 |
from langchain.schema import LLMResult
|
22 |
|
23 |
+
from varstate import State
|
24 |
+
|
25 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
26 |
from load_models import load_model
|
27 |
|
|
|
36 |
class Delete(BaseModel):
|
37 |
filename: str
|
38 |
|
|
|
|
|
39 |
# if torch.backends.mps.is_available():
|
40 |
# DEVICE_TYPE = "mps"
|
41 |
# elif torch.cuda.is_available():
|
|
|
58 |
RETRIEVER = DB.as_retriever()
|
59 |
|
60 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
61 |
+
def __init__(self, state):
|
62 |
self.end = False
|
63 |
+
self.state = state
|
64 |
|
65 |
def on_llm_start(
|
66 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
|
79 |
print(token)
|
80 |
|
81 |
|
82 |
+
# Create State
|
83 |
+
|
84 |
+
tokenMessageLLM = State()
|
85 |
+
|
86 |
+
get, update = tokenMessageLLM.create('')
|
87 |
|
88 |
+
handlerToken = MyCustomSyncHandler(update)
|
89 |
|
90 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
91 |
|
|
|
262 |
|
263 |
try:
|
264 |
while True:
|
265 |
+
tokenMessageLLM.after_create(lambda now, old: print(f"{old} updated to {now}."))
|
266 |
|
267 |
+
print(tokenMessageLLM)
|
268 |
|
269 |
data = await websocket.receive_text()
|
270 |
res = QA(data)
|
requirements.txt
CHANGED
@@ -29,6 +29,7 @@ uvicorn
|
|
29 |
fastapi
|
30 |
websockets
|
31 |
pydantic
|
|
|
32 |
|
33 |
# Streamlit related
|
34 |
streamlit
|
|
|
29 |
fastapi
|
30 |
websockets
|
31 |
pydantic
|
32 |
+
varstate
|
33 |
|
34 |
# Streamlit related
|
35 |
streamlit
|