Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
760ae83
1
Parent(s):
8a26b55
fix: add callback StreamingStdOutCallbackHandler
Browse files- load_models.py +2 -2
- main.py +0 -14
load_models.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5 |
from auto_gptq import AutoGPTQForCausalLM
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
|
|
8 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
9 |
|
10 |
from transformers import (
|
@@ -68,7 +69,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
|
|
68 |
kwargs["stream"] = stream
|
69 |
|
70 |
if stream == True:
|
71 |
-
kwargs["callbacks"] =
|
72 |
|
73 |
return LlamaCpp(**kwargs)
|
74 |
except:
|
@@ -220,7 +221,6 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
220 |
top_k=40,
|
221 |
repetition_penalty=1.0,
|
222 |
generation_config=generation_config,
|
223 |
-
# callbacks=callbacks
|
224 |
)
|
225 |
|
226 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
5 |
from auto_gptq import AutoGPTQForCausalLM
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
8 |
+
|
9 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
10 |
|
11 |
from transformers import (
|
|
|
69 |
kwargs["stream"] = stream
|
70 |
|
71 |
if stream == True:
|
72 |
+
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
|
73 |
|
74 |
return LlamaCpp(**kwargs)
|
75 |
except:
|
|
|
221 |
top_k=40,
|
222 |
repetition_penalty=1.0,
|
223 |
generation_config=generation_config,
|
|
|
224 |
)
|
225 |
|
226 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
main.py
CHANGED
@@ -44,19 +44,6 @@ class MyCustomHandler(BaseCallbackHandler):
|
|
44 |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
45 |
print("finish")
|
46 |
|
47 |
-
class CustomHandler(BaseCallbackHandler):
|
48 |
-
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
49 |
-
print(f" CustomHandler: {token}")
|
50 |
-
|
51 |
-
async def on_llm_start(
|
52 |
-
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
53 |
-
) -> None:
|
54 |
-
class_name = serialized["name"]
|
55 |
-
print("CustomHandler start")
|
56 |
-
|
57 |
-
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
58 |
-
print("CustomHandler finish")
|
59 |
-
|
60 |
# if torch.backends.mps.is_available():
|
61 |
# DEVICE_TYPE = "mps"
|
62 |
# elif torch.cuda.is_available():
|
@@ -101,7 +88,6 @@ QA = RetrievalQA.from_chain_type(
|
|
101 |
chain_type_kwargs={
|
102 |
"prompt": QA_CHAIN_PROMPT,
|
103 |
"memory": memory,
|
104 |
-
"callbacks": [CustomHandler()]
|
105 |
},
|
106 |
)
|
107 |
|
|
|
44 |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
45 |
print("finish")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# if torch.backends.mps.is_available():
|
48 |
# DEVICE_TYPE = "mps"
|
49 |
# elif torch.cuda.is_available():
|
|
|
88 |
chain_type_kwargs={
|
89 |
"prompt": QA_CHAIN_PROMPT,
|
90 |
"memory": memory,
|
|
|
91 |
},
|
92 |
)
|
93 |
|