from threading import Thread from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, ) import torch from LLM.chat import Chat from baseHandler import BaseHandler from rich.console import Console import logging from nltk import sent_tokenize logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) console = Console() class LanguageModelHandler(BaseHandler): """ Handles the language model part. """ def setup( self, model_name="microsoft/Phi-3-mini-4k-instruct", device="cuda", torch_dtype="float16", gen_kwargs={}, user_role="user", chat_size=1, init_chat_role=None, init_chat_prompt="You are a helpful AI assistant.", ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True ).to(device) self.pipe = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=device ) self.streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) self.gen_kwargs = { "streamer": self.streamer, "return_full_text": False, **gen_kwargs, } self.chat = Chat(chat_size) if init_chat_role: if not init_chat_prompt: raise ValueError( "An initial promt needs to be specified when setting init_chat_role." ) self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) self.user_role = user_role self.warmup() def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") dummy_input_text = "Write me a poem about Machine Learning." dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] warmup_gen_kwargs = { "min_new_tokens": self.gen_kwargs["min_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } n_steps = 2 if self.device == "cuda": start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() for _ in range(n_steps): thread = Thread( target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs ) thread.start() for _ in self.streamer: pass if self.device == "cuda": end_event.record() torch.cuda.synchronize() logger.info( f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) def process(self, prompt): logger.debug("infering language model...") self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs ) thread.start() if self.device == "mps": generated_text = "" for new_text in self.streamer: generated_text += new_text printable_text = generated_text torch.mps.empty_cache() else: generated_text, printable_text = "", "" for new_text in self.streamer: generated_text += new_text printable_text += new_text sentences = sent_tokenize(printable_text) if len(sentences) > 1: yield (sentences[0]) printable_text = new_text self.chat.append({"role": "assistant", "content": generated_text}) # don't forget last sentence yield printable_text