Daniel Marques commited on
Commit
760ae83
1 Parent(s): 8a26b55

fix: add callback StreamingStdOutCallbackHandler

Browse files
Files changed (2) hide show
  1. load_models.py +2 -2
  2. main.py +0 -14
load_models.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  from auto_gptq import AutoGPTQForCausalLM
6
  from huggingface_hub import hf_hub_download
7
  from langchain.llms import LlamaCpp, HuggingFacePipeline
 
8
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
 
10
  from transformers import (
@@ -68,7 +69,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
68
  kwargs["stream"] = stream
69
 
70
  if stream == True:
71
- kwargs["callbacks"] = callbacks
72
 
73
  return LlamaCpp(**kwargs)
74
  except:
@@ -220,7 +221,6 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
220
  top_k=40,
221
  repetition_penalty=1.0,
222
  generation_config=generation_config,
223
- # callbacks=callbacks
224
  )
225
 
226
  local_llm = HuggingFacePipeline(pipeline=pipe)
 
5
  from auto_gptq import AutoGPTQForCausalLM
6
  from huggingface_hub import hf_hub_download
7
  from langchain.llms import LlamaCpp, HuggingFacePipeline
8
+
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
 
11
  from transformers import (
 
69
  kwargs["stream"] = stream
70
 
71
  if stream == True:
72
+ kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
73
 
74
  return LlamaCpp(**kwargs)
75
  except:
 
221
  top_k=40,
222
  repetition_penalty=1.0,
223
  generation_config=generation_config,
 
224
  )
225
 
226
  local_llm = HuggingFacePipeline(pipeline=pipe)
main.py CHANGED
@@ -44,19 +44,6 @@ class MyCustomHandler(BaseCallbackHandler):
44
  async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
45
  print("finish")
46
 
47
- class CustomHandler(BaseCallbackHandler):
48
- def on_llm_new_token(self, token: str, **kwargs) -> None:
49
- print(f" CustomHandler: {token}")
50
-
51
- async def on_llm_start(
52
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
53
- ) -> None:
54
- class_name = serialized["name"]
55
- print("CustomHandler start")
56
-
57
- async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
58
- print("CustomHandler finish")
59
-
60
  # if torch.backends.mps.is_available():
61
  # DEVICE_TYPE = "mps"
62
  # elif torch.cuda.is_available():
@@ -101,7 +88,6 @@ QA = RetrievalQA.from_chain_type(
101
  chain_type_kwargs={
102
  "prompt": QA_CHAIN_PROMPT,
103
  "memory": memory,
104
- "callbacks": [CustomHandler()]
105
  },
106
  )
107
 
 
44
  async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
45
  print("finish")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # if torch.backends.mps.is_available():
48
  # DEVICE_TYPE = "mps"
49
  # elif torch.cuda.is_available():
 
88
  chain_type_kwargs={
89
  "prompt": QA_CHAIN_PROMPT,
90
  "memory": memory,
 
91
  },
92
  )
93