Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
·
27e6a14
1
Parent(s):
abb8521
fix: add websocket in handlerToken
Browse files- constants.py +3 -0
- main.py +38 -53
- requirements.txt +1 -0
constants.py
CHANGED
@@ -13,6 +13,8 @@ ROOT_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
|
|
13 |
|
14 |
PATH_NAME_SOURCE_DIRECTORY = "SOURCE_DOCUMENTS"
|
15 |
|
|
|
|
|
16 |
# Define the folder for storing database
|
17 |
SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/{PATH_NAME_SOURCE_DIRECTORY}"
|
18 |
|
@@ -43,6 +45,7 @@ N_BATCH = 2048
|
|
43 |
# N_BATCH = 512
|
44 |
|
45 |
|
|
|
46 |
# https://python.langchain.com/en/latest/_modules/langchain/document_loaders/excel.html#UnstructuredExcelLoader
|
47 |
DOCUMENT_MAP = {
|
48 |
".txt": TextLoader,
|
|
|
13 |
|
14 |
PATH_NAME_SOURCE_DIRECTORY = "SOURCE_DOCUMENTS"
|
15 |
|
16 |
+
SHOW_SOURCES=True
|
17 |
+
|
18 |
# Define the folder for storing database
|
19 |
SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/{PATH_NAME_SOURCE_DIRECTORY}"
|
20 |
|
|
|
45 |
# N_BATCH = 512
|
46 |
|
47 |
|
48 |
+
|
49 |
# https://python.langchain.com/en/latest/_modules/langchain/document_loaders/excel.html#UnstructuredExcelLoader
|
50 |
DOCUMENT_MAP = {
|
51 |
".txt": TextLoader,
|
main.py
CHANGED
@@ -1,33 +1,29 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import glob
|
3 |
import shutil
|
4 |
import subprocess
|
5 |
import asyncio
|
6 |
-
|
7 |
-
|
8 |
|
9 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
10 |
from fastapi.staticfiles import StaticFiles
|
11 |
|
12 |
from pydantic import BaseModel
|
13 |
|
14 |
-
#
|
15 |
from langchain.chains import RetrievalQA
|
16 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
17 |
-
from langchain.prompts import PromptTemplate
|
18 |
-
from langchain.memory import ConversationBufferMemory
|
19 |
from langchain.callbacks.base import BaseCallbackHandler
|
20 |
from langchain.schema import LLMResult
|
|
|
21 |
|
22 |
from prompt_template_utils import get_prompt_template
|
23 |
-
|
24 |
-
# from langchain.embeddings import HuggingFaceEmbeddings
|
25 |
from load_models import load_model
|
26 |
|
27 |
-
|
28 |
-
from langchain.vectorstores import Chroma
|
29 |
-
|
30 |
-
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY
|
31 |
|
32 |
class Predict(BaseModel):
|
33 |
prompt: str
|
@@ -35,54 +31,36 @@ class Predict(BaseModel):
|
|
35 |
class Delete(BaseModel):
|
36 |
filename: str
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
DEVICE_TYPE = "cuda"
|
46 |
-
SHOW_SOURCES = True
|
47 |
|
48 |
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
|
49 |
-
|
50 |
-
# load the vectorstore
|
51 |
-
DB = Chroma(
|
52 |
-
persist_directory=PERSIST_DIRECTORY,
|
53 |
-
embedding_function=EMBEDDINGS,
|
54 |
-
client_settings=CHROMA_SETTINGS,
|
55 |
-
)
|
56 |
-
|
57 |
RETRIEVER = DB.as_retriever()
|
58 |
|
59 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
60 |
-
def __init__(self):
|
61 |
-
self.end = False
|
62 |
-
|
63 |
def on_llm_start(
|
64 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
65 |
) -> None:
|
66 |
-
self
|
|
|
|
|
67 |
|
68 |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
69 |
-
self
|
|
|
|
|
70 |
|
71 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
72 |
-
print(self)
|
73 |
-
print(kwargs)
|
74 |
-
|
75 |
-
|
76 |
-
# Create State
|
77 |
-
handlerToken = MyCustomSyncHandler()
|
78 |
|
79 |
-
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True
|
80 |
-
|
81 |
-
# template = """You are a helpful, respectful and honest assistant.
|
82 |
-
# Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
|
83 |
-
# Context: {context}
|
84 |
-
# Question: {question}
|
85 |
-
# """
|
86 |
|
87 |
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
|
88 |
|
@@ -97,7 +75,9 @@ QA = RetrievalQA.from_chain_type(
|
|
97 |
},
|
98 |
)
|
99 |
|
|
|
100 |
|
|
|
101 |
|
102 |
app = FastAPI(title="homepage-app")
|
103 |
api_app = FastAPI(title="api app")
|
@@ -146,7 +126,7 @@ def run_ingest_route():
|
|
146 |
retriever=RETRIEVER,
|
147 |
return_source_documents=SHOW_SOURCES,
|
148 |
chain_type_kwargs={
|
149 |
-
"prompt":
|
150 |
"memory": memory
|
151 |
},
|
152 |
)
|
@@ -250,16 +230,21 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
|
250 |
|
251 |
await websocket.accept()
|
252 |
|
253 |
-
oldReceiveText = ''
|
254 |
-
|
255 |
try:
|
256 |
while True:
|
257 |
prompt = await websocket.receive_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
-
|
260 |
-
handlerToken.callback = websocket.send_text
|
261 |
-
oldReceiveText = prompt
|
262 |
-
await QA(prompt)
|
263 |
|
264 |
except WebSocketDisconnect:
|
265 |
print('disconnect')
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
|
3 |
import os
|
4 |
import glob
|
5 |
import shutil
|
6 |
import subprocess
|
7 |
import asyncio
|
8 |
+
import redis
|
9 |
+
import torch
|
10 |
|
11 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
12 |
from fastapi.staticfiles import StaticFiles
|
13 |
|
14 |
from pydantic import BaseModel
|
15 |
|
16 |
+
# langchain
|
17 |
from langchain.chains import RetrievalQA
|
18 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
|
|
|
19 |
from langchain.callbacks.base import BaseCallbackHandler
|
20 |
from langchain.schema import LLMResult
|
21 |
+
from langchain.vectorstores import Chroma
|
22 |
|
23 |
from prompt_template_utils import get_prompt_template
|
|
|
|
|
24 |
from load_models import load_model
|
25 |
|
26 |
+
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY, SHOW_SOURCES
|
|
|
|
|
|
|
27 |
|
28 |
class Predict(BaseModel):
|
29 |
prompt: str
|
|
|
31 |
class Delete(BaseModel):
|
32 |
filename: str
|
33 |
|
34 |
+
if torch.backends.mps.is_available():
|
35 |
+
DEVICE_TYPE = "mps"
|
36 |
+
elif torch.cuda.is_available():
|
37 |
+
DEVICE_TYPE = "cuda"
|
38 |
+
else:
|
39 |
+
DEVICE_TYPE = "cpu"
|
|
|
|
|
|
|
40 |
|
41 |
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
|
42 |
+
DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
RETRIEVER = DB.as_retriever()
|
44 |
|
45 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
|
|
|
|
|
|
46 |
def on_llm_start(
|
47 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
48 |
) -> None:
|
49 |
+
print(f'on_llm_start self {self}')
|
50 |
+
print(f'on_llm_start kwargs {prompts}')
|
51 |
+
print(f'on_llm_start token {kwargs}')
|
52 |
|
53 |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
54 |
+
print(f'on_llm_end self {self}')
|
55 |
+
print(f'on_llm_end kwargs {response}')
|
56 |
+
print(f'on_llm_end token {kwargs}')
|
57 |
|
58 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
59 |
+
print(f'on_llm_new_token self {self}')
|
60 |
+
print(f'on_llm_new_token kwargs {kwargs}')
|
61 |
+
print(f'on_llm_new_token token {token}')
|
|
|
|
|
|
|
62 |
|
63 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
|
66 |
|
|
|
75 |
},
|
76 |
)
|
77 |
|
78 |
+
redisClient = redis.Redis(host='localhost', port=6379, db=0)
|
79 |
|
80 |
+
r.set('foo', 'bar')
|
81 |
|
82 |
app = FastAPI(title="homepage-app")
|
83 |
api_app = FastAPI(title="api app")
|
|
|
126 |
retriever=RETRIEVER,
|
127 |
return_source_documents=SHOW_SOURCES,
|
128 |
chain_type_kwargs={
|
129 |
+
"prompt": prompt,
|
130 |
"memory": memory
|
131 |
},
|
132 |
)
|
|
|
230 |
|
231 |
await websocket.accept()
|
232 |
|
|
|
|
|
233 |
try:
|
234 |
while True:
|
235 |
prompt = await websocket.receive_text()
|
236 |
+
QA(
|
237 |
+
inputs=prompt,
|
238 |
+
return_only_outputs=True,
|
239 |
+
callbacks=[MyCustomSyncHandler()],
|
240 |
+
tags=f'{client_id}',
|
241 |
+
run_name=f'{client_id}',
|
242 |
+
include_run_info=True
|
243 |
+
)
|
244 |
+
|
245 |
+
response = redisClient.get('foo')
|
246 |
|
247 |
+
await websocket.send_text(response)
|
|
|
|
|
|
|
248 |
|
249 |
except WebSocketDisconnect:
|
250 |
print('disconnect')
|
requirements.txt
CHANGED
@@ -29,6 +29,7 @@ uvicorn
|
|
29 |
fastapi
|
30 |
websockets
|
31 |
pydantic
|
|
|
32 |
|
33 |
# Streamlit related
|
34 |
streamlit
|
|
|
29 |
fastapi
|
30 |
websockets
|
31 |
pydantic
|
32 |
+
redis
|
33 |
|
34 |
# Streamlit related
|
35 |
streamlit
|