Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen | |
from llama_index.core.llms.callbacks import llm_completion_callback | |
from typing import Any | |
import torch | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
import spaces | |
class GemmaLLMInterface(CustomLLM): | |
model: Any | |
tokenizer: Any | |
context_window: int = 8192 | |
num_output: int = 2048 | |
model_name: str = "gemma_2" | |
def _format_prompt(self, message: str) -> str: | |
return ( | |
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
) | |
def metadata(self) -> LLMMetadata: | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_name, | |
) | |
def _prepare_generation(self, prompt: str) -> tuple: | |
prompt = self._format_prompt(prompt) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(device) | |
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device) | |
if inputs["input_ids"].shape[1] > self.context_window: | |
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] | |
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"streamer": streamer, | |
"max_new_tokens": self.num_output, | |
"do_sample": True, | |
"top_p": 0.9, | |
"top_k": 50, | |
"temperature": 0.7, | |
"num_beams": 1, | |
"repetition_penalty": 1.1, | |
} | |
return streamer, generate_kwargs | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
streamer, generate_kwargs = self._prepare_generation(prompt) | |
t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
t.start() | |
response = "" | |
for new_token in streamer: | |
response += new_token | |
return CompletionResponse(text=response) | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
streamer, generate_kwargs = self._prepare_generation(prompt) | |
t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
t.start() | |
for new_token in streamer: | |
yield CompletionResponse(text=new_token) |