NPK_prediction_model2 / modeling_npk.py
GodfreyOwino's picture
Update: Add custom NPKPredictionModel implementation
3716637 verified
raw
history blame
2.1 kB
import pickle
import pandas as pd
from transformers import PreTrainedModel, PretrainedConfig
from huggingface_hub import hf_hub_download
class NPKConfig(PretrainedConfig):
model_type = "npk"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class NPKPredictionModel(PreTrainedModel):
config_class = NPKConfig
def __init__(self, config):
super().__init__(config)
self.xgb_model = None
self.label_encoder = None
self._load_models()
def _load_models(self):
# Load the XGBoost model and label encoder
xgb_path = hf_hub_download(repo_id=self.config._name_or_path, filename="npk_prediction_model.pkl")
le_path = hf_hub_download(repo_id=self.config._name_or_path, filename="label_encoder.pkl")
with open(xgb_path, 'rb') as f:
self.xgb_model = pickle.load(f)
with open(le_path, 'rb') as f:
self.label_encoder = pickle.load(f)
def forward(self, inputs):
# Preprocess inputs
processed_inputs = {}
for key, value in inputs.items():
if isinstance(value, list):
processed_inputs[key] = value[0] if value else None
else:
processed_inputs[key] = value
crop_name = processed_inputs['crop_name']
processed_inputs['crop_name'] = self.label_encoder.transform([crop_name])[0]
input_df = pd.DataFrame([processed_inputs])
# Make prediction
prediction = self.xgb_model.predict(input_df)
return {
'Nitrogen Need': float(prediction[0][0]),
'Phosphorus Need': float(prediction[0][1]),
'Potassium Need': float(prediction[0][2])
}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
if config is None:
config = NPKConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
return model