GodfreyOwino commited on
Commit
3716637
·
verified ·
1 Parent(s): 0ded05a

Update: Add custom NPKPredictionModel implementation

Browse files
Files changed (2) hide show
  1. config.json +5 -4
  2. modeling_npk.py +13 -13
config.json CHANGED
@@ -2,7 +2,8 @@
2
  "model_type": "npk",
3
  "architectures": ["NPKPredictionModel"],
4
  "auto_map": {
5
- "AutoConfig": "modeling_npk.NPKConfig",
6
- "AutoModel": "modeling_npk.NPKPredictionModel"
7
- }
8
- }
 
 
2
  "model_type": "npk",
3
  "architectures": ["NPKPredictionModel"],
4
  "auto_map": {
5
+ "AutoConfig": "modeling_npk.NPKConfig",
6
+ "AutoModel": "modeling_npk.NPKPredictionModel"
7
+ },
8
+ "trust_remote_code": true
9
+ }
modeling_npk.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import pickle
3
  import pandas as pd
4
  from transformers import PreTrainedModel, PretrainedConfig
@@ -17,6 +16,18 @@ class NPKPredictionModel(PreTrainedModel):
17
  super().__init__(config)
18
  self.xgb_model = None
19
  self.label_encoder = None
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def forward(self, inputs):
22
  # Preprocess inputs
@@ -48,15 +59,4 @@ class NPKPredictionModel(PreTrainedModel):
48
  config = NPKConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
49
 
50
  model = cls(config)
51
-
52
- # Load the XGBoost model and label encoder
53
- xgb_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="npk_prediction_model.pkl")
54
- le_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="label_encoder.pkl")
55
-
56
- with open(xgb_path, 'rb') as f:
57
- model.xgb_model = pickle.load(f)
58
-
59
- with open(le_path, 'rb') as f:
60
- model.label_encoder = pickle.load(f)
61
-
62
- return model
 
 
1
  import pickle
2
  import pandas as pd
3
  from transformers import PreTrainedModel, PretrainedConfig
 
16
  super().__init__(config)
17
  self.xgb_model = None
18
  self.label_encoder = None
19
+ self._load_models()
20
+
21
+ def _load_models(self):
22
+ # Load the XGBoost model and label encoder
23
+ xgb_path = hf_hub_download(repo_id=self.config._name_or_path, filename="npk_prediction_model.pkl")
24
+ le_path = hf_hub_download(repo_id=self.config._name_or_path, filename="label_encoder.pkl")
25
+
26
+ with open(xgb_path, 'rb') as f:
27
+ self.xgb_model = pickle.load(f)
28
+
29
+ with open(le_path, 'rb') as f:
30
+ self.label_encoder = pickle.load(f)
31
 
32
  def forward(self, inputs):
33
  # Preprocess inputs
 
59
  config = NPKConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
60
 
61
  model = cls(config)
62
+ return model