Spaces:
Sleeping
Sleeping
changed class interface with iterator
Browse files- backend.py +2 -2
- interface.py +9 -7
backend.py
CHANGED
@@ -20,7 +20,7 @@ login(huggingface_token)
|
|
20 |
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
23 |
-
|
24 |
model = AutoModelForCausalLM.from_pretrained(
|
25 |
model_id,
|
26 |
device_map="auto",
|
@@ -28,7 +28,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
28 |
token=True)
|
29 |
|
30 |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
31 |
-
model.eval()
|
32 |
|
33 |
# what models will be used by LlamaIndex:
|
34 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
|
|
20 |
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
23 |
+
model_id = "google/gemma-2-2b-it"
|
24 |
model = AutoModelForCausalLM.from_pretrained(
|
25 |
model_id,
|
26 |
device_map="auto",
|
|
|
28 |
token=True)
|
29 |
|
30 |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
31 |
+
model.eval()
|
32 |
|
33 |
# what models will be used by LlamaIndex:
|
34 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
interface.py
CHANGED
@@ -11,15 +11,17 @@ from pydantic import Field, field_validator
|
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
-
object.__setattr__(self, "model_id", model_id) #
|
15 |
-
|
16 |
-
|
17 |
device_map="auto",
|
18 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
)
|
20 |
-
|
21 |
-
self
|
22 |
-
self
|
|
|
|
|
23 |
|
24 |
def _format_prompt(self, message: str) -> str:
|
25 |
return (
|
@@ -32,7 +34,7 @@ class GemmaLLMInterface(CustomLLM):
|
|
32 |
return LLMMetadata(
|
33 |
context_window=self.context_window,
|
34 |
num_output=self.num_output,
|
35 |
-
model_name=self.model_id, #
|
36 |
)
|
37 |
|
38 |
|
|
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
+
object.__setattr__(self, "model_id", model_id) # Bypass Pydantic for model_id
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
model_id,
|
17 |
device_map="auto",
|
18 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
)
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
21 |
+
object.__setattr__(self, "model", model) # Bypass Pydantic for model
|
22 |
+
object.__setattr__(self, "tokenizer", tokenizer) # Bypass Pydantic for tokenizer
|
23 |
+
object.__setattr__(self, "context_window", 8192)
|
24 |
+
object.__setattr__(self, "num_output", 2048)
|
25 |
|
26 |
def _format_prompt(self, message: str) -> str:
|
27 |
return (
|
|
|
34 |
return LLMMetadata(
|
35 |
context_window=self.context_window,
|
36 |
num_output=self.num_output,
|
37 |
+
model_name=self.model_id, # Returning the correct model ID
|
38 |
)
|
39 |
|
40 |
|