gufett0 commited on
Commit
ff72627
·
1 Parent(s): b277c0d

changed class interface with iterator

Browse files
Files changed (1) hide show
  1. interface.py +5 -5
interface.py CHANGED
@@ -9,15 +9,15 @@ from pydantic import Field, field_validator
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
- def __init__(self, model_name: str = "google/gemma-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
- self.model_name = model_name
15
  self.model = AutoModelForCausalLM.from_pretrained(
16
- self.model_name,
17
  device_map="auto",
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
  )
20
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
21
  self.context_window = 8192
22
  self.num_output = 2048
23
 
@@ -32,7 +32,7 @@ class GemmaLLMInterface(CustomLLM):
32
  return LLMMetadata(
33
  context_window=self.context_window,
34
  num_output=self.num_output,
35
- model_name=self.model_name,
36
  )
37
 
38
  @llm_completion_callback()
 
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
+ def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
+ self.model_id = model_id
15
  self.model = AutoModelForCausalLM.from_pretrained(
16
+ self.model_id,
17
  device_map="auto",
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
  )
20
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
21
  self.context_window = 8192
22
  self.num_output = 2048
23
 
 
32
  return LLMMetadata(
33
  context_window=self.context_window,
34
  num_output=self.num_output,
35
+ model_name=self.model_id,
36
  )
37
 
38
  @llm_completion_callback()