Daniel Marques commited on
Commit
ef75206
1 Parent(s): d815dea

fix: add websocket in handlerToken

Browse files
Files changed (1) hide show
  1. main.py +10 -3
main.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import glob
3
  import shutil
4
  import subprocess
 
5
 
6
  from typing import Any, Dict, List
7
 
@@ -16,6 +17,9 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
16
  from langchain.prompts import PromptTemplate
17
  from langchain.memory import ConversationBufferMemory
18
  from langchain.callbacks.base import BaseCallbackHandler
 
 
 
19
  from langchain.schema import LLMResult
20
 
21
  # from langchain.embeddings import HuggingFaceEmbeddings
@@ -55,7 +59,7 @@ DB = Chroma(
55
 
56
  RETRIEVER = DB.as_retriever()
57
 
58
- class MyCustomSyncHandler(BaseCallbackHandler):
59
  def __init__(self):
60
  self.end = False
61
  self.websocket = None
@@ -74,7 +78,6 @@ class MyCustomSyncHandler(BaseCallbackHandler):
74
 
75
  print(token)
76
 
77
- self.token += token
78
 
79
 
80
  handlerToken = MyCustomSyncHandler()
@@ -252,6 +255,7 @@ async def websocket_endpoint(websocket: WebSocket):
252
  global QA
253
 
254
  await websocket.accept()
 
255
  try:
256
  while True:
257
  handlerToken.websocket = websocket
@@ -261,7 +265,10 @@ async def websocket_endpoint(websocket: WebSocket):
261
  print(res)
262
 
263
  except WebSocketDisconnect:
264
- await websocket.send_text(f"disconnect")
 
 
 
265
 
266
 
267
 
 
2
  import glob
3
  import shutil
4
  import subprocess
5
+ import sys
6
 
7
  from typing import Any, Dict, List
8
 
 
17
  from langchain.prompts import PromptTemplate
18
  from langchain.memory import ConversationBufferMemory
19
  from langchain.callbacks.base import BaseCallbackHandler
20
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
21
+
22
+
23
  from langchain.schema import LLMResult
24
 
25
  # from langchain.embeddings import HuggingFaceEmbeddings
 
59
 
60
  RETRIEVER = DB.as_retriever()
61
 
62
+ class MyCustomSyncHandler(StreamingStdOutCallbackHandler):
63
  def __init__(self):
64
  self.end = False
65
  self.websocket = None
 
78
 
79
  print(token)
80
 
 
81
 
82
 
83
  handlerToken = MyCustomSyncHandler()
 
255
  global QA
256
 
257
  await websocket.accept()
258
+
259
  try:
260
  while True:
261
  handlerToken.websocket = websocket
 
265
  print(res)
266
 
267
  except WebSocketDisconnect:
268
+ print('disconnect')
269
+ except RuntimeError as error:
270
+ print(error)
271
+
272
 
273
 
274