xuandin commited on
Commit
35de67c
·
verified ·
1 Parent(s): 77dabd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -65
app.py CHANGED
@@ -11,17 +11,22 @@ from semviqa.tvc.tvc_eval import classify_claim
11
  def load_model(model_name, model_class, is_bc=False):
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
 
14
  return tokenizer, model
15
 
16
- # Set up page configuration and custom CSS for a modern, clean look
17
  st.set_page_config(page_title="SemViQA Demo", layout="wide")
18
 
 
19
  st.markdown("""
20
  <style>
 
 
 
21
  .big-title {
22
  font-size: 36px;
23
  font-weight: bold;
24
- color: #4A90E2;
25
  text-align: center;
26
  margin-top: 20px;
27
  }
@@ -38,9 +43,10 @@ st.markdown("""
38
  width: 100%;
39
  border-radius: 8px;
40
  padding: 10px;
 
41
  }
42
- .stTextArea textarea {
43
- font-size: 16px;
44
  }
45
  .result-box {
46
  background-color: #f9f9f9;
@@ -52,7 +58,6 @@ st.markdown("""
52
  .verdict {
53
  font-size: 24px;
54
  font-weight: bold;
55
- margin: 0;
56
  display: flex;
57
  align-items: center;
58
  }
@@ -62,11 +67,12 @@ st.markdown("""
62
  </style>
63
  """, unsafe_allow_html=True)
64
 
65
- st.markdown("<p class='big-title'>SemViQA: Semantic Question Answering System for Vietnamese Fact-Checking</p>", unsafe_allow_html=True)
 
66
  st.markdown("<p class='sub-title'>Enter a claim and context to verify its accuracy</p>", unsafe_allow_html=True)
67
 
68
- # Sidebar: Settings and additional features
69
- with st.sidebar.expander("⚙️ Settings", expanded=False):
70
  tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01)
71
  length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01)
72
  qatc_model_name = st.selectbox("QATC Model", [
@@ -93,71 +99,55 @@ with st.sidebar.expander("⚙️ Settings", expanded=False):
93
  ])
94
  show_details = st.checkbox("Show probability details", value=False)
95
 
96
- # Initialize verification history in session state
97
- if 'history' not in st.session_state:
98
- st.session_state.history = []
99
-
100
- # Load the selected models
101
  tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
102
  tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
103
  tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
104
 
105
- # User input fields
106
- claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.")
107
- context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.")
108
-
109
- # Define icon mapping for each verdict label
110
  verdict_icons = {
111
  "SUPPORTED": "✅",
112
  "REFUTED": "❌",
113
  "NEI": "⚠️"
114
  }
115
 
116
- if st.button("Verify"):
117
- with st.spinner("Verifying..."):
118
- # Extract evidence
119
- evidence = extract_evidence_tfidf_qatc(
120
- claim, context, model_qatc, tokenizer_qatc, "cuda" if torch.cuda.is_available() else "cpu",
121
- confidence_threshold=tfidf_threshold, length_ratio_threshold=length_ratio_threshold
122
- )
123
-
124
- # Classify the claim
125
- verdict = "NEI"
126
- prob3class, pred_tc = classify_claim(claim, evidence, model_tc, tokenizer_tc, "cuda" if torch.cuda.is_available() else "cpu")
127
-
128
- details = ""
129
- if pred_tc != 0:
130
- prob2class, pred_bc = classify_claim(claim, evidence, model_bc, tokenizer_bc, "cuda" if torch.cuda.is_available() else "cpu")
131
- if pred_bc == 0:
132
- verdict = "SUPPORTED"
133
- elif prob2class > prob3class:
134
- verdict = "REFUTED"
135
- else:
136
- verdict = ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
137
- if show_details:
138
- details = f"<p><strong>3-Class Probability:</strong> {prob3class:.2f} - <strong>2-Class Probability:</strong> {prob2class:.2f}</p>"
139
-
140
- # Save the verification record in session history
141
- st.session_state.history.append({
142
- "claim": claim,
143
- "evidence": evidence,
144
- "verdict": verdict
145
- })
146
-
147
- # Display the results with icon and label (without extra "Verdict:" text)
148
- st.markdown(f"""
149
- <div class='result-box'>
150
- <h3>Result</h3>
151
- <p><strong>Evidence:</strong> {evidence}</p>
152
- <p class='verdict'><span class='verdict-icon'>{verdict_icons.get(verdict, '')}</span>{verdict}</p>
153
- {details}
154
- </div>
155
- """, unsafe_allow_html=True)
156
 
157
- # Display verification history in the sidebar
158
- with st.sidebar.expander("Verification History", expanded=False):
159
- if st.session_state.history:
160
- for idx, record in enumerate(reversed(st.session_state.history), 1):
161
- st.markdown(f"**{idx}. Claim:** {record['claim']} \n**Result:** {verdict_icons.get(record['verdict'], '')} {record['verdict']}")
162
- else:
163
- st.write("No verification history yet.")
 
11
  def load_model(model_name, model_class, is_bc=False):
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
14
+ model.eval()
15
  return tokenizer, model
16
 
17
+ # Page Configuration
18
  st.set_page_config(page_title="SemViQA Demo", layout="wide")
19
 
20
+ # Custom CSS for improved UI
21
  st.markdown("""
22
  <style>
23
+ body {
24
+ font-family: 'Arial', sans-serif;
25
+ }
26
  .big-title {
27
  font-size: 36px;
28
  font-weight: bold;
29
+ color: #0078D4;
30
  text-align: center;
31
  margin-top: 20px;
32
  }
 
43
  width: 100%;
44
  border-radius: 8px;
45
  padding: 10px;
46
+ transition: 0.3s;
47
  }
48
+ .stButton>button:hover {
49
+ background-color: #45a049;
50
  }
51
  .result-box {
52
  background-color: #f9f9f9;
 
58
  .verdict {
59
  font-size: 24px;
60
  font-weight: bold;
 
61
  display: flex;
62
  align-items: center;
63
  }
 
67
  </style>
68
  """, unsafe_allow_html=True)
69
 
70
+ # Page Header
71
+ st.markdown("<p class='big-title'>SemViQA: Vietnamese Fact-Checking System</p>", unsafe_allow_html=True)
72
  st.markdown("<p class='sub-title'>Enter a claim and context to verify its accuracy</p>", unsafe_allow_html=True)
73
 
74
+ # Sidebar: Settings
75
+ with st.sidebar.expander("⚙️ Settings", expanded=True):
76
  tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01)
77
  length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01)
78
  qatc_model_name = st.selectbox("QATC Model", [
 
99
  ])
100
  show_details = st.checkbox("Show probability details", value=False)
101
 
102
+ # Load Models
 
 
 
 
103
  tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
104
  tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
105
  tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
106
 
107
+ # Define verdict icons
 
 
 
 
108
  verdict_icons = {
109
  "SUPPORTED": "✅",
110
  "REFUTED": "❌",
111
  "NEI": "⚠️"
112
  }
113
 
114
+ # Tabs for functionalities
115
+ tabs = st.tabs(["Verify", "History", "About"])
116
+
117
+ # --- Verify Tab ---
118
+ with tabs[0]:
119
+ st.subheader("Verify a Claim")
120
+ claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.")
121
+ context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia.")
122
+
123
+ if st.button("Verify", key="verify_button"):
124
+ with st.spinner("Verifying..."):
125
+ with torch.no_grad():
126
+ evidence = extract_evidence_tfidf_qatc(
127
+ claim, context, model_qatc, tokenizer_qatc,
128
+ "cuda" if torch.cuda.is_available() else "cpu",
129
+ confidence_threshold=tfidf_threshold,
130
+ length_ratio_threshold=length_ratio_threshold
131
+ )
132
+ verdict = "NEI"
133
+ prob3class, pred_tc = classify_claim(claim, evidence, model_tc, tokenizer_tc, "cuda" if torch.cuda.is_available() else "cpu")
134
+ if pred_tc != 0:
135
+ prob2class, pred_bc = classify_claim(claim, evidence, model_bc, tokenizer_bc, "cuda" if torch.cuda.is_available() else "cpu")
136
+ verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
137
+
138
+ # Display result
139
+ st.markdown(f"""
140
+ <div class='result-box'>
141
+ <h3>Result</h3>
142
+ <p><strong>Evidence:</strong> {evidence}</p>
143
+ <p class='verdict'><span class='verdict-icon'>{verdict_icons.get(verdict, '')}</span>{verdict}</p>
144
+ </div>
145
+ """, unsafe_allow_html=True)
146
+
147
+ if torch.cuda.is_available():
148
+ torch.cuda.empty_cache()
 
 
 
 
 
149
 
150
+ # --- About Tab ---
151
+ with tabs[2]:
152
+ st.subheader("About SemViQA")
153
+ st.markdown("""SemViQA is a semantic fact-checking system for Vietnamese information verification.""")