Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
2ea73cf
1
Parent(s):
3861b3b
fix: add streamer
Browse files- load_models.py +1 -6
- main.py +13 -4
load_models.py
CHANGED
@@ -222,9 +222,4 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
222 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
223 |
logging.info("Local LLM Loaded")
|
224 |
|
225 |
-
|
226 |
-
for new_text in streamer:
|
227 |
-
generated_text += new_text
|
228 |
-
print(generated_text)
|
229 |
-
|
230 |
-
return local_llm
|
|
|
222 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
223 |
logging.info("Local LLM Loaded")
|
224 |
|
225 |
+
return [local_llm, streamer]
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -42,9 +42,11 @@ DB = Chroma(
|
|
42 |
|
43 |
RETRIEVER = DB.as_retriever()
|
44 |
|
45 |
-
|
|
|
|
|
46 |
|
47 |
-
template = """
|
48 |
You should only respond only topics that contains in documents use to training.
|
49 |
Use the following pieces of context to answer the question at the end.
|
50 |
Always answer in the most helpful and safe way possible.
|
@@ -70,7 +72,6 @@ QA = RetrievalQA.from_chain_type(
|
|
70 |
},
|
71 |
)
|
72 |
|
73 |
-
|
74 |
class Predict(BaseModel):
|
75 |
prompt: str
|
76 |
|
@@ -145,7 +146,7 @@ def get_files():
|
|
145 |
def delete_source_route(data: Delete):
|
146 |
filename = data.filename
|
147 |
path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
|
148 |
-
file_to_delete = f"{path_source_documents}
|
149 |
|
150 |
if os.path.exists(file_to_delete):
|
151 |
try:
|
@@ -166,6 +167,9 @@ async def predict(data: Predict):
|
|
166 |
# print(f'User Prompt: {user_prompt}')
|
167 |
# Get the answer from the chain
|
168 |
res = QA(user_prompt)
|
|
|
|
|
|
|
169 |
answer, docs = res["result"], res["source_documents"]
|
170 |
|
171 |
prompt_response_dict = {
|
@@ -179,6 +183,11 @@ async def predict(data: Predict):
|
|
179 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
180 |
)
|
181 |
|
|
|
|
|
|
|
|
|
|
|
182 |
return {"response": prompt_response_dict}
|
183 |
else:
|
184 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
|
|
42 |
|
43 |
RETRIEVER = DB.as_retriever()
|
44 |
|
45 |
+
models = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
|
46 |
+
LLM = models[0]
|
47 |
+
STREAMER = models[1]
|
48 |
|
49 |
+
template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
|
50 |
You should only respond only topics that contains in documents use to training.
|
51 |
Use the following pieces of context to answer the question at the end.
|
52 |
Always answer in the most helpful and safe way possible.
|
|
|
72 |
},
|
73 |
)
|
74 |
|
|
|
75 |
class Predict(BaseModel):
|
76 |
prompt: str
|
77 |
|
|
|
146 |
def delete_source_route(data: Delete):
|
147 |
filename = data.filename
|
148 |
path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
|
149 |
+
file_to_delete = f"{path_source_documents}/{filename}"
|
150 |
|
151 |
if os.path.exists(file_to_delete):
|
152 |
try:
|
|
|
167 |
# print(f'User Prompt: {user_prompt}')
|
168 |
# Get the answer from the chain
|
169 |
res = QA(user_prompt)
|
170 |
+
|
171 |
+
print(res)
|
172 |
+
|
173 |
answer, docs = res["result"], res["source_documents"]
|
174 |
|
175 |
prompt_response_dict = {
|
|
|
183 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
184 |
)
|
185 |
|
186 |
+
generated_text = ""
|
187 |
+
for new_text in STREAMER:
|
188 |
+
generated_text += new_text
|
189 |
+
print(generated_text)
|
190 |
+
|
191 |
return {"response": prompt_response_dict}
|
192 |
else:
|
193 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|