# predict.py import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import pickle model_path = 'shirleylqs/mistral-snomed-classification' tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) with open(f'{model_path}/label_encoder.pkl', 'rb') as f: label_encoder = pickle.load(f) def predict_class(text): inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predict_class_id = logits.argmax(-1).item() predict_label = label_encoder.inverse_transform([predict_class_id])[0] return predict_label if __name__ == "__main__": text = "purulent discharge" predicted_label = predict_class(text) print(predicted_label)