File size: 4,262 Bytes
643e1b9
 
 
57b8c08
231b62a
703abf3
 
c611543
bbe77cb
 
703abf3
cf360c7
bbe77cb
ff72627
b277c0d
cf360c7
0467f17
 
b277c0d
 
 
0467f17
cf360c7
 
0467f17
 
d3df8fd
f7aeb1e
d3df8fd
7bc4367
 
d3df8fd
f7aeb1e
 
 
 
 
 
5b579d8
f7aeb1e
997eb0b
 
f7aeb1e
 
6130d38
40986a4
 
974c8b8
b7a41e7
 
f7aeb1e
 
 
8e5c2f6
40986a4
f57e33c
40986a4
 
 
 
 
 
f57e33c
40986a4
 
 
 
f57e33c
 
bbe77cb
 
d312028
bbe77cb
 
 
 
 
f7aeb1e
bbe77cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import  CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any, Iterator
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from pydantic import Field, field_validator
import keras
import keras_nlp

# for transformers 2 (__setattr__ is used to bypass Pydantic check )
"""class GemmaLLMInterface(CustomLLM):
    def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
        super().__init__(**kwargs)
        object.__setattr__(self, "model_id", model_id)  
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        object.__setattr__(self, "model", model)  
        object.__setattr__(self, "tokenizer", tokenizer) 
        object.__setattr__(self, "context_window", 8192)
        object.__setattr__(self, "num_output", 2048)

    def _format_prompt(self, message: str) -> str:
        return (
            f"<start_of_turn>user\n{message}<end_of_turn>\n"
            f"<start_of_turn>model\n"
        )

    @property
    def metadata(self) -> LLMMetadata:
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.num_output,
            model_name=self.model_id, 
        )
        
        
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt = self._format_prompt(prompt)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(**inputs, max_new_tokens=self.num_output)
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        response = response[len(prompt):].strip()
        return CompletionResponse(text=response if response else "No response generated.")

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        #prompt = self._format_prompt(prompt)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer)
        
        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        streamed_response = ""
        for new_text in streamer:
            if new_text:
                streamed_response += new_text
                yield CompletionResponse(text=streamed_response, delta=new_text)
        
        if not streamed_response:
            yield CompletionResponse(text="No response generated.", delta="No response generated.")"""

# for Keras
class GemmaLLMInterface(CustomLLM):
    model: keras_nlp.models.GemmaCausalLM = None
    context_window: int = 8192
    num_output: int = 2048
    model_name: str = "gemma_2"

    def _format_prompt(self, message: str) -> str:
        return (
            f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n"
        )

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.num_output,
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt = self._format_prompt(prompt)
        raw_response = self.model.generate(prompt, max_length=self.num_output)
        response = raw_response[len(prompt) :]
        return CompletionResponse(text=response)

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        response = self.complete(prompt).text
        for token in response:
            response += token
            yield CompletionResponse(text=response, delta=token)