Spaces:
Sleeping
Sleeping
from threading import Thread | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
pipeline, | |
TextIteratorStreamer, | |
) | |
import torch | |
from LLM.chat import Chat | |
from baseHandler import BaseHandler | |
from rich.console import Console | |
import logging | |
from nltk import sent_tokenize | |
logging.basicConfig( | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
console = Console() | |
class LanguageModelHandler(BaseHandler): | |
""" | |
Handles the language model part. | |
""" | |
def setup( | |
self, | |
model_name="microsoft/Phi-3-mini-4k-instruct", | |
device="cuda", | |
torch_dtype="float16", | |
gen_kwargs={}, | |
user_role="user", | |
chat_size=1, | |
init_chat_role=None, | |
init_chat_prompt="You are a helpful AI assistant.", | |
): | |
self.device = device | |
self.torch_dtype = getattr(torch, torch_dtype) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, torch_dtype=torch_dtype, trust_remote_code=True | |
).to(device) | |
self.pipe = pipeline( | |
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device | |
) | |
self.streamer = TextIteratorStreamer( | |
self.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
) | |
self.gen_kwargs = { | |
"streamer": self.streamer, | |
"return_full_text": False, | |
**gen_kwargs, | |
} | |
self.chat = Chat(chat_size) | |
if init_chat_role: | |
if not init_chat_prompt: | |
raise ValueError( | |
"An initial promt needs to be specified when setting init_chat_role." | |
) | |
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) | |
self.user_role = user_role | |
self.warmup() | |
def warmup(self): | |
logger.info(f"Warming up {self.__class__.__name__}") | |
dummy_input_text = "Write me a poem about Machine Learning." | |
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] | |
warmup_gen_kwargs = { | |
"min_new_tokens": self.gen_kwargs["min_new_tokens"], | |
"max_new_tokens": self.gen_kwargs["max_new_tokens"], | |
**self.gen_kwargs, | |
} | |
n_steps = 2 | |
if self.device == "cuda": | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_event.record() | |
for _ in range(n_steps): | |
thread = Thread( | |
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs | |
) | |
thread.start() | |
for _ in self.streamer: | |
pass | |
if self.device == "cuda": | |
end_event.record() | |
torch.cuda.synchronize() | |
logger.info( | |
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" | |
) | |
def process(self, prompt): | |
logger.debug("infering language model...") | |
self.chat.append({"role": self.user_role, "content": prompt}) | |
thread = Thread( | |
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs | |
) | |
thread.start() | |
if self.device == "mps": | |
generated_text = "" | |
for new_text in self.streamer: | |
generated_text += new_text | |
printable_text = generated_text | |
torch.mps.empty_cache() | |
else: | |
generated_text, printable_text = "", "" | |
for new_text in self.streamer: | |
generated_text += new_text | |
printable_text += new_text | |
sentences = sent_tokenize(printable_text) | |
if len(sentences) > 1: | |
yield (sentences[0]) | |
printable_text = new_text | |
self.chat.append({"role": "assistant", "content": generated_text}) | |
# don't forget last sentence | |
yield printable_text | |