Daniel Marques commited on
Commit
b0d4d1d
1 Parent(s): 2084d31

feat: add websocket

Browse files
Files changed (1) hide show
  1. main.py +3 -37
main.py CHANGED
@@ -40,7 +40,7 @@ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, mode
40
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
41
  RETRIEVER = DB.as_retriever()
42
 
43
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[])
44
 
45
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
46
 
@@ -211,24 +211,7 @@ async def create_upload_file(file: UploadFile):
211
 
212
  @api_app.websocket("/ws/{user_id}")
213
  async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
214
- DB_USER = Chroma(
215
- persist_directory=PERSIST_DIRECTORY,
216
- embedding_function=EMBEDDINGS,
217
- client_settings=CHROMA_SETTINGS,
218
- )
219
-
220
- RETRIEVER = DB_USER.as_retriever()
221
-
222
- QA = RetrievalQA.from_chain_type(
223
- llm=LLM,
224
- chain_type="stuff",
225
- retriever=RETRIEVER,
226
- return_source_documents=SHOW_SOURCES,
227
- chain_type_kwargs={
228
- "prompt": prompt,
229
- "memory": memory
230
- },
231
- )
232
 
233
  message = {
234
  "message": f"Student {user_id} connected"
@@ -258,24 +241,7 @@ async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
258
 
259
  @api_app.websocket("/ws/{room_id}/{user_id}")
260
  async def websocket_endpoint_room(websocket: WebSocket, room_id: str, user_id: str):
261
- DB_ROOM = Chroma(
262
- persist_directory=PERSIST_DIRECTORY,
263
- embedding_function=EMBEDDINGS,
264
- client_settings=CHROMA_SETTINGS,
265
- )
266
-
267
- RETRIEVER = DB_ROOM.as_retriever()
268
-
269
- QA = RetrievalQA.from_chain_type(
270
- llm=LLM,
271
- chain_type="stuff",
272
- retriever=RETRIEVER,
273
- return_source_documents=SHOW_SOURCES,
274
- chain_type_kwargs={
275
- "prompt": prompt,
276
- "memory": memory,
277
- },
278
- )
279
 
280
  message = {
281
  "message": f"Student {user_id} connected to the classroom"
 
40
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
41
  RETRIEVER = DB.as_retriever()
42
 
43
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
44
 
45
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
46
 
 
211
 
212
  @api_app.websocket("/ws/{user_id}")
213
  async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
214
+ global QA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  message = {
217
  "message": f"Student {user_id} connected"
 
241
 
242
  @api_app.websocket("/ws/{room_id}/{user_id}")
243
  async def websocket_endpoint_room(websocket: WebSocket, room_id: str, user_id: str):
244
+ global QA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  message = {
247
  "message": f"Student {user_id} connected to the classroom"