import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# 加载模型和tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("your_model_directory") | |
model = AutoModelForSequenceClassification.from_pretrained("your_model_directory", num_labels=8) | |
model.eval() | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1).squeeze() | |
# 假设每个类别(SAS_Class和SDS_Class)都有4个概率值 | |
sas_probs = probabilities[:4] # 获取SAS_Class的概率 | |
sds_probs = probabilities[4:] # 获取SDS_Class的概率 | |
return sas_probs, sds_probs | |
# 创建Streamlit应用 | |
st.title("Multi-label Classification App") | |
# 用户输入文本 | |
user_input = st.text_area("Enter text here", "Type something...") | |
if st.button("Predict"): | |
# 显示预测结果 | |
sas_probs, sds_probs = predict(user_input) | |
st.write("SAS_Class probabilities:", sas_probs.numpy()) | |
st.write("SDS_Class probabilities:", sds_probs.numpy()) | |