gufett0 commited on
Commit
f680265
·
1 Parent(s): c5a2ac1

added new class

Browse files
Files changed (1) hide show
  1. interface.py +38 -45
interface.py CHANGED
@@ -9,27 +9,18 @@ from pydantic import Field, field_validator
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
- model_id: str = Field(default="google/gemma-2-2b-it")
13
- context_window: int = Field(default=8192)
14
- num_output: int = Field(default=2048)
15
- tokenizer: Any = Field(default=None)
16
- model: Any = Field(default=None)
17
-
18
- # Validators are restructured to avoid deepcopy issues
19
- @field_validator('context_window', 'num_output', pre=True)
20
- def ensure_integer(cls, v):
21
- return int(v) if isinstance(v, str) else v
22
-
23
- def __init__(self, **data):
24
- super().__init__(**data)
25
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
26
  self.model = AutoModelForCausalLM.from_pretrained(
27
- self.model_id,
28
  device_map="auto",
29
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
30
  )
31
  self.model.eval()
32
-
33
  def _format_prompt(self, message: str) -> str:
34
  return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
35
 
@@ -40,42 +31,44 @@ class GemmaLLMInterface(CustomLLM):
40
  num_output=self.num_output,
41
  model_name=self.model_id,
42
  )
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @llm_completion_callback()
45
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
46
- formatted_prompt = self._format_prompt(prompt)
47
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
48
-
49
- with torch.no_grad():
50
- outputs = self.model.generate(
51
- **inputs,
52
- max_new_tokens=self.num_output,
53
- do_sample=True,
54
- temperature=0.7,
55
- top_p=0.95,
56
- )
57
-
58
- response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
59
  return CompletionResponse(text=response)
60
 
61
  @llm_completion_callback()
62
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
63
- formatted_prompt = self._format_prompt(prompt)
64
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
65
-
66
  response = ""
67
- with torch.no_grad():
68
- for output in self.model.generate(
69
- **inputs,
70
- max_new_tokens=self.num_output,
71
- do_sample=True,
72
- temperature=0.7,
73
- top_p=0.95,
74
- streamer=True,
75
- ):
76
- token = self.tokenizer.decode(output, skip_special_tokens=True)
77
- response += token
78
- yield CompletionResponse(text=response, delta=token)
79
 
80
 
81
  # for transformers 1
 
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
+ def __init__(self, model_id: str = "google/gemma-2b-it", context_window: int = 8192, num_output: int = 2048):
13
+ self.model_id = model_id
14
+ self.context_window = context_window
15
+ self.num_output = num_output
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
17
  self.model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
  device_map="auto",
20
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
21
  )
22
  self.model.eval()
23
+
24
  def _format_prompt(self, message: str) -> str:
25
  return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
26
 
 
31
  num_output=self.num_output,
32
  model_name=self.model_id,
33
  )
34
+
35
+ def _prepare_inputs(self, prompt: str) -> dict:
36
+ formatted_prompt = self._format_prompt(prompt)
37
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True).to(self.model.device)
38
+ if inputs["input_ids"].shape[1] > self.context_window:
39
+ inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
40
+ return inputs
41
+
42
+ def _generate(self, inputs: dict) -> Iterator[str]:
43
+ for output in self.model.generate(
44
+ **inputs,
45
+ max_new_tokens=self.num_output,
46
+ do_sample=True,
47
+ top_p=0.9,
48
+ top_k=50,
49
+ temperature=0.7,
50
+ num_beams=1,
51
+ repetition_penalty=1.1,
52
+ streamer=None,
53
+ return_dict_in_generate=True,
54
+ output_scores=False,
55
+ ):
56
+ new_tokens = output.sequences[:, inputs["input_ids"].shape[-1]:]
57
+ yield self.tokenizer.decode(new_tokens[0], skip_special_tokens=True)
58
+
59
  @llm_completion_callback()
60
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
61
+ inputs = self._prepare_inputs(prompt)
62
+ response = "".join(self._generate(inputs))
 
 
 
 
 
 
 
 
 
 
 
63
  return CompletionResponse(text=response)
64
 
65
  @llm_completion_callback()
66
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
67
+ inputs = self._prepare_inputs(prompt)
 
 
68
  response = ""
69
+ for new_token in self._generate(inputs):
70
+ response += new_token
71
+ yield CompletionResponse(text=response, delta=new_token)
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  # for transformers 1