aaa / app.py
Yi-666's picture
Update app.py
82baf39 verified
raw
history blame
2.45 kB
import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
# 初始化情感分析模型
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
# 定义页面功能
st.sidebar.title("App Navigation")
page = st.sidebar.radio("Choose a feature", ["Sentiment Analysis", "Model Evaluation"])
if page == "Sentiment Analysis":
# 设置页面标题
st.title("Twitter Sentiment Analysis App")
# 输入框让用户输入文本
user_input = st.text_input("Enter a tweet to analyze:")
if user_input:
# 调用模型分析用户输入
result = pipe(user_input)
st.write("Sentiment Analysis Result:", result)
elif page == "Model Evaluation":
# 设置评估页面标题
st.title("Model Precision-Recall Evaluation")
# 上传真实标签和预测概率
st.write("### 输入数据")
y_true_input = st.text_area("输入真实标签 (用逗号分隔)", "1,0,1,1,0,1")
y_score_input = st.text_area("输入预测概率 (用逗号分隔)", "0.95,0.1,0.85,0.75,0.2,0.9")
if y_true_input and y_score_input:
try:
# 解析输入
y_true = list(map(int, y_true_input.split(",")))
y_score = list(map(float, y_score_input.split(",")))
# 检查长度是否一致
if len(y_true) != len(y_score):
st.error("真实标签和预测概率的长度不一致!请重新输入。")
else:
# 计算 Precision 和 Recall
precision, recall, _ = precision_recall_curve(y_true, y_score)
pr_auc = auc(recall, precision)
# 绘制 PR 曲线
fig, ax = plt.subplots()
ax.plot(recall, precision, label=f"PR Curve (AUC = {pr_auc:.2f})")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curve")
ax.legend(loc="best")
ax.grid()
# 显示图像
st.pyplot(fig)
st.success(f"PR Curve AUC: {pr_auc:.2f}")
except Exception as e:
st.error(f"发生错误: {e}")
else:
st.info("请输入真实标签和预测概率以生成 PR 曲线。")