# predict.py | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import pickle | |
model_path = 'shirleylqs/mistral-snomed-classification' | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
with open(f'{model_path}/label_encoder.pkl', 'rb') as f: | |
label_encoder = pickle.load(f) | |
def predict_class(text): | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predict_class_id = logits.argmax(-1).item() | |
predict_label = label_encoder.inverse_transform([predict_class_id])[0] | |
return predict_label | |
if __name__ == "__main__": | |
text = "purulent discharge" | |
predicted_label = predict_class(text) | |
print(predicted_label) | |