Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
198843f
1
Parent(s):
8fa0233
fix: add streamer
Browse files- load_models.py +31 -1
- main.py +5 -8
load_models.py
CHANGED
@@ -1,9 +1,15 @@
|
|
1 |
import torch
|
|
|
2 |
import logging
|
|
|
|
|
3 |
from auto_gptq import AutoGPTQForCausalLM
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
6 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
|
|
|
|
7 |
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
@@ -22,6 +28,29 @@ torch.set_grad_enabled(False)
|
|
22 |
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging, stream = False):
|
26 |
"""
|
27 |
Load a GGUF/GGML quantized model using LlamaCpp.
|
@@ -66,6 +95,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
|
|
66 |
|
67 |
#add stream
|
68 |
kwargs["stream"] = stream
|
|
|
69 |
|
70 |
return LlamaCpp(**kwargs)
|
71 |
except:
|
@@ -220,7 +250,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
220 |
repetition_penalty=1.0,
|
221 |
generation_config=generation_config,
|
222 |
streamer=streamer,
|
223 |
-
callbacks=[
|
224 |
)
|
225 |
|
226 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
1 |
import torch
|
2 |
+
import asyncio
|
3 |
import logging
|
4 |
+
from typing import Any, Dict, List
|
5 |
+
|
6 |
from auto_gptq import AutoGPTQForCausalLM
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
9 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
10 |
+
from langchain.schema import LLMResult
|
11 |
+
|
12 |
+
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
13 |
|
14 |
from transformers import (
|
15 |
AutoModelForCausalLM,
|
|
|
28 |
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
|
29 |
|
30 |
|
31 |
+
class MyCustomSyncHandler(BaseCallbackHandler):
|
32 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
33 |
+
print(f"Sync handler being called in a `thread_pool_executor`: token: {token}")
|
34 |
+
|
35 |
+
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
36 |
+
"""Async callback handler that can be used to handle callbacks from langchain."""
|
37 |
+
|
38 |
+
async def on_llm_start(
|
39 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
40 |
+
) -> None:
|
41 |
+
"""Run when chain starts running."""
|
42 |
+
print("zzzz....")
|
43 |
+
await asyncio.sleep(0.3)
|
44 |
+
class_name = serialized["name"]
|
45 |
+
print("Hi! I just woke up. Your llm is starting")
|
46 |
+
|
47 |
+
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
48 |
+
"""Run when chain ends running."""
|
49 |
+
print("zzzz....")
|
50 |
+
await asyncio.sleep(0.3)
|
51 |
+
print("Hi! I just woke up. Your llm is ending")
|
52 |
+
|
53 |
+
|
54 |
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging, stream = False):
|
55 |
"""
|
56 |
Load a GGUF/GGML quantized model using LlamaCpp.
|
|
|
95 |
|
96 |
#add stream
|
97 |
kwargs["stream"] = stream
|
98 |
+
kwargs["callbacks"] = [MyCustomSyncHandler(), MyCustomAsyncHandler()]
|
99 |
|
100 |
return LlamaCpp(**kwargs)
|
101 |
except:
|
|
|
250 |
repetition_penalty=1.0,
|
251 |
generation_config=generation_config,
|
252 |
streamer=streamer,
|
253 |
+
callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()]
|
254 |
)
|
255 |
|
256 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
main.py
CHANGED
@@ -42,10 +42,7 @@ DB = Chroma(
|
|
42 |
|
43 |
RETRIEVER = DB.as_retriever()
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
LLM = models[0]
|
48 |
-
STREAMER = models[1]
|
49 |
|
50 |
template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
|
51 |
You should only respond only topics that contains in documents use to training.
|
@@ -182,10 +179,10 @@ async def predict(data: Predict):
|
|
182 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
183 |
)
|
184 |
|
185 |
-
generated_text = ""
|
186 |
-
for new_text in STREAMER:
|
187 |
-
|
188 |
-
|
189 |
|
190 |
return {"response": prompt_response_dict}
|
191 |
else:
|
|
|
42 |
|
43 |
RETRIEVER = DB.as_retriever()
|
44 |
|
45 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
|
|
|
|
|
|
|
46 |
|
47 |
template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
|
48 |
You should only respond only topics that contains in documents use to training.
|
|
|
179 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
180 |
)
|
181 |
|
182 |
+
# generated_text = ""
|
183 |
+
# for new_text in STREAMER:
|
184 |
+
# generated_text += new_text
|
185 |
+
# print(generated_text)
|
186 |
|
187 |
return {"response": prompt_response_dict}
|
188 |
else:
|