Spaces:
Sleeping
Sleeping
File size: 3,007 Bytes
643e1b9 57b8c08 231b62a 703abf3 c611543 703abf3 cf360c7 643e1b9 ff72627 b277c0d cf360c7 0467f17 b277c0d 0467f17 cf360c7 0467f17 d3df8fd f7aeb1e d3df8fd 7bc4367 d3df8fd f7aeb1e 5b579d8 f7aeb1e 997eb0b f7aeb1e 6130d38 40986a4 974c8b8 b7a41e7 f7aeb1e 8e5c2f6 40986a4 f57e33c 40986a4 f57e33c 40986a4 f57e33c b7a41e7 f7aeb1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator
# for transformers 2 (__setattr__ is used to bypass Pydantic check )
class GemmaLLMInterface(CustomLLM):
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, "model_id", model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
object.__setattr__(self, "model", model)
object.__setattr__(self, "tokenizer", tokenizer)
object.__setattr__(self, "context_window", 8192)
object.__setattr__(self, "num_output", 2048)
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_id,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt = self._format_prompt(prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=self.num_output)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
response = response[len(prompt):].strip()
return CompletionResponse(text=response if response else "No response generated.")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
#prompt = self._format_prompt(prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
streamed_response = ""
for new_text in streamer:
if new_text:
streamed_response += new_text
yield CompletionResponse(text=streamed_response, delta=new_text)
if not streamed_response:
yield CompletionResponse(text="No response generated.", delta="No response generated.")
|