File size: 3,007 Bytes
643e1b9
 
 
57b8c08
231b62a
703abf3
 
c611543
703abf3
cf360c7
643e1b9
ff72627
b277c0d
cf360c7
0467f17
 
b277c0d
 
 
0467f17
cf360c7
 
0467f17
 
d3df8fd
f7aeb1e
d3df8fd
7bc4367
 
d3df8fd
f7aeb1e
 
 
 
 
 
5b579d8
f7aeb1e
997eb0b
 
f7aeb1e
 
6130d38
40986a4
 
974c8b8
b7a41e7
 
f7aeb1e
 
 
8e5c2f6
40986a4
f57e33c
40986a4
 
 
 
 
 
f57e33c
40986a4
 
 
 
f57e33c
 
b7a41e7
f7aeb1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import  CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator

# for transformers 2 (__setattr__ is used to bypass Pydantic check )
class GemmaLLMInterface(CustomLLM):
    def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
        super().__init__(**kwargs)
        object.__setattr__(self, "model_id", model_id)  
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        object.__setattr__(self, "model", model)  
        object.__setattr__(self, "tokenizer", tokenizer) 
        object.__setattr__(self, "context_window", 8192)
        object.__setattr__(self, "num_output", 2048)

    def _format_prompt(self, message: str) -> str:
        return (
            f"<start_of_turn>user\n{message}<end_of_turn>\n"
            f"<start_of_turn>model\n"
        )

    @property
    def metadata(self) -> LLMMetadata:
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.num_output,
            model_name=self.model_id, 
        )
        
        
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt = self._format_prompt(prompt)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(**inputs, max_new_tokens=self.num_output)
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        response = response[len(prompt):].strip()
        return CompletionResponse(text=response if response else "No response generated.")

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        #prompt = self._format_prompt(prompt)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer)
        
        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        streamed_response = ""
        for new_text in streamer:
            if new_text:
                streamed_response += new_text
                yield CompletionResponse(text=streamed_response, delta=new_text)
        
        if not streamed_response:
            yield CompletionResponse(text="No response generated.", delta="No response generated.")