gufett0 commited on
Commit
0467f17
·
1 Parent(s): 997eb0b

changed class interface with iterator

Browse files
Files changed (2) hide show
  1. backend.py +2 -2
  2. interface.py +9 -7
backend.py CHANGED
@@ -20,7 +20,7 @@ login(huggingface_token)
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
- """model_id = "google/gemma-2-2b-it"
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
  device_map="auto",
@@ -28,7 +28,7 @@ model = AutoModelForCausalLM.from_pretrained(
28
  token=True)
29
 
30
  model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
31
- model.eval()"""
32
 
33
  # what models will be used by LlamaIndex:
34
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
 
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
+ model_id = "google/gemma-2-2b-it"
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
  device_map="auto",
 
28
  token=True)
29
 
30
  model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
31
+ model.eval()
32
 
33
  # what models will be used by LlamaIndex:
34
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
interface.py CHANGED
@@ -11,15 +11,17 @@ from pydantic import Field, field_validator
11
  class GemmaLLMInterface(CustomLLM):
12
  def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
- object.__setattr__(self, "model_id", model_id) # Use object.__setattr__ to bypass Pydantic restrictions
15
- self.model = AutoModelForCausalLM.from_pretrained(
16
- self.model_id,
17
  device_map="auto",
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
  )
20
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
21
- self.context_window = 8192
22
- self.num_output = 2048
 
 
23
 
24
  def _format_prompt(self, message: str) -> str:
25
  return (
@@ -32,7 +34,7 @@ class GemmaLLMInterface(CustomLLM):
32
  return LLMMetadata(
33
  context_window=self.context_window,
34
  num_output=self.num_output,
35
- model_name=self.model_id, # Passing the correct model ID here
36
  )
37
 
38
 
 
11
  class GemmaLLMInterface(CustomLLM):
12
  def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
+ object.__setattr__(self, "model_id", model_id) # Bypass Pydantic for model_id
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
  device_map="auto",
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
  )
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ object.__setattr__(self, "model", model) # Bypass Pydantic for model
22
+ object.__setattr__(self, "tokenizer", tokenizer) # Bypass Pydantic for tokenizer
23
+ object.__setattr__(self, "context_window", 8192)
24
+ object.__setattr__(self, "num_output", 2048)
25
 
26
  def _format_prompt(self, message: str) -> str:
27
  return (
 
34
  return LLMMetadata(
35
  context_window=self.context_window,
36
  num_output=self.num_output,
37
+ model_name=self.model_id, # Returning the correct model ID
38
  )
39
 
40