Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen | |
from llama_index.core.llms.callbacks import llm_completion_callback | |
from typing import Any, Iterator | |
import torch | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
from pydantic import Field, field_validator | |
import keras | |
import keras_nlp | |
# for transformers 2 (__setattr__ is used to bypass Pydantic check ) | |
"""class GemmaLLMInterface(CustomLLM): | |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs): | |
super().__init__(**kwargs) | |
object.__setattr__(self, "model_id", model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
object.__setattr__(self, "model", model) | |
object.__setattr__(self, "tokenizer", tokenizer) | |
object.__setattr__(self, "context_window", 8192) | |
object.__setattr__(self, "num_output", 2048) | |
def _format_prompt(self, message: str) -> str: | |
return ( | |
f"<start_of_turn>user\n{message}<end_of_turn>\n" | |
f"<start_of_turn>model\n" | |
) | |
@property | |
def metadata(self) -> LLMMetadata: | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_id, | |
) | |
@llm_completion_callback() | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
prompt = self._format_prompt(prompt) | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=self.num_output) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
response = response[len(prompt):].strip() | |
return CompletionResponse(text=response if response else "No response generated.") | |
@llm_completion_callback() | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
#prompt = self._format_prompt(prompt) | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer) | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
streamed_response = "" | |
for new_text in streamer: | |
if new_text: | |
streamed_response += new_text | |
yield CompletionResponse(text=streamed_response, delta=new_text) | |
if not streamed_response: | |
yield CompletionResponse(text="No response generated.", delta="No response generated.")""" | |
# for Keras | |
class GemmaLLMInterface(CustomLLM): | |
model: keras_nlp.models.GemmaCausalLM = None | |
context_window: int = 8192 | |
num_output: int = 2048 | |
model_name: str = "gemma_2" | |
def _format_prompt(self, message: str) -> str: | |
return ( | |
f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n" | |
) | |
def metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_name, | |
) | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
prompt = self._format_prompt(prompt) | |
raw_response = self.model.generate(prompt, max_length=self.num_output) | |
response = raw_response[len(prompt) :] | |
return CompletionResponse(text=response) | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
response = self.complete(prompt).text | |
for token in response: | |
response += token | |
yield CompletionResponse(text=response, delta=token) | |