semviqa-demo / app.py
xuandin's picture
Update app.py
4ebd212 verified
raw
history blame
11.4 kB
import streamlit as st
import torch
from transformers import AutoTokenizer
from semviqa.ser.qatc_model import QATCForQuestionAnswering
from semviqa.tvc.model import ClaimModelForClassification
from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
from semviqa.tvc.tvc_eval import classify_claim
import time
import io
# Load models with caching
@st.cache_resource()
def load_model(model_name, model_class, is_bc=False):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
model.eval()
return tokenizer, model
# Set up page configuration
st.set_page_config(page_title="SemViQA Demo", layout="wide")
# Custom CSS cho header cố định và main container (chiều cao = viewport - 55px)
st.markdown("""
<style>
.header-container {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 55px;
background-color: #fff;
z-index: 1000;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 20px;
}
.header-title {
font-size: 14px;
font-weight: bold;
color: #4A90E2;
}
.header-nav {
margin: 0 20px;
}
.header-subtitle {
font-size: 12px;
color: #666;
text-align: right;
}
.main-container {
margin-top: 55px;
height: calc(100vh - 55px);
overflow-y: auto;
padding: 20px;
}
.stButton>button {
background-color: #4CAF50;
color: white;
font-size: 16px;
width: 100%;
border-radius: 8px;
padding: 10px;
}
.stTextArea textarea {
font-size: 16px;
min-height: 120px;
}
.result-box {
background-color: #f9f9f9;
padding: 20px;
border-radius: 10px;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
margin-top: 20px;
}
.verdict {
font-size: 24px;
font-weight: bold;
margin: 0;
display: flex;
align-items: center;
}
.verdict-icon {
margin-right: 10px;
}
</style>
""", unsafe_allow_html=True)
# --- Fixed Header ---
# Sử dụng st.markdown để in ra phần header cố định bao gồm title, nav (radio) và subtitle
st.markdown("""
<div class='header-container'>
<div class='header-title'>SemViQA: Semantic Fact-Checking System for Vietnamese</div>
<div class='header-nav'>
""", unsafe_allow_html=True)
# Navigation: sử dụng st.radio để chuyển đổi các trang (hiển thị theo dạng ngang)
nav_option = st.radio("", ["Verify", "History", "About"], horizontal=True, key="nav")
st.markdown("""
</div>
<div class='header-subtitle'>Enter a claim and context to verify its accuracy</div>
</div>
""", unsafe_allow_html=True)
# --- Main Container ---
with st.container():
st.markdown("<div class='main-container'>", unsafe_allow_html=True)
# Sidebar: Global Settings (không thay đổi)
with st.sidebar.expander("⚙️ Settings", expanded=True):
tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01)
length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01)
qatc_model_name = st.selectbox("QATC Model", [
"SemViQA/qatc-infoxlm-viwikifc",
"SemViQA/qatc-infoxlm-isedsc01",
"SemViQA/qatc-vimrc-viwikifc",
"SemViQA/qatc-vimrc-isedsc01"
])
bc_model_name = st.selectbox("Binary Classification Model", [
"SemViQA/bc-xlmr-viwikifc",
"SemViQA/bc-xlmr-isedsc01",
"SemViQA/bc-infoxlm-viwikifc",
"SemViQA/bc-infoxlm-isedsc01",
"SemViQA/bc-erniem-viwikifc",
"SemViQA/bc-erniem-isedsc01"
])
tc_model_name = st.selectbox("Three-Class Classification Model", [
"SemViQA/tc-xlmr-viwikifc",
"SemViQA/tc-xlmr-isedsc01",
"SemViQA/tc-infoxlm-viwikifc",
"SemViQA/tc-infoxlm-isedsc01",
"SemViQA/tc-erniem-viwikifc",
"SemViQA/tc-erniem-isedsc01"
])
show_details = st.checkbox("Show probability details", value=False)
# Khởi tạo lịch sử kiểm chứng và kết quả mới nhất
if 'history' not in st.session_state:
st.session_state.history = []
if 'latest_result' not in st.session_state:
st.session_state.latest_result = None
# Load các mô hình đã chọn
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
# Icon cho kết quả
verdict_icons = {
"SUPPORTED": "✅",
"REFUTED": "❌",
"NEI": "⚠️"
}
# Hiển thị nội dung theo lựa chọn của navigation
if nav_option == "Verify":
st.subheader("Verify a Claim")
# Layout 2 cột: bên trái cho input, bên phải hiển thị kết quả
col_input, col_result = st.columns([2, 1])
with col_input:
claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.")
context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.")
verify_clicked = st.button("Verify", key="verify_button")
with col_result:
if verify_clicked:
with st.spinner("Loading and running verification..."):
# Hiển thị progress bar mô phỏng quá trình xử lý
progress_bar = st.progress(0)
for i in range(1, 101, 20):
time.sleep(0.1)
progress_bar.progress(i)
with torch.no_grad():
# Trích xuất bằng chứng và phân loại thông tin
evidence = extract_evidence_tfidf_qatc(
claim, context, model_qatc, tokenizer_qatc,
"cuda" if torch.cuda.is_available() else "cpu",
confidence_threshold=tfidf_threshold,
length_ratio_threshold=length_ratio_threshold
)
verdict = "NEI"
details = ""
prob3class, pred_tc = classify_claim(
claim, evidence, model_tc, tokenizer_tc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_tc != 0:
prob2class, pred_bc = classify_claim(
claim, evidence, model_bc, tokenizer_bc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_bc == 0:
verdict = "SUPPORTED"
elif prob2class > prob3class:
verdict = "REFUTED"
else:
verdict = ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
if show_details:
details = f"<p><strong>3-Class Probability:</strong> {prob3class.item():.2f} - <strong>2-Class Probability:</strong> {prob2class.item():.2f}</p>"
# Lưu lịch sử và kết quả kiểm chứng mới nhất
st.session_state.history.append({
"claim": claim,
"evidence": evidence,
"verdict": verdict
})
st.session_state.latest_result = {
"claim": claim,
"evidence": evidence,
"verdict": verdict,
"details": details
}
if torch.cuda.is_available():
torch.cuda.empty_cache()
res = st.session_state.latest_result
st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
st.markdown(f"""
<div class='result-box'>
<p><strong>Claim:</strong> {res['claim']}</p>
<p><strong>Evidence:</strong> {res['evidence']}</p>
<p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
{res['details']}
</div>
""", unsafe_allow_html=True)
result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
else:
st.info("No verification result yet.")
elif nav_option == "History":
st.subheader("Verification History")
if st.session_state.history:
for idx, record in enumerate(reversed(st.session_state.history), 1):
st.markdown(f"**{idx}. Claim:** {record['claim']} \n**Result:** {verdict_icons.get(record['verdict'], '')} {record['verdict']}")
else:
st.write("No verification history yet.")
elif nav_option == "About":
st.subheader("About")
st.markdown("""
<p align="center">
<a href="https://arxiv.org/abs/2503.00955">
<img src="https://img.shields.io/badge/arXiv-2411.00918-red?style=flat&label=arXiv">
</a>
<a href="https://huggingface.co/SemViQA">
<img src="https://img.shields.io/badge/Hugging%20Face-Model-yellow?style=flat">
</a>
<a href="https://pypi.org/project/SemViQA">
<img src="https://img.shields.io/pypi/v/SemViQA?color=blue&label=PyPI">
</a>
<a href="https://github.com/DAVID-NGUYEN-S16/SemViQA">
<img src="https://img.shields.io/github/stars/DAVID-NGUYEN-S16/SemViQA?style=social">
</a>
</p>
""", unsafe_allow_html=True)
st.markdown("""
**Description:**
SemViQA is a semantic QA system designed for fact-checking in Vietnamese.
It extracts evidence from the provided context and classifies the claim as **SUPPORTED**, **REFUTED**, or **NEI** (Not Enough Information) using state-of-the-art models.
""")
st.markdown("</div>", unsafe_allow_html=True)