|
import pickle |
|
import pandas as pd |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from huggingface_hub import hf_hub_download |
|
|
|
class NPKConfig(PretrainedConfig): |
|
model_type = "npk" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
class NPKPredictionModel(PreTrainedModel): |
|
config_class = NPKConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.xgb_model = None |
|
self.label_encoder = None |
|
self._load_models() |
|
|
|
def _load_models(self): |
|
|
|
xgb_path = hf_hub_download(repo_id=self.config._name_or_path, filename="npk_prediction_model.pkl") |
|
le_path = hf_hub_download(repo_id=self.config._name_or_path, filename="label_encoder.pkl") |
|
|
|
with open(xgb_path, 'rb') as f: |
|
self.xgb_model = pickle.load(f) |
|
|
|
with open(le_path, 'rb') as f: |
|
self.label_encoder = pickle.load(f) |
|
|
|
def forward(self, inputs): |
|
|
|
processed_inputs = {} |
|
for key, value in inputs.items(): |
|
if isinstance(value, list): |
|
processed_inputs[key] = value[0] if value else None |
|
else: |
|
processed_inputs[key] = value |
|
|
|
crop_name = processed_inputs['crop_name'] |
|
processed_inputs['crop_name'] = self.label_encoder.transform([crop_name])[0] |
|
|
|
input_df = pd.DataFrame([processed_inputs]) |
|
|
|
|
|
prediction = self.xgb_model.predict(input_df) |
|
|
|
return { |
|
'Nitrogen Need': float(prediction[0][0]), |
|
'Phosphorus Need': float(prediction[0][1]), |
|
'Potassium Need': float(prediction[0][2]) |
|
} |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
config = kwargs.pop("config", None) |
|
if config is None: |
|
config = NPKConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
model = cls(config) |
|
return model |