raj-tomar001 commited on
Commit
ed6e5bf
·
verified ·
1 Parent(s): 880a0f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -223
app.py CHANGED
@@ -1,169 +1,78 @@
1
- import gradio as gr
2
- from transformers import DebertaTokenizer, DebertaForSequenceClassification, DistilBertTokenizer, DistilBertForSequenceClassification
3
- from transformers import pipeline
4
  import json
5
- import numpy as np
6
  import random
 
 
 
 
7
 
8
- save_path_abstract = './fine-tuned-distillberta'
9
- model_abstract = DistilBertForSequenceClassification.from_pretrained(save_path_abstract)
10
- tokenizer_abstract = DistilBertTokenizer.from_pretrained(save_path_abstract)
11
-
12
- classifier_abstract = pipeline('text-classification', model=model_abstract, tokenizer=tokenizer_abstract)
13
 
14
- save_path_essay = './fine-tuned-distillberta'
15
- model_essay = DistilBertForSequenceClassification.from_pretrained(save_path_essay)
16
- tokenizer_essay = DistilBertTokenizer.from_pretrained(save_path_essay)
 
 
17
 
18
- classifier_essay = pipeline('text-classification', model=model_essay, tokenizer=tokenizer_essay)
19
 
20
- demo_essays = json.load(open('samples.json'))
21
- index = None
 
22
 
 
 
23
 
24
- ################# HELPER FUNCTIONS (DETECTION TAB) ####################
 
 
 
 
 
25
 
26
  def process_result_detection_tab(text):
27
- '''
28
- Classify the text into one of the four categories by averaging the soft predictions of the two models.
29
-
30
- Args:
31
- text: str: the text to be classified
32
- Returns:
33
- dict: a dictionary with the following keys:
34
- 'Machine Generated': float: the probability that the text is machine generated
35
- 'Human Written': float: the probability that the text is human written
36
- 'Machine Written, Machine Humanized': float: the probability that the text is machine written and machine humanized
37
- 'Human Written, Machine Polished': float: the probability that the text is human written and machine polished
38
- '''
39
- mapping = {'llm': 'Machine Generated', 'human':'Human Written', 'machine-humanized': 'Machine Written, Machine Humanized', 'machine-polished': 'Human Written, Machine Polished'}
40
- result = classifier_abstract(text)
41
- result_r = classifier_essay(text)
42
-
43
- labels = [mapping[x['label']] for x in result]
44
- scores = list(0.5 * np.array([x['score'] for x in result]) + 0.5 * np.array([x['score'] for x in result_r]))
45
 
46
  final_results = dict(zip(labels, scores))
47
- print(final_results)
48
- return final_results
49
-
50
- def update_detection_tab(name, uploaded_file, radio_input):
51
- '''
52
- Callback function to update the result of the classification based on the input text or uploaded file.
53
- Args:
54
- name: str: the input text from the Textbox
55
- uploaded_file: file: the uploaded file from the file input
56
- Returns:
57
- dict: the result of the classification including labels and scores
58
- '''
59
-
60
- if name == '' and uploaded_file is None:
61
  return ""
62
- if uploaded_file is not None:
63
- return f"Work in progress"
64
- else:
65
- return process_result_detection_tab(name)
66
 
67
- def active_button_detection_tab(input_text, file_input):
68
- '''
69
- Callback function to activate the 'Check Origin' button when the input text or file input
70
- is not empty. For text input, the button can be clickde only when the word count is between
71
- 50 and 500.
72
-
73
- Args:
74
- input_text: str: the input text from the textbox
75
- file_input: file: the uploaded file from the file input
76
- Returns:
77
- gr.Button: The 'Check Origin' button with the appropriate interactivity.
78
- '''
79
-
80
- if (input_text == "" and file_input is None) or (file_input is None and not (50 <= len(input_text.split()) <= 500)):
81
  return gr.Button("Check Origin", variant="primary", interactive=False)
82
-
83
  return gr.Button("Check Origin", variant="primary", interactive=True)
84
 
85
  def clear_detection_tab():
86
- '''
87
- Callback function to clear the input text and file input in the 'Try it!' tab.
88
- The interactivity of the 'Check Origin' button is set to False to prevent user click when the Textbox is empty.
89
-
90
- Args:
91
- None
92
- Returns:
93
- str: An empty string to clear the Textbox.
94
- None: None to clear the file input.
95
- gr.Button: The 'Check Origin' button with no interactivity.
96
- '''
97
-
98
- return "", None, gr.Button("Check Origin", variant="primary", interactive=False)
99
 
100
  def count_words_detection_tab(text):
101
- '''
102
- Callback function called when the input text is changed to update the word count.
103
- Args:
104
- text: str: the input text from the Textbox
105
- Returns:
106
- str: the word count of the input text for the Markdown widget
107
- '''
108
- return (f'{len(text.split())}/500 words (Minimum 50 words)')
109
-
110
-
111
- ################# HELPER FUNCTIONS (CHALLENGE TAB) ####################
112
-
113
- def clear_challenge_tab():
114
- '''
115
- Callback function to clear the text and result in the 'Challenge Yourself' tab.
116
- The interactivity of the buttons is set to False to prevent user click when the Textbox is empty.
117
-
118
- Args:
119
- None
120
- Returns:
121
- gr.Button: The 'Machine-Generated' button with no interactivity.
122
- gr.Button: The 'Human-Written' button with no interactivity.
123
- gr.Button: The 'Machine-Humanized' button with no interactivity.
124
- gr.Button: The 'Machine-Polished' button with no interactivity.
125
- str: An empty string to clear the Textbox.
126
- '''
127
-
128
- mg = gr.Button("Machine-Generated", variant="secondary", interactive=False)
129
- hw = gr.Button("Human-Written", variant="secondary", interactive=False)
130
- mh = gr.Button("Machine-Humanized", variant="secondary", interactive=False)
131
- mp = gr.Button("Machine-Polished", variant="secondary", interactive=False)
132
-
133
- return mg, hw, mh, mp, ''
134
 
135
  def generate_text_challenge_tab():
136
- '''
137
- Callback function to randomly sample an essay from the dataset and set the interactivity of the buttons to True.
138
- Args:
139
- None
140
- Returns:
141
- str: A sample text from the dataset
142
- gr.Button: The 'Machine-Generated' button with interactivity.
143
- gr.Button: The 'Human-Written' button with interactivity.
144
- gr.Button: The 'Machine-Humanized' button with interactivity.
145
- gr.Button: The 'Machine-Polished' button with interactivity.
146
- str: An empty string to clear the Result.
147
- '''
148
-
149
- global index # to access the index of the sample text for the show_result function
150
  mg = gr.Button("Machine-Generated", variant="secondary", interactive=True)
151
  hw = gr.Button("Human-Written", variant="secondary", interactive=True)
152
- mh = gr.Button("Machine-Humanized", variant="secondary", interactive=True)
153
  mp = gr.Button("Machine-Polished", variant="secondary", interactive=True)
 
154
  index = random.choice(range(80))
155
  essay = demo_essays[index][0]
156
  return essay, mg, hw, mh, mp, ''
157
 
158
  def correct_label_challenge_tab():
159
- '''
160
- Function to return the correct label of the sample text based on the index (global variable).
161
- Args:
162
- None
163
- Returns:
164
- str: The correct label of the sample text
165
- '''
166
-
167
  if 0 <= index < 20 :
168
  return 'Human-Written'
169
  elif 20 <= index < 40:
@@ -174,21 +83,6 @@ def correct_label_challenge_tab():
174
  return 'Machine-Humanized'
175
 
176
  def show_result_challenge_tab(button):
177
- '''
178
- Callback function to show the result of the classification based on the button clicked by the user.
179
- The correct label of the sample text is displayed in the primary variant.
180
- The chosen label by the user is displayed in the stop variant if it is incorrect.
181
-
182
- Args:
183
- button: str: the label of the button clicked by the user
184
- Returns:
185
- str: the outcome of the classification
186
- gr.Button: The 'Machine-Generated' button with the appropriate variant.
187
- gr.Button: The 'Human-Written' button with the appropriate variant.
188
- gr.Button: The 'Machine-Humanized' button with the appropriate variant.
189
- gr.Button: The 'Machine-Polished' button with the appropriate variant.
190
- '''
191
-
192
  correct_btn = correct_label_challenge_tab()
193
  mg = gr.Button("Machine-Generated", variant="secondary")
194
  hw = gr.Button("Human-Written", variant="secondary")
@@ -213,104 +107,90 @@ def show_result_challenge_tab(button):
213
  elif correct_btn == 'Machine-Polished':
214
  mp = gr.Button("Machine-Polished", variant="primary")
215
 
216
- outcome = ''
217
- if button == correct_btn:
218
- outcome = 'Correct'
219
- else:
220
- outcome = 'Incorrect'
221
 
222
  return outcome, mg, hw, mh, mp
223
 
224
-
225
- ############################## GRADIO UI ##############################
226
-
227
- with gr.Blocks() as demo:
228
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  gr.Markdown("""<h1><centre>Machine Generated Text (MGT) Detection</center></h1>""")
230
  with gr.Tab('Try it!'):
 
231
 
232
  with gr.Row():
233
- radio_button = gr.Dropdown(['Student Essay', 'Scientific Abstract'], label = 'Text Type', info = 'We have specialized models that work on domain-specific text.', value='Student Essay')
234
-
235
- with gr.Row():
236
-
237
  input_text = gr.Textbox(placeholder="Paste your text here...", label="Text", lines=10, max_lines=15)
238
- file_input = gr.File(label="Upload File", file_types=[".txt", ".pdf"])
239
 
240
  with gr.Row():
241
  wc = gr.Markdown("0/500 words (Minimum 50 words)")
242
  with gr.Row():
243
  check_button = gr.Button("Check Origin", variant="primary", interactive=False)
244
- clear_button = gr.ClearButton([input_text, file_input], variant="stop")
245
 
246
  out = gr.Label(label='Result')
247
  clear_button.add(out)
248
 
249
- check_button.click(fn=update_detection_tab, inputs=[input_text, file_input, radio_button], outputs=out)
250
 
251
  input_text.change(count_words_detection_tab, input_text, wc, show_progress=False)
252
  input_text.input(
253
  active_button_detection_tab,
254
- [input_text, file_input],
255
  [check_button],
256
  )
257
 
258
- file_input.upload(
259
- active_button_detection_tab,
260
- [input_text, file_input],
261
- [check_button],
262
- )
263
-
264
  clear_button.click(
265
  clear_detection_tab,
266
  inputs=[],
267
- outputs=[input_text, file_input, check_button],
268
- )
269
-
270
-
271
- # Adding JavaScript to simulate file input click
272
- gr.Markdown(
273
- """
274
- <script>
275
- document.addEventListener("DOMContentLoaded", function() {
276
- const uploadButton = Array.from(document.getElementsByTagName('button')).find(el => el.innerText === "Upload File");
277
- if (uploadButton) {
278
- uploadButton.onclick = function() {
279
- document.querySelector('input[type="file"]').click();
280
- };
281
- }
282
- });
283
- </script>
284
- """
285
  )
286
 
287
  with gr.Tab('Challenge Yourself!'):
288
- gr.Markdown(
289
- """
290
- <style>
291
- .gr-button-secondary {
292
- width: 100px;
293
- height: 30px;
294
- padding: 5px;
295
- }
296
- .gr-row {
297
- display: flex;
298
- align-items: center;
299
- gap: 10px;
300
- }
301
- .gr-block {
302
- padding: 20px;
303
- }
304
- .gr-markdown p {
305
- font-size: 16px;
306
- }
307
- </style>
308
- <span style='font-family: Arial, sans-serif; font-size: 20px;'>Was this text written by <strong>human</strong> or <strong>AI</strong>?</span>
309
- <p style='font-family: Arial, sans-serif;'>Try detecting one of our sample texts:</p>
310
- """
311
- )
312
-
313
-
314
  with gr.Row():
315
  generate = gr.Button("Generate Sample Text", variant="primary")
316
  clear = gr.ClearButton([], variant="stop")
@@ -319,7 +199,6 @@ with gr.Blocks() as demo:
319
  text = gr.Textbox(value="", label="Text", lines=20, interactive=False)
320
 
321
  with gr.Row():
322
-
323
  mg = gr.Button("Machine-Generated", variant="secondary", interactive=False)
324
  hw = gr.Button("Human-Written", variant="secondary", interactive=False)
325
  mh = gr.Button("Machine-Humanized", variant="secondary", interactive=False)
@@ -333,8 +212,12 @@ with gr.Blocks() as demo:
333
  for button in [mg, hw, mh, mp]:
334
  button.click(show_result_challenge_tab, [button], [result, mg, hw, mh, mp])
335
 
336
- clear.click(clear_challenge_tab, [], [mg, hw, mh, mp, result])
337
-
 
 
 
 
 
338
 
339
  demo.launch(share=False)
340
-
 
 
 
 
1
  import json
 
2
  import random
3
+ from pathlib import Path
4
+ import gradio as gr
5
+ import numpy as np
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
7
 
8
+ # Constants
9
+ MIN_WORDS = 50
10
+ MAX_WORDS = 500
11
+ SAMPLE_JSON_PATH = Path('samples.json')
 
12
 
13
+ # Load models
14
+ def load_model(model_name):
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
17
+ return pipeline('text-classification', model=model, tokenizer=tokenizer, truncation=True, max_length=512, top_k=4)
18
 
19
+ classifier = load_model("./fine-tuned-distillberta")
20
 
21
+ # Load sample essays
22
+ with open(SAMPLE_JSON_PATH, 'r') as f:
23
+ demo_essays = json.load(f)
24
 
25
+ # Global variable to store the current essay index
26
+ current_essay_index = None
27
 
28
+ TEXT_CLASS_MAPPING = {
29
+ 'llm': 'Machine Generated',
30
+ 'human': 'Human Written',
31
+ 'machine-humanized': 'Machine Written, Machine Humanized',
32
+ 'machine-polished': 'Human Written, Machine Polished'
33
+ }
34
 
35
  def process_result_detection_tab(text):
36
+
37
+ result = classifier(text)[0]
38
+
39
+ labels = [TEXT_CLASS_MAPPING[x['label']] for x in result]
40
+ scores = list(np.array([x['score'] for x in result]))
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  final_results = dict(zip(labels, scores))
43
+
44
+ # Return only the label with the highest score
45
+ return max(final_results, key=final_results.get)
46
+
47
+ def update_detection_tab(name):
48
+ if name == '':
 
 
 
 
 
 
 
 
49
  return ""
50
+ return process_result_detection_tab(name)
 
 
 
51
 
52
+ def active_button_detection_tab(input_text):
53
+ if not (50 <= len(input_text.split()) <= 500):
 
 
 
 
 
 
 
 
 
 
 
 
54
  return gr.Button("Check Origin", variant="primary", interactive=False)
 
55
  return gr.Button("Check Origin", variant="primary", interactive=True)
56
 
57
  def clear_detection_tab():
58
+ return "", gr.Button("Check Origin", variant="primary", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def count_words_detection_tab(text):
61
+ return f'{len(text.split())}/500 words (Minimum 50 words)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def generate_text_challenge_tab():
64
+ global index
65
+
 
 
 
 
 
 
 
 
 
 
 
 
66
  mg = gr.Button("Machine-Generated", variant="secondary", interactive=True)
67
  hw = gr.Button("Human-Written", variant="secondary", interactive=True)
68
+ mh = gr.Button("Machine-Humanized", variant="secondary", interactive=True)
69
  mp = gr.Button("Machine-Polished", variant="secondary", interactive=True)
70
+
71
  index = random.choice(range(80))
72
  essay = demo_essays[index][0]
73
  return essay, mg, hw, mh, mp, ''
74
 
75
  def correct_label_challenge_tab():
 
 
 
 
 
 
 
 
76
  if 0 <= index < 20 :
77
  return 'Human-Written'
78
  elif 20 <= index < 40:
 
83
  return 'Machine-Humanized'
84
 
85
  def show_result_challenge_tab(button):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  correct_btn = correct_label_challenge_tab()
87
  mg = gr.Button("Machine-Generated", variant="secondary")
88
  hw = gr.Button("Human-Written", variant="secondary")
 
107
  elif correct_btn == 'Machine-Polished':
108
  mp = gr.Button("Machine-Polished", variant="primary")
109
 
110
+ outcome = 'Correct' if button == correct_btn else 'Incorrect'
 
 
 
 
111
 
112
  return outcome, mg, hw, mh, mp
113
 
114
+ css = """
115
+ body, .gradio-container {
116
+ font-family: Arial, sans-serif;
117
+ }
118
+ .gr-button {
119
+ background-color: #1e1e1e;
120
+ border: 1px solid #333333;
121
+ color: #ffffff;
122
+ }
123
+ .gr-button:hover {
124
+ background-color: #2e2e2e;
125
+ }
126
+ .gr-input, .gr-textarea {
127
+ background-color: #1f2937;
128
+ border: 1px solid #333333;
129
+ color: #ffffff;
130
+ }
131
+ .gr-form {
132
+ background-color: #1f2937;
133
+ border: 1px solid #333333;
134
+ }
135
+ .class-intro {
136
+ background-color: #1f2937;
137
+ border: 1px solid #333333;
138
+ padding: 15px;
139
+ margin-bottom: 20px;
140
+ border-radius: 5px;
141
+ }
142
+ .class-intro h2 {
143
+ margin-top: 0;
144
+ color: #ffffff;
145
+ }
146
+ .class-intro p {
147
+ margin-bottom: 5px;
148
+ }
149
+ """
150
+
151
+ class_intro_html = """
152
+ <div class="class-intro">
153
+ <h2>Text Classes</h2>
154
+ <p><strong>Human Written:</strong> Original text created by humans.</p>
155
+ <p><strong>Machine Generated:</strong> Text created by AI from basic prompts, without style instructions.</p>
156
+ <p><strong>Human Written, Machine Polished:</strong> Human text refined by AI for grammar and flow, without new content.</p>
157
+ <p><strong>Machine Written, Machine Humanized:</strong> AI-generated text modified to mimic human writing style.</p>
158
+ </div>
159
+ """
160
+
161
+ with gr.Blocks(css=css) as demo:
162
  gr.Markdown("""<h1><centre>Machine Generated Text (MGT) Detection</center></h1>""")
163
  with gr.Tab('Try it!'):
164
+ gr.HTML(class_intro_html)
165
 
166
  with gr.Row():
 
 
 
 
167
  input_text = gr.Textbox(placeholder="Paste your text here...", label="Text", lines=10, max_lines=15)
 
168
 
169
  with gr.Row():
170
  wc = gr.Markdown("0/500 words (Minimum 50 words)")
171
  with gr.Row():
172
  check_button = gr.Button("Check Origin", variant="primary", interactive=False)
173
+ clear_button = gr.ClearButton([input_text], variant="stop")
174
 
175
  out = gr.Label(label='Result')
176
  clear_button.add(out)
177
 
178
+ check_button.click(fn=update_detection_tab, inputs=[input_text], outputs=out)
179
 
180
  input_text.change(count_words_detection_tab, input_text, wc, show_progress=False)
181
  input_text.input(
182
  active_button_detection_tab,
183
+ [input_text],
184
  [check_button],
185
  )
186
 
 
 
 
 
 
 
187
  clear_button.click(
188
  clear_detection_tab,
189
  inputs=[],
190
+ outputs=[input_text, check_button],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  )
192
 
193
  with gr.Tab('Challenge Yourself!'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  with gr.Row():
195
  generate = gr.Button("Generate Sample Text", variant="primary")
196
  clear = gr.ClearButton([], variant="stop")
 
199
  text = gr.Textbox(value="", label="Text", lines=20, interactive=False)
200
 
201
  with gr.Row():
 
202
  mg = gr.Button("Machine-Generated", variant="secondary", interactive=False)
203
  hw = gr.Button("Human-Written", variant="secondary", interactive=False)
204
  mh = gr.Button("Machine-Humanized", variant="secondary", interactive=False)
 
212
  for button in [mg, hw, mh, mp]:
213
  button.click(show_result_challenge_tab, [button], [result, mg, hw, mh, mp])
214
 
215
+ clear.click(lambda: ("",
216
+ gr.Button("Machine-Generated", variant="secondary", interactive=False),
217
+ gr.Button("Human-Written", variant="secondary", interactive=False),
218
+ gr.Button("Machine-Humanized", variant="secondary", interactive=False),
219
+ gr.Button("Machine-Polished", variant="secondary", interactive=False),
220
+ ""),
221
+ outputs=[text, mg, hw, mh, mp, result])
222
 
223
  demo.launch(share=False)