chatbot-llamaindex / interface.py
gufett0's picture
added gpu to GemmaLLMInterface
2d75926
raw
history blame
2.74 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any
import torch
from transformers import TextIteratorStreamer
from threading import Thread
import spaces
@spaces.GPU(duration=120)
class GemmaLLMInterface(CustomLLM):
model: Any
tokenizer: Any
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma_2"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
def _prepare_generation(self, prompt: str) -> tuple:
prompt = self._format_prompt(prompt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
if inputs["input_ids"].shape[1] > self.context_window:
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": self.num_output,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"temperature": 0.7,
"num_beams": 1,
"repetition_penalty": 1.1,
}
return streamer, generate_kwargs
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
response += new_token
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
yield CompletionResponse(text=new_token)