File size: 2,743 Bytes
643e1b9
 
 
 
231b62a
703abf3
 
2d75926
703abf3
2d75926
 
 
643e1b9
 
 
 
 
 
e90ba30
643e1b9
 
 
 
 
 
 
 
 
 
 
 
e90ba30
 
643e1b9
231b62a
 
e90ba30
703abf3
 
 
e90ba30
703abf3
e90ba30
703abf3
 
 
 
 
e90ba30
703abf3
 
 
 
 
e90ba30
 
 
 
 
 
 
703abf3
 
e90ba30
703abf3
 
 
e90ba30
 
643e1b9
 
703abf3
e90ba30
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.core.llms import  CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import Any
import torch
from transformers import TextIteratorStreamer
from threading import Thread
import spaces



@spaces.GPU(duration=120)
class GemmaLLMInterface(CustomLLM):
    model: Any
    tokenizer: Any
    context_window: int = 8192
    num_output: int = 2048
    model_name: str = "gemma_2"
    
    def _format_prompt(self, message: str) -> str:
        return (
            f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
        )

    @property
    def metadata(self) -> LLMMetadata:
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.num_output,
            model_name=self.model_name,
        )
    
    def _prepare_generation(self, prompt: str) -> tuple:
        prompt = self._format_prompt(prompt)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        
        inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
        if inputs["input_ids"].shape[1] > self.context_window:
            inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
        
        streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
        
        generate_kwargs = {
            "input_ids": inputs["input_ids"],
            "streamer": streamer,
            "max_new_tokens": self.num_output,
            "do_sample": True,
            "top_p": 0.9,
            "top_k": 50,
            "temperature": 0.7,
            "num_beams": 1,
            "repetition_penalty": 1.1,
        }
        
        return streamer, generate_kwargs
           
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        streamer, generate_kwargs = self._prepare_generation(prompt)
        
        t = Thread(target=self.model.generate, kwargs=generate_kwargs)
        t.start()
        
        response = ""
        for new_token in streamer:
            response += new_token
        
        return CompletionResponse(text=response)

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        streamer, generate_kwargs = self._prepare_generation(prompt)
        
        t = Thread(target=self.model.generate, kwargs=generate_kwargs)
        t.start()
        
        for new_token in streamer:
            yield CompletionResponse(text=new_token)