Daniel Marques commited on
Commit
36d73c0
·
1 Parent(s): 880c11d

fix: add callback

Browse files
Files changed (3) hide show
  1. callbacks.py +21 -0
  2. main.py +4 -44
  3. redis-implements/main.py +258 -0
callbacks.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks.base import BaseCallbackHandler
2
+
3
+ class MyCustomSyncHandler(BaseCallbackHandler):
4
+ def __init__(self, redisClient):
5
+ self.message = ''
6
+ self.redisClient = redisClient
7
+
8
+ def on_llm_new_token(self, token: str, **kwargs) -> Any:
9
+ self.message += token
10
+ self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
11
+
12
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
13
+ self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
14
+
15
+ def on_llm_error(
16
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
17
+ ) -> Any:
18
+ self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
19
+
20
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
21
+ self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
main.py CHANGED
@@ -1,13 +1,10 @@
1
- from typing import Any, Dict, List, Union
2
 
3
  import os
4
  import glob
5
  import shutil
6
  import subprocess
7
- import redis
8
  import torch
9
- import concurrent.futures
10
- import json
11
 
12
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
13
  from fastapi.staticfiles import StaticFiles
@@ -43,34 +40,7 @@ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, mode
43
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
44
  RETRIEVER = DB.as_retriever()
45
 
46
- redisClient = redis.Redis(host='localhost', port=6379, db=0)
47
-
48
- class MyCustomSyncHandler(BaseCallbackHandler):
49
- def __init__(self, redisClient):
50
- self.message = ''
51
- self.redisClient = redisClient
52
-
53
- def on_llm_new_token(self, token: str, **kwargs) -> Any:
54
- self.message += token
55
- self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
56
-
57
- def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
58
- print("on_llm_end end")
59
- self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
60
-
61
- def on_llm_error(
62
- self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
63
- ) -> Any:
64
- print("on_llm_error end")
65
- self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
66
-
67
- def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
68
- print("on_chain_end end")
69
- self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
70
-
71
- handleCallback = MyCustomSyncHandler(redisClient)
72
-
73
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handleCallback])
74
 
75
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
76
 
@@ -238,20 +208,10 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
238
  try:
239
  while True:
240
  prompt = await websocket.receive_text()
241
- pubsub = redisClient.pubsub()
242
- pubsub.subscribe(f'{client_id}')
243
 
244
- response = QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True, callbacks=[handleCallback])
245
- await websocket.send_text(f'{response}')
246
 
247
- # with concurrent.futures.ThreadPoolExecutor() as executor:
248
- # executor.submit(QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True, callbacks=[handleCallback]))
249
-
250
- # for item in pubsub.listen():
251
- # if item["type"] == "message":
252
- # message = item["data"].decode('utf-8')
253
- # if message == "end": pubsub.unsubscribe({client_id})
254
- # await websocket.send_text(f'{message}')
255
 
256
 
257
  except WebSocketDisconnect:
 
1
+ from typing import Any, Dict, Union
2
 
3
  import os
4
  import glob
5
  import shutil
6
  import subprocess
 
7
  import torch
 
 
8
 
9
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
10
  from fastapi.staticfiles import StaticFiles
 
40
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
41
  RETRIEVER = DB.as_retriever()
42
 
43
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
46
 
 
208
  try:
209
  while True:
210
  prompt = await websocket.receive_text()
 
 
211
 
212
+ response = QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
 
213
 
214
+ await websocket.send_text(f'{response}')
 
 
 
 
 
 
 
215
 
216
 
217
  except WebSocketDisconnect:
redis-implements/main.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
+
3
+ import os
4
+ import glob
5
+ import shutil
6
+ import subprocess
7
+ import redis
8
+ import torch
9
+ import concurrent.futures
10
+ import json
11
+
12
+ from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
13
+ from fastapi.staticfiles import StaticFiles
14
+
15
+ from pydantic import BaseModel
16
+
17
+ # langchain
18
+ from langchain.chains import RetrievalQA
19
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
20
+ from langchain.callbacks.base import BaseCallbackHandler
21
+ from langchain.schema import LLMResult
22
+ from langchain.vectorstores import Chroma
23
+
24
+ from prompt_template_utils import get_prompt_template
25
+ from load_models import load_model
26
+
27
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY, SHOW_SOURCES
28
+
29
+ class Predict(BaseModel):
30
+ prompt: str
31
+
32
+ class Delete(BaseModel):
33
+ filename: str
34
+
35
+ if torch.backends.mps.is_available():
36
+ DEVICE_TYPE = "mps"
37
+ elif torch.cuda.is_available():
38
+ DEVICE_TYPE = "cuda"
39
+ else:
40
+ DEVICE_TYPE = "cpu"
41
+
42
+ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
43
+ DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
44
+ RETRIEVER = DB.as_retriever()
45
+
46
+ redisClient = redis.Redis(host='localhost', port=6379, db=0)
47
+
48
+ class MyCustomSyncHandler(BaseCallbackHandler):
49
+ def __init__(self, redisClient):
50
+ self.message = ''
51
+ self.redisClient = redisClient
52
+
53
+ def on_llm_new_token(self, token: str, **kwargs) -> Any:
54
+ self.message += token
55
+ self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
56
+
57
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
58
+ print("on_llm_end end")
59
+ self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
60
+
61
+ def on_llm_error(
62
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
63
+ ) -> Any:
64
+ print("on_llm_error end")
65
+ self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
66
+
67
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
68
+ print("on_chain_end end")
69
+ self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
70
+
71
+ handleCallback = MyCustomSyncHandler(redisClient)
72
+
73
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handleCallback])
74
+
75
+ prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
76
+
77
+ QA = RetrievalQA.from_chain_type(
78
+ llm=LLM,
79
+ chain_type="stuff",
80
+ retriever=RETRIEVER,
81
+ return_source_documents=SHOW_SOURCES,
82
+ chain_type_kwargs={
83
+ "prompt": prompt,
84
+ "memory": memory
85
+ },
86
+ )
87
+
88
+ app = FastAPI(title="homepage-app")
89
+ api_app = FastAPI(title="api app")
90
+
91
+ app.mount("/api", api_app, name="api")
92
+ app.mount("/", StaticFiles(directory="static",html = True), name="static")
93
+
94
+ @api_app.get("/training")
95
+ def run_ingest_route():
96
+ global DB
97
+ global RETRIEVER
98
+ global QA
99
+
100
+ try:
101
+ if os.path.exists(PERSIST_DIRECTORY):
102
+ try:
103
+ shutil.rmtree(PERSIST_DIRECTORY)
104
+ except OSError as e:
105
+ raise HTTPException(status_code=500, detail=f"Error: {e.filename} - {e.strerror}.")
106
+ else:
107
+ raise HTTPException(status_code=500, detail="The directory does not exist")
108
+
109
+ run_langest_commands = ["python", "ingest.py"]
110
+
111
+ if DEVICE_TYPE == "cpu":
112
+ run_langest_commands.append("--device_type")
113
+ run_langest_commands.append(DEVICE_TYPE)
114
+
115
+ result = subprocess.run(run_langest_commands, capture_output=True)
116
+
117
+ if result.returncode != 0:
118
+ raise HTTPException(status_code=400, detail="Script execution failed: {}")
119
+
120
+ # load the vectorstore
121
+ DB = Chroma(
122
+ persist_directory=PERSIST_DIRECTORY,
123
+ embedding_function=EMBEDDINGS,
124
+ client_settings=CHROMA_SETTINGS,
125
+ )
126
+
127
+ RETRIEVER = DB.as_retriever()
128
+
129
+ QA = RetrievalQA.from_chain_type(
130
+ llm=LLM,
131
+ chain_type="stuff",
132
+ retriever=RETRIEVER,
133
+ return_source_documents=SHOW_SOURCES,
134
+ chain_type_kwargs={
135
+ "prompt": prompt,
136
+ "memory": memory
137
+ },
138
+ )
139
+
140
+ return {"response": "The training was successfully completed"}
141
+ except Exception as e:
142
+ raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
143
+
144
+ @api_app.get("/api/files")
145
+ def get_files():
146
+ upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
147
+ files = glob.glob(os.path.join(upload_dir, '*'))
148
+
149
+ return {"directory": upload_dir, "files": files}
150
+
151
+ @api_app.delete("/api/delete_document")
152
+ def delete_source_route(data: Delete):
153
+ filename = data.filename
154
+ path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
155
+ file_to_delete = f"{path_source_documents}/{filename}"
156
+
157
+ if os.path.exists(file_to_delete):
158
+ try:
159
+ os.remove(file_to_delete)
160
+ print(f"{file_to_delete} has been deleted.")
161
+
162
+ return {"message": f"{file_to_delete} has been deleted."}
163
+ except OSError as e:
164
+ raise HTTPException(status_code=400, detail=print(f"error: {e}."))
165
+ else:
166
+ raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist."))
167
+
168
+ @api_app.post('/predict')
169
+ async def predict(data: Predict):
170
+ global QA
171
+ user_prompt = data.prompt
172
+ if user_prompt:
173
+ res = QA(user_prompt)
174
+
175
+ answer, docs = res["result"], res["source_documents"]
176
+
177
+ prompt_response_dict = {
178
+ "Prompt": user_prompt,
179
+ "Answer": answer,
180
+ }
181
+
182
+ prompt_response_dict["Sources"] = []
183
+ for document in docs:
184
+ prompt_response_dict["Sources"].append(
185
+ (os.path.basename(str(document.metadata["source"])), str(document.page_content))
186
+ )
187
+
188
+ return {"response": prompt_response_dict}
189
+ else:
190
+ raise HTTPException(status_code=400, detail="Prompt Incorrect")
191
+
192
+ @api_app.post("/save_document/")
193
+ async def create_upload_file(file: UploadFile):
194
+ # Get the file size (in bytes)
195
+ file.file.seek(0, 2)
196
+ file_size = file.file.tell()
197
+
198
+ # move the cursor back to the beginning
199
+ await file.seek(0)
200
+
201
+ if file_size > 10 * 1024 * 1024:
202
+ # more than 10 MB
203
+ raise HTTPException(status_code=400, detail="File too large")
204
+
205
+ content_type = file.content_type
206
+
207
+ if content_type not in [
208
+ "text/plain",
209
+ "text/markdown",
210
+ "text/x-markdown",
211
+ "text/csv",
212
+ "application/msword",
213
+ "application/pdf",
214
+ "application/vnd.ms-excel",
215
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
216
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
217
+ "text/x-python",
218
+ "application/x-python-code"]:
219
+ raise HTTPException(status_code=400, detail="Invalid file type")
220
+
221
+ upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
222
+ if not os.path.exists(upload_dir):
223
+ os.makedirs(upload_dir)
224
+
225
+ dest = os.path.join(upload_dir, file.filename)
226
+
227
+ with open(dest, "wb") as buffer:
228
+ shutil.copyfileobj(file.file, buffer)
229
+
230
+ return {"filename": file.filename}
231
+
232
+ @api_app.websocket("/ws/{client_id}")
233
+ async def websocket_endpoint(websocket: WebSocket, client_id: int):
234
+ global QA
235
+
236
+ await websocket.accept()
237
+
238
+ try:
239
+ while True:
240
+ prompt = await websocket.receive_text()
241
+ pubsub = redisClient.pubsub()
242
+ pubsub.subscribe(f'{client_id}')
243
+
244
+ with concurrent.futures.ThreadPoolExecutor() as executor:
245
+ executor.submit(QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True, callbacks=[handleCallback]))
246
+
247
+ for item in pubsub.listen():
248
+ if item["type"] == "message":
249
+ message = item["data"].decode('utf-8')
250
+ if message == "end": pubsub.unsubscribe({client_id})
251
+ await websocket.send_text(f'{message}')
252
+
253
+
254
+
255
+ except WebSocketDisconnect:
256
+ print('disconnect')
257
+ except RuntimeError as error:
258
+ print(error)