Spaces:
Sleeping
Sleeping
changed class interface with iterator
Browse files- interface.py +5 -5
interface.py
CHANGED
@@ -9,15 +9,15 @@ from pydantic import Field, field_validator
|
|
9 |
|
10 |
# for transformers 2
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
-
def __init__(self,
|
13 |
super().__init__(**kwargs)
|
14 |
-
self.
|
15 |
self.model = AutoModelForCausalLM.from_pretrained(
|
16 |
-
self.
|
17 |
device_map="auto",
|
18 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
)
|
20 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.
|
21 |
self.context_window = 8192
|
22 |
self.num_output = 2048
|
23 |
|
@@ -32,7 +32,7 @@ class GemmaLLMInterface(CustomLLM):
|
|
32 |
return LLMMetadata(
|
33 |
context_window=self.context_window,
|
34 |
num_output=self.num_output,
|
35 |
-
model_name=self.
|
36 |
)
|
37 |
|
38 |
@llm_completion_callback()
|
|
|
9 |
|
10 |
# for transformers 2
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
+
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
+
self.model_id = model_id
|
15 |
self.model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
self.model_id,
|
17 |
device_map="auto",
|
18 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
)
|
20 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
21 |
self.context_window = 8192
|
22 |
self.num_output = 2048
|
23 |
|
|
|
32 |
return LLMMetadata(
|
33 |
context_window=self.context_window,
|
34 |
num_output=self.num_output,
|
35 |
+
model_name=self.model_id,
|
36 |
)
|
37 |
|
38 |
@llm_completion_callback()
|