from transformers import AutoTokenizer, AutoModelForCausalLM from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen from llama_index.core.llms.callbacks import llm_completion_callback from typing import Any, Iterator import torch from transformers import TextIteratorStreamer from threading import Thread from pydantic import Field, field_validator import keras import keras_nlp # for transformers 2 (__setattr__ is used to bypass Pydantic check ) """class GemmaLLMInterface(CustomLLM): def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs): super().__init__(**kwargs) object.__setattr__(self, "model_id", model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) tokenizer = AutoTokenizer.from_pretrained(model_id) object.__setattr__(self, "model", model) object.__setattr__(self, "tokenizer", tokenizer) object.__setattr__(self, "context_window", 8192) object.__setattr__(self, "num_output", 2048) def _format_prompt(self, message: str) -> str: return ( f"user\n{message}\n" f"model\n" ) @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_id, ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: prompt = self._format_prompt(prompt) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=self.num_output) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() response = response[len(prompt):].strip() return CompletionResponse(text=response if response else "No response generated.") @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: #prompt = self._format_prompt(prompt) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() streamed_response = "" for new_text in streamer: if new_text: streamed_response += new_text yield CompletionResponse(text=streamed_response, delta=new_text) if not streamed_response: yield CompletionResponse(text="No response generated.", delta="No response generated.")""" # for Keras class GemmaLLMInterface(CustomLLM): model: keras_nlp.models.GemmaCausalLM = None context_window: int = 8192 num_output: int = 2048 model_name: str = "gemma_2" def _format_prompt(self, message: str) -> str: return ( f"user\n{message}\n" f"model\n" ) @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_name, ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: prompt = self._format_prompt(prompt) raw_response = self.model.generate(prompt, max_length=self.num_output) response = raw_response[len(prompt) :] return CompletionResponse(text=response) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: response = self.complete(prompt).text for token in response: response += token yield CompletionResponse(text=response, delta=token)