from transformers import AutoTokenizer, AutoModelForCausalLM from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen from llama_index.core.llms.callbacks import llm_completion_callback from typing import Any, Iterator import torch from transformers import TextIteratorStreamer from threading import Thread from pydantic import Field, field_validator # for transformers 2 class GemmaLLMInterface(CustomLLM): model: Any = None context_window: int = 8192 num_output: int = 2048 model_name: str = "gemma-2b-it" def _format_prompt(self, message: str) -> str: return ( f"user\n{message}\n" f"model\n" ) @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_name, ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: prompt = self._format_prompt(prompt) inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_length=self.num_output) response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True) response = response[len(prompt):].strip() return CompletionResponse(text=response if response else "No response generated.") @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: prompt = self._format_prompt(prompt) inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device) streamed_response = "" for output in self.model.generate(**inputs, max_length=self.num_output, streaming=True): new_token = self.model.tokenizer.decode(output[0], skip_special_tokens=True) if new_token: streamed_response += new_token yield CompletionResponse(text=streamed_response, delta=new_token) if not streamed_response: yield CompletionResponse(text="No response generated.", delta="No response generated.") # for transformers 1 """class GemmaLLMInterface(CustomLLM): model: Any tokenizer: Any context_window: int = 8192 num_output: int = 2048 model_name: str = "gemma_2" def _format_prompt(self, message: str) -> str: return ( f"user\n{message}\nmodel\n" ) @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_name, ) def _prepare_generation(self, prompt: str) -> tuple: prompt = self._format_prompt(prompt) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device) if inputs["input_ids"].shape[1] > self.context_window: inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": inputs["input_ids"], "streamer": streamer, "max_new_tokens": self.num_output, "do_sample": True, "top_p": 0.9, "top_k": 50, "temperature": 0.7, "num_beams": 1, "repetition_penalty": 1.1, } return streamer, generate_kwargs @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: streamer, generate_kwargs = self._prepare_generation(prompt) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() response = "" for new_token in streamer: response += new_token return CompletionResponse(text=response) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: streamer, generate_kwargs = self._prepare_generation(prompt) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() try: for new_token in streamer: yield CompletionResponse(text=new_token) except StopIteration: return"""