chatbot-llamaindex / interface.py
gufett0's picture
vectostoreindex
f0608de
raw
history blame
4.26 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator
import keras
import keras_nlp
# for transformers 2 (__setattr__ is used to bypass Pydantic check )
"""class GemmaLLMInterface(CustomLLM):
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, "model_id", model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
object.__setattr__(self, "model", model)
object.__setattr__(self, "tokenizer", tokenizer)
object.__setattr__(self, "context_window", 8192)
object.__setattr__(self, "num_output", 2048)
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_id,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt = self._format_prompt(prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=self.num_output)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
response = response[len(prompt):].strip()
return CompletionResponse(text=response if response else "No response generated.")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
#prompt = self._format_prompt(prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
streamed_response = ""
for new_text in streamer:
if new_text:
streamed_response += new_text
yield CompletionResponse(text=streamed_response, delta=new_text)
if not streamed_response:
yield CompletionResponse(text="No response generated.", delta="No response generated.")"""
# for Keras
class GemmaLLMInterface(CustomLLM):
model: keras_nlp.models.GemmaCausalLM = None
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma_2"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt = self._format_prompt(prompt)
raw_response = self.model.generate(prompt, max_length=self.num_output)
response = raw_response[len(prompt) :]
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
response = self.complete(prompt).text
for token in response:
response += token
yield CompletionResponse(text=response, delta=token)