chatbot-llamaindex / interface.py
gufett0's picture
changed class interface with iterator
5b579d8
raw
history blame
3.01 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator
# for transformers 2 (__setattr__ is used to bypass Pydantic check )
class GemmaLLMInterface(CustomLLM):
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, "model_id", model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
object.__setattr__(self, "model", model)
object.__setattr__(self, "tokenizer", tokenizer)
object.__setattr__(self, "context_window", 8192)
object.__setattr__(self, "num_output", 2048)
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_id,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt = self._format_prompt(prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=self.num_output)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
response = response[len(prompt):].strip()
return CompletionResponse(text=response if response else "No response generated.")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
#prompt = self._format_prompt(prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
streamed_response = ""
for new_text in streamer:
if new_text:
streamed_response += new_text
yield CompletionResponse(text=streamed_response, delta=new_text)
if not streamed_response:
yield CompletionResponse(text="No response generated.", delta="No response generated.")