Spaces:
Sleeping
Sleeping
import logging | |
from LLM.chat import Chat | |
from baseHandler import BaseHandler | |
from mlx_lm import load, stream_generate, generate | |
from rich.console import Console | |
import torch | |
logging.basicConfig( | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
console = Console() | |
class MLXLanguageModelHandler(BaseHandler): | |
""" | |
Handles the language model part. | |
""" | |
def setup( | |
self, | |
model_name="microsoft/Phi-3-mini-4k-instruct", | |
device="mps", | |
torch_dtype="float16", | |
gen_kwargs={}, | |
user_role="user", | |
chat_size=1, | |
init_chat_role=None, | |
init_chat_prompt="You are a helpful AI assistant.", | |
): | |
self.model_name = model_name | |
self.model, self.tokenizer = load(self.model_name) | |
self.gen_kwargs = gen_kwargs | |
self.chat = Chat(chat_size) | |
if init_chat_role: | |
if not init_chat_prompt: | |
raise ValueError( | |
"An initial promt needs to be specified when setting init_chat_role." | |
) | |
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) | |
self.user_role = user_role | |
self.warmup() | |
def warmup(self): | |
logger.info(f"Warming up {self.__class__.__name__}") | |
dummy_input_text = "Write me a poem about Machine Learning." | |
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] | |
n_steps = 2 | |
for _ in range(n_steps): | |
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) | |
generate( | |
self.model, | |
self.tokenizer, | |
prompt=prompt, | |
max_tokens=self.gen_kwargs["max_new_tokens"], | |
verbose=False, | |
) | |
def process(self, prompt): | |
logger.debug("infering language model...") | |
self.chat.append({"role": self.user_role, "content": prompt}) | |
# Remove system messages if using a Gemma model | |
if "gemma" in self.model_name.lower(): | |
chat_messages = [ | |
msg for msg in self.chat.to_list() if msg["role"] != "system" | |
] | |
else: | |
chat_messages = self.chat.to_list() | |
prompt = self.tokenizer.apply_chat_template( | |
chat_messages, tokenize=False, add_generation_prompt=True | |
) | |
output = "" | |
curr_output = "" | |
for t in stream_generate( | |
self.model, | |
self.tokenizer, | |
prompt, | |
max_tokens=self.gen_kwargs["max_new_tokens"], | |
): | |
output += t | |
curr_output += t | |
if curr_output.endswith((".", "?", "!", "<|end|>")): | |
yield curr_output.replace("<|end|>", "") | |
curr_output = "" | |
generated_text = output.replace("<|end|>", "") | |
torch.mps.empty_cache() | |
self.chat.append({"role": "assistant", "content": generated_text}) | |