import streamlit as st import torch from transformers import AutoTokenizer from semviqa.ser.qatc_model import QATCForQuestionAnswering from semviqa.tvc.model import ClaimModelForClassification from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc from semviqa.tvc.tvc_eval import classify_claim import time import io # Load models with caching @st.cache_resource() def load_model(model_name, model_class, is_bc=False): tokenizer = AutoTokenizer.from_pretrained(model_name) model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2) model.eval() return tokenizer, model # Set up page configuration st.set_page_config(page_title="SemViQA Demo", layout="wide") # Custom CSS cho header cố định và main container (chiều cao = viewport - 55px) st.markdown(""" """, unsafe_allow_html=True) # --- Fixed Header --- # Sử dụng st.markdown để in ra phần header cố định bao gồm title, nav (radio) và subtitle st.markdown("""
SemViQA: Semantic Fact-Checking System for Vietnamese
""", unsafe_allow_html=True) # Navigation: sử dụng st.radio để chuyển đổi các trang (hiển thị theo dạng ngang) nav_option = st.radio("", ["Verify", "History", "About"], horizontal=True, key="nav") st.markdown("""
Enter a claim and context to verify its accuracy
""", unsafe_allow_html=True) # --- Main Container --- with st.container(): st.markdown("
", unsafe_allow_html=True) # Sidebar: Global Settings (không thay đổi) with st.sidebar.expander("⚙️ Settings", expanded=True): tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01) length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01) qatc_model_name = st.selectbox("QATC Model", [ "SemViQA/qatc-infoxlm-viwikifc", "SemViQA/qatc-infoxlm-isedsc01", "SemViQA/qatc-vimrc-viwikifc", "SemViQA/qatc-vimrc-isedsc01" ]) bc_model_name = st.selectbox("Binary Classification Model", [ "SemViQA/bc-xlmr-viwikifc", "SemViQA/bc-xlmr-isedsc01", "SemViQA/bc-infoxlm-viwikifc", "SemViQA/bc-infoxlm-isedsc01", "SemViQA/bc-erniem-viwikifc", "SemViQA/bc-erniem-isedsc01" ]) tc_model_name = st.selectbox("Three-Class Classification Model", [ "SemViQA/tc-xlmr-viwikifc", "SemViQA/tc-xlmr-isedsc01", "SemViQA/tc-infoxlm-viwikifc", "SemViQA/tc-infoxlm-isedsc01", "SemViQA/tc-erniem-viwikifc", "SemViQA/tc-erniem-isedsc01" ]) show_details = st.checkbox("Show probability details", value=False) # Khởi tạo lịch sử kiểm chứng và kết quả mới nhất if 'history' not in st.session_state: st.session_state.history = [] if 'latest_result' not in st.session_state: st.session_state.latest_result = None # Load các mô hình đã chọn tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering) tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True) tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification) # Icon cho kết quả verdict_icons = { "SUPPORTED": "✅", "REFUTED": "❌", "NEI": "⚠️" } # Hiển thị nội dung theo lựa chọn của navigation if nav_option == "Verify": st.subheader("Verify a Claim") # Layout 2 cột: bên trái cho input, bên phải hiển thị kết quả col_input, col_result = st.columns([2, 1]) with col_input: claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.") context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.") verify_clicked = st.button("Verify", key="verify_button") with col_result: if verify_clicked: with st.spinner("Loading and running verification..."): # Hiển thị progress bar mô phỏng quá trình xử lý progress_bar = st.progress(0) for i in range(1, 101, 20): time.sleep(0.1) progress_bar.progress(i) with torch.no_grad(): # Trích xuất bằng chứng và phân loại thông tin evidence = extract_evidence_tfidf_qatc( claim, context, model_qatc, tokenizer_qatc, "cuda" if torch.cuda.is_available() else "cpu", confidence_threshold=tfidf_threshold, length_ratio_threshold=length_ratio_threshold ) verdict = "NEI" details = "" prob3class, pred_tc = classify_claim( claim, evidence, model_tc, tokenizer_tc, "cuda" if torch.cuda.is_available() else "cpu" ) if pred_tc != 0: prob2class, pred_bc = classify_claim( claim, evidence, model_bc, tokenizer_bc, "cuda" if torch.cuda.is_available() else "cpu" ) if pred_bc == 0: verdict = "SUPPORTED" elif prob2class > prob3class: verdict = "REFUTED" else: verdict = ["NEI", "SUPPORTED", "REFUTED"][pred_tc] if show_details: details = f"

3-Class Probability: {prob3class.item():.2f} - 2-Class Probability: {prob2class.item():.2f}

" # Lưu lịch sử và kết quả kiểm chứng mới nhất st.session_state.history.append({ "claim": claim, "evidence": evidence, "verdict": verdict }) st.session_state.latest_result = { "claim": claim, "evidence": evidence, "verdict": verdict, "details": details } if torch.cuda.is_available(): torch.cuda.empty_cache() res = st.session_state.latest_result st.markdown("

Verification Result

", unsafe_allow_html=True) st.markdown(f"""

Claim: {res['claim']}

Evidence: {res['evidence']}

{verdict_icons.get(res['verdict'], '')}{res['verdict']}

{res['details']}
""", unsafe_allow_html=True) result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}" st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain") else: st.info("No verification result yet.") elif nav_option == "History": st.subheader("Verification History") if st.session_state.history: for idx, record in enumerate(reversed(st.session_state.history), 1): st.markdown(f"**{idx}. Claim:** {record['claim']} \n**Result:** {verdict_icons.get(record['verdict'], '')} {record['verdict']}") else: st.write("No verification history yet.") elif nav_option == "About": st.subheader("About") st.markdown("""

""", unsafe_allow_html=True) st.markdown(""" **Description:** SemViQA is a semantic QA system designed for fact-checking in Vietnamese. It extracts evidence from the provided context and classifies the claim as **SUPPORTED**, **REFUTED**, or **NEI** (Not Enough Information) using state-of-the-art models. """) st.markdown("
", unsafe_allow_html=True)