Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
99f6cbc
1
Parent(s):
ef75206
fix: add websocket in handlerToken
Browse files
main.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import glob
|
3 |
import shutil
|
4 |
import subprocess
|
5 |
-
import
|
6 |
|
7 |
from typing import Any, Dict, List
|
8 |
|
@@ -17,7 +17,6 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
17 |
from langchain.prompts import PromptTemplate
|
18 |
from langchain.memory import ConversationBufferMemory
|
19 |
from langchain.callbacks.base import BaseCallbackHandler
|
20 |
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
21 |
|
22 |
|
23 |
from langchain.schema import LLMResult
|
@@ -59,7 +58,7 @@ DB = Chroma(
|
|
59 |
|
60 |
RETRIEVER = DB.as_retriever()
|
61 |
|
62 |
-
class MyCustomSyncHandler(
|
63 |
def __init__(self):
|
64 |
self.end = False
|
65 |
self.websocket = None
|
@@ -73,8 +72,10 @@ class MyCustomSyncHandler(StreamingStdOutCallbackHandler):
|
|
73 |
self.end = True
|
74 |
|
75 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
|
|
|
|
76 |
if self.websocket != None:
|
77 |
-
self.websocket.send_text(token)
|
78 |
|
79 |
print(token)
|
80 |
|
@@ -82,7 +83,7 @@ class MyCustomSyncHandler(StreamingStdOutCallbackHandler):
|
|
82 |
|
83 |
handlerToken = MyCustomSyncHandler()
|
84 |
|
85 |
-
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[])
|
86 |
|
87 |
template = """You are a helpful, respectful and honest assistant.
|
88 |
Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
|
@@ -101,7 +102,6 @@ QA = RetrievalQA.from_chain_type(
|
|
101 |
return_source_documents=SHOW_SOURCES,
|
102 |
chain_type_kwargs={
|
103 |
"prompt": QA_CHAIN_PROMPT,
|
104 |
-
"callbacks": [handlerToken]
|
105 |
},
|
106 |
)
|
107 |
|
@@ -260,6 +260,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
260 |
while True:
|
261 |
handlerToken.websocket = websocket
|
262 |
|
|
|
|
|
263 |
data = await websocket.receive_text()
|
264 |
res = QA(data)
|
265 |
print(res)
|
|
|
2 |
import glob
|
3 |
import shutil
|
4 |
import subprocess
|
5 |
+
import asyncio
|
6 |
|
7 |
from typing import Any, Dict, List
|
8 |
|
|
|
17 |
from langchain.prompts import PromptTemplate
|
18 |
from langchain.memory import ConversationBufferMemory
|
19 |
from langchain.callbacks.base import BaseCallbackHandler
|
|
|
20 |
|
21 |
|
22 |
from langchain.schema import LLMResult
|
|
|
58 |
|
59 |
RETRIEVER = DB.as_retriever()
|
60 |
|
61 |
+
class MyCustomSyncHandler(BaseCallbackHandler):
|
62 |
def __init__(self):
|
63 |
self.end = False
|
64 |
self.websocket = None
|
|
|
72 |
self.end = True
|
73 |
|
74 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
75 |
+
print(token)
|
76 |
+
|
77 |
if self.websocket != None:
|
78 |
+
asyncio.run(self.websocket.send_text(token))
|
79 |
|
80 |
print(token)
|
81 |
|
|
|
83 |
|
84 |
handlerToken = MyCustomSyncHandler()
|
85 |
|
86 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
87 |
|
88 |
template = """You are a helpful, respectful and honest assistant.
|
89 |
Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
|
|
|
102 |
return_source_documents=SHOW_SOURCES,
|
103 |
chain_type_kwargs={
|
104 |
"prompt": QA_CHAIN_PROMPT,
|
|
|
105 |
},
|
106 |
)
|
107 |
|
|
|
260 |
while True:
|
261 |
handlerToken.websocket = websocket
|
262 |
|
263 |
+
print(handlerToken.websocket)
|
264 |
+
|
265 |
data = await websocket.receive_text()
|
266 |
res = QA(data)
|
267 |
print(res)
|