|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("saved_models/model_20240302-214915_lr1e-05_optAdamW_lossBCEWithLogitsLoss_batch16_epoch10.pt") |
|
model = AutoModelForSequenceClassification.from_pretrained("saved_models/model_20240302-214915_lr1e-05_optAdamW_lossBCEWithLogitsLoss_batch16_epoch10.pt", num_labels=8) |
|
model.eval() |
|
|
|
def predict(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.softmax(logits, dim=1).squeeze() |
|
|
|
sas_probs = probabilities[:4] |
|
sds_probs = probabilities[4:] |
|
return sas_probs, sds_probs |
|
|
|
|
|
st.title("Multi-label Classification App") |
|
|
|
|
|
user_input = st.text_area("Enter text here", "Type something...") |
|
|
|
if st.button("Predict"): |
|
|
|
sas_probs, sds_probs = predict(user_input) |
|
st.write("SAS_Class probabilities:", sas_probs.numpy()) |
|
st.write("SDS_Class probabilities:", sds_probs.numpy()) |
|
|