|
import streamlit as st |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import matplotlib.pyplot as plt |
|
from sklearn.metrics import precision_recall_curve, auc |
|
|
|
|
|
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest") |
|
|
|
|
|
st.sidebar.title("App Navigation") |
|
page = st.sidebar.radio("Choose a feature", ["Sentiment Analysis", "Model Evaluation"]) |
|
|
|
if page == "Sentiment Analysis": |
|
|
|
st.title("Twitter Sentiment Analysis App") |
|
|
|
|
|
user_input = st.text_input("Enter a tweet to analyze:") |
|
|
|
if user_input: |
|
|
|
result = pipe(user_input) |
|
st.write("Sentiment Analysis Result:", result) |
|
|
|
elif page == "Model Evaluation": |
|
|
|
st.title("Model Precision-Recall Evaluation") |
|
|
|
|
|
st.write("### 输入数据") |
|
y_true_input = st.text_area("输入真实标签 (用逗号分隔)", "1,0,1,1,0,1") |
|
y_score_input = st.text_area("输入预测概率 (用逗号分隔)", "0.95,0.1,0.85,0.75,0.2,0.9") |
|
|
|
if y_true_input and y_score_input: |
|
try: |
|
|
|
y_true = list(map(int, y_true_input.split(","))) |
|
y_score = list(map(float, y_score_input.split(","))) |
|
|
|
|
|
if len(y_true) != len(y_score): |
|
st.error("真实标签和预测概率的长度不一致!请重新输入。") |
|
else: |
|
|
|
precision, recall, _ = precision_recall_curve(y_true, y_score) |
|
pr_auc = auc(recall, precision) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
ax.plot(recall, precision, label=f"PR Curve (AUC = {pr_auc:.2f})") |
|
ax.set_xlabel("Recall") |
|
ax.set_ylabel("Precision") |
|
ax.set_title("Precision-Recall Curve") |
|
ax.legend(loc="best") |
|
ax.grid() |
|
|
|
|
|
st.pyplot(fig) |
|
st.success(f"PR Curve AUC: {pr_auc:.2f}") |
|
except Exception as e: |
|
st.error(f"发生错误: {e}") |
|
else: |
|
st.info("请输入真实标签和预测概率以生成 PR 曲线。") |
|
|
|
|