chatbot-llamaindex / interface.py
gufett0's picture
added new class
f7aeb1e
raw
history blame
5.19 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any
import torch
from transformers import TextIteratorStreamer
from threading import Thread
# for transformers 2
class GemmaLLMInterface(CustomLLM):
def __init__(self, model_id: str = "google/gemma-2-2b-it", context_window: int = 8192, num_output: int = 2048):
self.model_id = model_id
self.context_window = context_window
self.num_output = num_output
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
self.model.eval()
def _format_prompt(self, message: str) -> str:
return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_id,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
formatted_prompt = self._format_prompt(prompt)
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.num_output,
do_sample=True,
temperature=0.7,
top_p=0.95,
)
response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
formatted_prompt = self._format_prompt(prompt)
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
response = ""
with torch.no_grad():
for output in self.model.generate(
**inputs,
max_new_tokens=self.num_output,
do_sample=True,
temperature=0.7,
top_p=0.95,
streamer=True,
):
token = self.tokenizer.decode(output, skip_special_tokens=True)
response += token
yield CompletionResponse(text=response, delta=token)
# for transformers 1
"""class GemmaLLMInterface(CustomLLM):
model: Any
tokenizer: Any
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma_2"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
def _prepare_generation(self, prompt: str) -> tuple:
prompt = self._format_prompt(prompt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
if inputs["input_ids"].shape[1] > self.context_window:
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": self.num_output,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"temperature": 0.7,
"num_beams": 1,
"repetition_penalty": 1.1,
}
return streamer, generate_kwargs
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
response += new_token
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
try:
for new_token in streamer:
yield CompletionResponse(text=new_token)
except StopIteration:
return"""