gufett0 commited on
Commit
e90ba30
·
1 Parent(s): 703abf3

changed GemmaLLMInterface

Browse files
Files changed (1) hide show
  1. interface.py +24 -23
interface.py CHANGED
@@ -6,15 +6,13 @@ import torch
6
  from transformers import TextIteratorStreamer
7
  from threading import Thread
8
 
9
-
10
-
11
  class GemmaLLMInterface(CustomLLM):
12
  model: Any
13
  tokenizer: Any
14
  context_window: int = 8192
15
  num_output: int = 2048
16
  model_name: str = "gemma_2"
17
-
18
  def _format_prompt(self, message: str) -> str:
19
  return (
20
  f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
@@ -27,48 +25,51 @@ class GemmaLLMInterface(CustomLLM):
27
  num_output=self.num_output,
28
  model_name=self.model_name,
29
  )
30
-
31
- @llm_completion_callback()
32
- def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
33
  prompt = self._format_prompt(prompt)
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  self.model.to(device)
36
-
37
- # Tokenize prompt and move inputs to the correct device
38
  inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
39
-
40
- # Ensure the input doesn't exceed the maximum token length
41
  if inputs["input_ids"].shape[1] > self.context_window:
42
  inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
43
-
44
- # Create a streamer to handle token streaming
45
  streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
46
-
47
- # Generate kwargs for the model
48
  generate_kwargs = {
49
  "input_ids": inputs["input_ids"],
50
  "streamer": streamer,
51
  "max_new_tokens": self.num_output,
52
  "do_sample": True,
53
- "top_p": 0.9, # You can tweak these sampling params based on your needs
54
  "top_k": 50,
55
  "temperature": 0.7,
56
  "num_beams": 1,
57
  "repetition_penalty": 1.1,
58
  }
59
-
60
- # Launch the generation in a separate thread to stream the output
 
 
 
 
 
61
  t = Thread(target=self.model.generate, kwargs=generate_kwargs)
62
  t.start()
63
-
64
- # Collect the streamed response token by token
65
  response = ""
66
  for new_token in streamer:
67
  response += new_token
68
- yield CompletionResponse(text=response)
 
69
 
70
  @llm_completion_callback()
71
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
72
- # Use the complete method to stream the output in real-time
73
- for response in self.complete(prompt):
74
- yield response
 
 
 
 
 
6
  from transformers import TextIteratorStreamer
7
  from threading import Thread
8
 
 
 
9
  class GemmaLLMInterface(CustomLLM):
10
  model: Any
11
  tokenizer: Any
12
  context_window: int = 8192
13
  num_output: int = 2048
14
  model_name: str = "gemma_2"
15
+
16
  def _format_prompt(self, message: str) -> str:
17
  return (
18
  f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
 
25
  num_output=self.num_output,
26
  model_name=self.model_name,
27
  )
28
+
29
+ def _prepare_generation(self, prompt: str) -> tuple:
 
30
  prompt = self._format_prompt(prompt)
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  self.model.to(device)
33
+
 
34
  inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
 
 
35
  if inputs["input_ids"].shape[1] > self.context_window:
36
  inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
37
+
 
38
  streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
39
+
 
40
  generate_kwargs = {
41
  "input_ids": inputs["input_ids"],
42
  "streamer": streamer,
43
  "max_new_tokens": self.num_output,
44
  "do_sample": True,
45
+ "top_p": 0.9,
46
  "top_k": 50,
47
  "temperature": 0.7,
48
  "num_beams": 1,
49
  "repetition_penalty": 1.1,
50
  }
51
+
52
+ return streamer, generate_kwargs
53
+
54
+ @llm_completion_callback()
55
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
56
+ streamer, generate_kwargs = self._prepare_generation(prompt)
57
+
58
  t = Thread(target=self.model.generate, kwargs=generate_kwargs)
59
  t.start()
60
+
 
61
  response = ""
62
  for new_token in streamer:
63
  response += new_token
64
+
65
+ return CompletionResponse(text=response)
66
 
67
  @llm_completion_callback()
68
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
69
+ streamer, generate_kwargs = self._prepare_generation(prompt)
70
+
71
+ t = Thread(target=self.model.generate, kwargs=generate_kwargs)
72
+ t.start()
73
+
74
+ for new_token in streamer:
75
+ yield CompletionResponse(text=new_token)