import pickle import pandas as pd from transformers import PreTrainedModel, PretrainedConfig from huggingface_hub import hf_hub_download class NPKConfig(PretrainedConfig): model_type = "npk" def __init__(self, **kwargs): super().__init__(**kwargs) class NPKPredictionModel(PreTrainedModel): config_class = NPKConfig def __init__(self, config): super().__init__(config) self.xgb_model = None self.label_encoder = None self._load_models() def _load_models(self): # Load the XGBoost model and label encoder xgb_path = hf_hub_download(repo_id=self.config._name_or_path, filename="npk_prediction_model.pkl") le_path = hf_hub_download(repo_id=self.config._name_or_path, filename="label_encoder.pkl") with open(xgb_path, 'rb') as f: self.xgb_model = pickle.load(f) with open(le_path, 'rb') as f: self.label_encoder = pickle.load(f) def forward(self, inputs): # Preprocess inputs processed_inputs = {} for key, value in inputs.items(): if isinstance(value, list): processed_inputs[key] = value[0] if value else None else: processed_inputs[key] = value crop_name = processed_inputs['crop_name'] processed_inputs['crop_name'] = self.label_encoder.transform([crop_name])[0] input_df = pd.DataFrame([processed_inputs]) # Make prediction prediction = self.xgb_model.predict(input_df) return { 'Nitrogen Need': float(prediction[0][0]), 'Phosphorus Need': float(prediction[0][1]), 'Potassium Need': float(prediction[0][2]) } @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop("config", None) if config is None: config = NPKConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) return model