Spaces:
Sleeping
Sleeping
File size: 5,191 Bytes
643e1b9 231b62a 703abf3 f7aeb1e 643e1b9 f7aeb1e 643e1b9 e90ba30 643e1b9 e90ba30 643e1b9 231b62a e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 703abf3 e90ba30 643e1b9 703abf3 e90ba30 91d2747 ed51056 91d2747 ed51056 f7aeb1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any
import torch
from transformers import TextIteratorStreamer
from threading import Thread
# for transformers 2
class GemmaLLMInterface(CustomLLM):
def __init__(self, model_id: str = "google/gemma-2-2b-it", context_window: int = 8192, num_output: int = 2048):
self.model_id = model_id
self.context_window = context_window
self.num_output = num_output
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
self.model.eval()
def _format_prompt(self, message: str) -> str:
return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_id,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
formatted_prompt = self._format_prompt(prompt)
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.num_output,
do_sample=True,
temperature=0.7,
top_p=0.95,
)
response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
formatted_prompt = self._format_prompt(prompt)
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
response = ""
with torch.no_grad():
for output in self.model.generate(
**inputs,
max_new_tokens=self.num_output,
do_sample=True,
temperature=0.7,
top_p=0.95,
streamer=True,
):
token = self.tokenizer.decode(output, skip_special_tokens=True)
response += token
yield CompletionResponse(text=response, delta=token)
# for transformers 1
"""class GemmaLLMInterface(CustomLLM):
model: Any
tokenizer: Any
context_window: int = 8192
num_output: int = 2048
model_name: str = "gemma_2"
def _format_prompt(self, message: str) -> str:
return (
f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
def _prepare_generation(self, prompt: str) -> tuple:
prompt = self._format_prompt(prompt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
if inputs["input_ids"].shape[1] > self.context_window:
inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": self.num_output,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"temperature": 0.7,
"num_beams": 1,
"repetition_penalty": 1.1,
}
return streamer, generate_kwargs
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
response += new_token
return CompletionResponse(text=response)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
streamer, generate_kwargs = self._prepare_generation(prompt)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
try:
for new_token in streamer:
yield CompletionResponse(text=new_token)
except StopIteration:
return""" |