Spaces:
Sleeping
Sleeping
File size: 4,756 Bytes
643e1b9 57b8c08 231b62a 703abf3 c611543 703abf3 f7aeb1e 643e1b9 bb0b4d5 d3df8fd b7a41e7 d3df8fd f7aeb1e d3df8fd 7bc4367 d3df8fd f7aeb1e d3df8fd f7aeb1e d3df8fd f7aeb1e d3df8fd f7aeb1e 6130d38 7bc4367 6130d38 b7a41e7 f7aeb1e f57e33c b7a41e7 f7aeb1e 643e1b9 e90ba30 643e1b9 e90ba30 643e1b9 231b62a e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 643e1b9 703abf3 e90ba30 91d2747 ed51056 91d2747 ed51056 f7aeb1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator
# for transformers 2
class GemmaLLMInterface(CustomLLM):
model: Any = None
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma-2b-it"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt = self._format_prompt(prompt)
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_length=self.num_output)
response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
return CompletionResponse(text=response if response else "No response generated.")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
prompt = self._format_prompt(prompt)
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device)
streamed_response = ""
for output in self.model.generate(**inputs, max_length=self.num_output, streaming=True):
new_token = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
if new_token:
streamed_response += new_token
yield CompletionResponse(text=streamed_response, delta=new_token)
if not streamed_response:
yield CompletionResponse(text="No response generated.", delta="No response generated.")
# for transformers 1
"""class GemmaLLMInterface(CustomLLM):
model: Any
tokenizer: Any
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma_2"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
def _prepare_generation(self, prompt: str) -> tuple:
prompt = self._format_prompt(prompt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
if inputs["input_ids"].shape[1] > self.context_window:
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": self.num_output,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"temperature": 0.7,
"num_beams": 1,
"repetition_penalty": 1.1,
}
return streamer, generate_kwargs
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
response += new_token
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
try:
for new_token in streamer:
yield CompletionResponse(text=new_token)
except StopIteration:
return""" |