gufett0 commited on
Commit
5b579d8
·
1 Parent(s): 716b08f

changed class interface with iterator

Browse files
Files changed (1) hide show
  1. interface.py +1 -72
interface.py CHANGED
@@ -34,7 +34,7 @@ class GemmaLLMInterface(CustomLLM):
34
  return LLMMetadata(
35
  context_window=self.context_window,
36
  num_output=self.num_output,
37
- model_name=self.model_id, # Returning the correct model ID
38
  )
39
 
40
 
@@ -67,74 +67,3 @@ class GemmaLLMInterface(CustomLLM):
67
  if not streamed_response:
68
  yield CompletionResponse(text="No response generated.", delta="No response generated.")
69
 
70
- # for transformers 1
71
- """class GemmaLLMInterface(CustomLLM):
72
- model: Any
73
- tokenizer: Any
74
- context_window: int = 8192
75
- num_output: int = 2048
76
- model_name: str = "gemma_2"
77
-
78
- def _format_prompt(self, message: str) -> str:
79
- return (
80
- f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
81
- )
82
-
83
- @property
84
- def metadata(self) -> LLMMetadata:
85
- return LLMMetadata(
86
- context_window=self.context_window,
87
- num_output=self.num_output,
88
- model_name=self.model_name,
89
- )
90
-
91
- def _prepare_generation(self, prompt: str) -> tuple:
92
- prompt = self._format_prompt(prompt)
93
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
- self.model.to(device)
95
-
96
- inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
97
- if inputs["input_ids"].shape[1] > self.context_window:
98
- inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
99
-
100
- streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
101
-
102
- generate_kwargs = {
103
- "input_ids": inputs["input_ids"],
104
- "streamer": streamer,
105
- "max_new_tokens": self.num_output,
106
- "do_sample": True,
107
- "top_p": 0.9,
108
- "top_k": 50,
109
- "temperature": 0.7,
110
- "num_beams": 1,
111
- "repetition_penalty": 1.1,
112
- }
113
-
114
- return streamer, generate_kwargs
115
-
116
- @llm_completion_callback()
117
- def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
118
- streamer, generate_kwargs = self._prepare_generation(prompt)
119
-
120
- t = Thread(target=self.model.generate, kwargs=generate_kwargs)
121
- t.start()
122
-
123
- response = ""
124
- for new_token in streamer:
125
- response += new_token
126
-
127
- return CompletionResponse(text=response)
128
-
129
- @llm_completion_callback()
130
- def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
131
- streamer, generate_kwargs = self._prepare_generation(prompt)
132
-
133
- t = Thread(target=self.model.generate, kwargs=generate_kwargs)
134
- t.start()
135
-
136
- try:
137
- for new_token in streamer:
138
- yield CompletionResponse(text=new_token)
139
- except StopIteration:
140
- return"""
 
34
  return LLMMetadata(
35
  context_window=self.context_window,
36
  num_output=self.num_output,
37
+ model_name=self.model_id,
38
  )
39
 
40
 
 
67
  if not streamed_response:
68
  yield CompletionResponse(text="No response generated.", delta="No response generated.")
69