from transformers import AutoTokenizer, AutoModelForCausalLM from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen from llama_index.core.llms.callbacks import llm_completion_callback from typing import Any, Iterator import torch from transformers import TextIteratorStreamer from threading import Thread from pydantic import Field, field_validator # for transformers 2 class GemmaLLMInterface(CustomLLM): def __init__(self, model_id: str = "google/gemma-2b-it", context_window: int = 8192, num_output: int = 2048): self.model_id = model_id self.context_window = context_window self.num_output = num_output self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) self.model.eval() def _format_prompt(self, message: str) -> str: return f"user\n{message}\nmodel\n" @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_id, ) def _prepare_inputs(self, prompt: str) -> dict: formatted_prompt = self._format_prompt(prompt) inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True).to(self.model.device) if inputs["input_ids"].shape[1] > self.context_window: inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] return inputs def _generate(self, inputs: dict) -> Iterator[str]: for output in self.model.generate( **inputs, max_new_tokens=self.num_output, do_sample=True, top_p=0.9, top_k=50, temperature=0.7, num_beams=1, repetition_penalty=1.1, streamer=None, return_dict_in_generate=True, output_scores=False, ): new_tokens = output.sequences[:, inputs["input_ids"].shape[-1]:] yield self.tokenizer.decode(new_tokens[0], skip_special_tokens=True) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: inputs = self._prepare_inputs(prompt) response = "".join(self._generate(inputs)) return CompletionResponse(text=response) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: inputs = self._prepare_inputs(prompt) response = "" for new_token in self._generate(inputs): response += new_token yield CompletionResponse(text=response, delta=new_token) # for transformers 1 """class GemmaLLMInterface(CustomLLM): model: Any tokenizer: Any context_window: int = 8192 num_output: int = 2048 model_name: str = "gemma_2" def _format_prompt(self, message: str) -> str: return ( f"user\n{message}\nmodel\n" ) @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_name, ) def _prepare_generation(self, prompt: str) -> tuple: prompt = self._format_prompt(prompt) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device) if inputs["input_ids"].shape[1] > self.context_window: inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": inputs["input_ids"], "streamer": streamer, "max_new_tokens": self.num_output, "do_sample": True, "top_p": 0.9, "top_k": 50, "temperature": 0.7, "num_beams": 1, "repetition_penalty": 1.1, } return streamer, generate_kwargs @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: streamer, generate_kwargs = self._prepare_generation(prompt) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() response = "" for new_token in streamer: response += new_token return CompletionResponse(text=response) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: streamer, generate_kwargs = self._prepare_generation(prompt) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() try: for new_token in streamer: yield CompletionResponse(text=new_token) except StopIteration: return"""