Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
e72e226
1
Parent(s):
dc8d635
fix: add callback
Browse files- load_models.py +2 -2
- main.py +0 -3
load_models.py
CHANGED
@@ -3,6 +3,7 @@ import logging
|
|
3 |
from auto_gptq import AutoGPTQForCausalLM
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
|
|
6 |
|
7 |
from transformers import (
|
8 |
AutoModelForCausalLM,
|
@@ -204,8 +205,6 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
204 |
|
205 |
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
206 |
|
207 |
-
logging.info(streamer)
|
208 |
-
|
209 |
pipe = pipeline(
|
210 |
"text-generation",
|
211 |
model=model,
|
@@ -217,6 +216,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
217 |
repetition_penalty=1.0,
|
218 |
generation_config=generation_config,
|
219 |
streamer=streamer
|
|
|
220 |
)
|
221 |
|
222 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
3 |
from auto_gptq import AutoGPTQForCausalLM
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
6 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
7 |
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
|
|
205 |
|
206 |
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
207 |
|
|
|
|
|
208 |
pipe = pipeline(
|
209 |
"text-generation",
|
210 |
model=model,
|
|
|
216 |
repetition_penalty=1.0,
|
217 |
generation_config=generation_config,
|
218 |
streamer=streamer
|
219 |
+
callbacks=[StreamingStdOutCallbackHandler()]
|
220 |
)
|
221 |
|
222 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
main.py
CHANGED
@@ -179,9 +179,6 @@ async def predict(data: Predict):
|
|
179 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
180 |
)
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
return {"response": prompt_response_dict}
|
186 |
else:
|
187 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
|
|
179 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
180 |
)
|
181 |
|
|
|
|
|
|
|
182 |
return {"response": prompt_response_dict}
|
183 |
else:
|
184 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|