chatbot-llamaindex / interface.py
gufett0's picture
changed class interface
f57e33c
raw
history blame
4.76 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator
# for transformers 2
class GemmaLLMInterface(CustomLLM):
model: Any = None
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma-2b-it"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt = self._format_prompt(prompt)
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_length=self.num_output)
response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
return CompletionResponse(text=response if response else "No response generated.")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
prompt = self._format_prompt(prompt)
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device)
streamed_response = ""
for output in self.model.generate(**inputs, max_length=self.num_output, streaming=True):
new_token = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
if new_token:
streamed_response += new_token
yield CompletionResponse(text=streamed_response, delta=new_token)
if not streamed_response:
yield CompletionResponse(text="No response generated.", delta="No response generated.")
# for transformers 1
"""class GemmaLLMInterface(CustomLLM):
model: Any
tokenizer: Any
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma_2"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
def _prepare_generation(self, prompt: str) -> tuple:
prompt = self._format_prompt(prompt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
if inputs["input_ids"].shape[1] > self.context_window:
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": self.num_output,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"temperature": 0.7,
"num_beams": 1,
"repetition_penalty": 1.1,
}
return streamer, generate_kwargs
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
response += new_token
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
try:
for new_token in streamer:
yield CompletionResponse(text=new_token)
except StopIteration:
return"""