import streamlit as st
import torch
from transformers import AutoTokenizer
from semviqa.ser.qatc_model import QATCForQuestionAnswering
from semviqa.tvc.model import ClaimModelForClassification
from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
from semviqa.tvc.tvc_eval import classify_claim
import time
import io
# Load models with caching
@st.cache_resource()
def load_model(model_name, model_class, is_bc=False):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
model.eval()
return tokenizer, model
# Set up page configuration
st.set_page_config(page_title="SemViQA Demo", layout="wide")
# Custom CSS cho header cố định và main container (chiều cao = viewport - 55px)
st.markdown("""
""", unsafe_allow_html=True)
# --- Fixed Header ---
# Sử dụng st.markdown để in ra phần header cố định bao gồm title, nav (radio) và subtitle
st.markdown("""
""", unsafe_allow_html=True)
# --- Main Container ---
with st.container():
st.markdown("", unsafe_allow_html=True)
# Sidebar: Global Settings (không thay đổi)
with st.sidebar.expander("⚙️ Settings", expanded=True):
tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01)
length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01)
qatc_model_name = st.selectbox("QATC Model", [
"SemViQA/qatc-infoxlm-viwikifc",
"SemViQA/qatc-infoxlm-isedsc01",
"SemViQA/qatc-vimrc-viwikifc",
"SemViQA/qatc-vimrc-isedsc01"
])
bc_model_name = st.selectbox("Binary Classification Model", [
"SemViQA/bc-xlmr-viwikifc",
"SemViQA/bc-xlmr-isedsc01",
"SemViQA/bc-infoxlm-viwikifc",
"SemViQA/bc-infoxlm-isedsc01",
"SemViQA/bc-erniem-viwikifc",
"SemViQA/bc-erniem-isedsc01"
])
tc_model_name = st.selectbox("Three-Class Classification Model", [
"SemViQA/tc-xlmr-viwikifc",
"SemViQA/tc-xlmr-isedsc01",
"SemViQA/tc-infoxlm-viwikifc",
"SemViQA/tc-infoxlm-isedsc01",
"SemViQA/tc-erniem-viwikifc",
"SemViQA/tc-erniem-isedsc01"
])
show_details = st.checkbox("Show probability details", value=False)
# Khởi tạo lịch sử kiểm chứng và kết quả mới nhất
if 'history' not in st.session_state:
st.session_state.history = []
if 'latest_result' not in st.session_state:
st.session_state.latest_result = None
# Load các mô hình đã chọn
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
# Icon cho kết quả
verdict_icons = {
"SUPPORTED": "✅",
"REFUTED": "❌",
"NEI": "⚠️"
}
# Hiển thị nội dung theo lựa chọn của navigation
if nav_option == "Verify":
st.subheader("Verify a Claim")
# Layout 2 cột: bên trái cho input, bên phải hiển thị kết quả
col_input, col_result = st.columns([2, 1])
with col_input:
claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.")
context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.")
verify_clicked = st.button("Verify", key="verify_button")
with col_result:
if verify_clicked:
with st.spinner("Loading and running verification..."):
# Hiển thị progress bar mô phỏng quá trình xử lý
progress_bar = st.progress(0)
for i in range(1, 101, 20):
time.sleep(0.1)
progress_bar.progress(i)
with torch.no_grad():
# Trích xuất bằng chứng và phân loại thông tin
evidence = extract_evidence_tfidf_qatc(
claim, context, model_qatc, tokenizer_qatc,
"cuda" if torch.cuda.is_available() else "cpu",
confidence_threshold=tfidf_threshold,
length_ratio_threshold=length_ratio_threshold
)
verdict = "NEI"
details = ""
prob3class, pred_tc = classify_claim(
claim, evidence, model_tc, tokenizer_tc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_tc != 0:
prob2class, pred_bc = classify_claim(
claim, evidence, model_bc, tokenizer_bc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_bc == 0:
verdict = "SUPPORTED"
elif prob2class > prob3class:
verdict = "REFUTED"
else:
verdict = ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
if show_details:
details = f"
3-Class Probability: {prob3class.item():.2f} - 2-Class Probability: {prob2class.item():.2f}
"
# Lưu lịch sử và kết quả kiểm chứng mới nhất
st.session_state.history.append({
"claim": claim,
"evidence": evidence,
"verdict": verdict
})
st.session_state.latest_result = {
"claim": claim,
"evidence": evidence,
"verdict": verdict,
"details": details
}
if torch.cuda.is_available():
torch.cuda.empty_cache()
res = st.session_state.latest_result
st.markdown("
Verification Result
", unsafe_allow_html=True)
st.markdown(f"""
Claim: {res['claim']}
Evidence: {res['evidence']}
{verdict_icons.get(res['verdict'], '')}{res['verdict']}
{res['details']}
""", unsafe_allow_html=True)
result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
else:
st.info("No verification result yet.")
elif nav_option == "History":
st.subheader("Verification History")
if st.session_state.history:
for idx, record in enumerate(reversed(st.session_state.history), 1):
st.markdown(f"**{idx}. Claim:** {record['claim']} \n**Result:** {verdict_icons.get(record['verdict'], '')} {record['verdict']}")
else:
st.write("No verification history yet.")
elif nav_option == "About":
st.subheader("About")
st.markdown("""
""", unsafe_allow_html=True)
st.markdown("""
**Description:**
SemViQA is a semantic QA system designed for fact-checking in Vietnamese.
It extracts evidence from the provided context and classifies the claim as **SUPPORTED**, **REFUTED**, or **NEI** (Not Enough Information) using state-of-the-art models.
""")
st.markdown("
", unsafe_allow_html=True)