from transformers import AutoTokenizer, AutoModelForCausalLM from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen from llama_index.core.llms.callbacks import llm_completion_callback from typing import Any import torch from transformers import TextIteratorStreamer from threading import Thread import spaces @spaces.GPU(duration=120) class GemmaLLMInterface(CustomLLM): model: Any tokenizer: Any context_window: int = 8192 num_output: int = 2048 model_name: str = "gemma_2" def _format_prompt(self, message: str) -> str: return ( f"user\n{message}\nmodel\n" ) @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_name, ) def _prepare_generation(self, prompt: str) -> tuple: prompt = self._format_prompt(prompt) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device) if inputs["input_ids"].shape[1] > self.context_window: inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": inputs["input_ids"], "streamer": streamer, "max_new_tokens": self.num_output, "do_sample": True, "top_p": 0.9, "top_k": 50, "temperature": 0.7, "num_beams": 1, "repetition_penalty": 1.1, } return streamer, generate_kwargs @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: streamer, generate_kwargs = self._prepare_generation(prompt) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() response = "" for new_token in streamer: response += new_token return CompletionResponse(text=response) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: streamer, generate_kwargs = self._prepare_generation(prompt) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: yield CompletionResponse(text=new_token)