Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen | |
from llama_index.core.llms.callbacks import llm_completion_callback | |
from typing import Any | |
import torch | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
# for transformers 2 | |
class GemmaLLMInterface(CustomLLM): | |
def __init__(self, model_id: str = "google/gemma-2-2b-it", context_window: int = 8192, num_output: int = 2048): | |
self.model_id = model_id | |
self.context_window = context_window | |
self.num_output = num_output | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
) | |
self.model.eval() | |
def _format_prompt(self, message: str) -> str: | |
return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
def metadata(self) -> LLMMetadata: | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_id, | |
) | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
formatted_prompt = self._format_prompt(prompt) | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.num_output, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
) | |
response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
return CompletionResponse(text=response) | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
formatted_prompt = self._format_prompt(prompt) | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
response = "" | |
with torch.no_grad(): | |
for output in self.model.generate( | |
**inputs, | |
max_new_tokens=self.num_output, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
streamer=True, | |
): | |
token = self.tokenizer.decode(output, skip_special_tokens=True) | |
response += token | |
yield CompletionResponse(text=response, delta=token) | |
# for transformers 1 | |
"""class GemmaLLMInterface(CustomLLM): | |
model: Any | |
tokenizer: Any | |
context_window: int = 8192 | |
num_output: int = 2048 | |
model_name: str = "gemma_2" | |
def _format_prompt(self, message: str) -> str: | |
return ( | |
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
) | |
@property | |
def metadata(self) -> LLMMetadata: | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_name, | |
) | |
def _prepare_generation(self, prompt: str) -> tuple: | |
prompt = self._format_prompt(prompt) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(device) | |
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device) | |
if inputs["input_ids"].shape[1] > self.context_window: | |
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] | |
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"streamer": streamer, | |
"max_new_tokens": self.num_output, | |
"do_sample": True, | |
"top_p": 0.9, | |
"top_k": 50, | |
"temperature": 0.7, | |
"num_beams": 1, | |
"repetition_penalty": 1.1, | |
} | |
return streamer, generate_kwargs | |
@llm_completion_callback() | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
streamer, generate_kwargs = self._prepare_generation(prompt) | |
t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
t.start() | |
response = "" | |
for new_token in streamer: | |
response += new_token | |
return CompletionResponse(text=response) | |
@llm_completion_callback() | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
streamer, generate_kwargs = self._prepare_generation(prompt) | |
t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
t.start() | |
try: | |
for new_token in streamer: | |
yield CompletionResponse(text=new_token) | |
except StopIteration: | |
return""" |