Spaces:
Running
Running
import json | |
import random | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
# Constants | |
MIN_WORDS = 50 | |
MAX_WORDS = 500 | |
SAMPLE_JSON_PATH = Path('samples.json') | |
# Load models | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
return pipeline('text-classification', model=model, tokenizer=tokenizer, truncation=True, max_length=512, top_k=4) | |
classifier = load_model("./fine_tuned_roberta-base") | |
# Load sample essays | |
with open(SAMPLE_JSON_PATH, 'r') as f: | |
demo_essays = json.load(f) | |
# Global variable to store the current essay index | |
current_essay_index = None | |
TEXT_CLASS_MAPPING = { | |
'LABEL_0': 'Machine Generated', | |
'LABEL_1': 'Human Written', | |
'LABEL_2': 'Machine Written, Machine Humanized', | |
'LABEL_3': 'Human Written, Machine Polished' | |
} | |
def process_result_detection_tab(text): | |
result = classifier(text)[0] | |
labels = [TEXT_CLASS_MAPPING[x['label']] for x in result] | |
scores = list(np.array([x['score'] for x in result])) | |
final_results = dict(zip(labels, scores)) | |
# Return only the label with the highest score | |
return max(final_results, key=final_results.get) | |
def update_detection_tab(name): | |
if name == '': | |
return "" | |
return process_result_detection_tab(name) | |
def active_button_detection_tab(input_text): | |
if not (50 <= len(input_text.split()) <= 500): | |
return gr.Button("Check Origin", variant="primary", interactive=False) | |
return gr.Button("Check Origin", variant="primary", interactive=True) | |
def clear_detection_tab(): | |
return "", gr.Button("Check Origin", variant="primary", interactive=False) | |
def count_words_detection_tab(text): | |
return f'{len(text.split())}/500 words (Minimum 50 words)' | |
def generate_text_challenge_tab(): | |
global index | |
mg = gr.Button("Machine-Generated", variant="secondary", interactive=True) | |
hw = gr.Button("Human-Written", variant="secondary", interactive=True) | |
mh = gr.Button("Machine-Humanized", variant="secondary", interactive=True) | |
mp = gr.Button("Machine-Polished", variant="secondary", interactive=True) | |
index = random.choice(range(80)) | |
essay = demo_essays[index][0] | |
return essay, mg, hw, mh, mp, '' | |
def correct_label_challenge_tab(): | |
if 0 <= index < 20 : | |
return 'Human-Written' | |
elif 20 <= index < 40: | |
return 'Machine-Generated' | |
elif 40 <= index < 60: | |
return 'Machine-Polished' | |
elif 60 <= index < 80: | |
return 'Machine-Humanized' | |
def show_result_challenge_tab(button): | |
correct_btn = correct_label_challenge_tab() | |
mg = gr.Button("Machine-Generated", variant="secondary") | |
hw = gr.Button("Human-Written", variant="secondary") | |
mh = gr.Button("Machine-Humanized", variant="secondary") | |
mp = gr.Button("Machine-Polished", variant="secondary") | |
if button == 'Machine-Generated': | |
mg = gr.Button("Machine-Generated", variant="stop") | |
elif button == 'Human-Written': | |
hw = gr.Button("Human-Written", variant="stop") | |
elif button == 'Machine-Humanized': | |
mh = gr.Button("Machine-Humanized", variant="stop") | |
elif button == 'Machine-Polished': | |
mp = gr.Button("Machine-Polished", variant="stop") | |
if correct_btn == 'Machine-Generated': | |
mg = gr.Button("Machine-Generated", variant="primary") | |
elif correct_btn == 'Human-Written': | |
hw = gr.Button("Human-Written", variant="primary") | |
elif correct_btn == 'Machine-Humanized': | |
mh = gr.Button("Machine-Humanized", variant="primary") | |
elif correct_btn == 'Machine-Polished': | |
mp = gr.Button("Machine-Polished", variant="primary") | |
outcome = 'Correct' if button == correct_btn else 'Incorrect' | |
return outcome, mg, hw, mh, mp | |
css = """ | |
body, .gradio-container { | |
font-family: Arial, sans-serif; | |
} | |
.gr-input, .gr-textarea { | |
} | |
.class-intro { | |
padding: 15px; | |
margin-bottom: 20px; | |
border-radius: 5px; | |
} | |
.class-intro h2 { | |
margin-top: 0; | |
} | |
.class-intro p { | |
margin-bottom: 5px; | |
} | |
""" | |
class_intro_html = """ | |
<div class="class-intro"> | |
<h2>Text Classes</h2> | |
<p><strong>Human Written:</strong> Original text created by humans.</p> | |
<p><strong>Machine Generated:</strong> Text created by AI from basic prompts, without style instructions.</p> | |
<p><strong>Human Written, Machine Polished:</strong> Human text refined by AI for grammar and flow, without new content.</p> | |
<p><strong>Machine Written, Machine Humanized:</strong> AI-generated text modified to mimic human writing style.</p> | |
</div> | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("""<h1><centre>Machine Generated Text (MGT) Detection</center></h1>""") | |
with gr.Tab('Try it!'): | |
gr.HTML(class_intro_html) | |
with gr.Row(): | |
input_text = gr.Textbox(placeholder="Paste your text here...", label="Text", lines=10, max_lines=15) | |
with gr.Row(): | |
wc = gr.Markdown("0/500 words (Minimum 50 words)") | |
with gr.Row(): | |
check_button = gr.Button("Check Origin", variant="primary", interactive=False) | |
clear_button = gr.ClearButton([input_text], variant="stop") | |
out = gr.Label(label='Result') | |
clear_button.add(out) | |
check_button.click(fn=update_detection_tab, inputs=[input_text], outputs=out) | |
input_text.change(count_words_detection_tab, input_text, wc, show_progress=False) | |
input_text.input( | |
active_button_detection_tab, | |
[input_text], | |
[check_button], | |
) | |
clear_button.click( | |
clear_detection_tab, | |
inputs=[], | |
outputs=[input_text, check_button], | |
) | |
with gr.Tab('Challenge Yourself!'): | |
with gr.Row(): | |
generate = gr.Button("Generate Sample Text", variant="primary") | |
clear = gr.ClearButton([], variant="stop") | |
with gr.Row(): | |
text = gr.Textbox(value="", label="Text", lines=20, interactive=False) | |
with gr.Row(): | |
mg = gr.Button("Machine-Generated", variant="secondary", interactive=False) | |
hw = gr.Button("Human-Written", variant="secondary", interactive=False) | |
mh = gr.Button("Machine-Humanized", variant="secondary", interactive=False) | |
mp = gr.Button("Machine-Polished", variant="secondary", interactive=False) | |
with gr.Row(): | |
result = gr.Label(label="Result", value="") | |
clear.add([result, text]) | |
generate.click(generate_text_challenge_tab, [], [text, mg, hw, mh, mp, result]) | |
for button in [mg, hw, mh, mp]: | |
button.click(show_result_challenge_tab, [button], [result, mg, hw, mh, mp]) | |
clear.click(lambda: ("", | |
gr.Button("Machine-Generated", variant="secondary", interactive=False), | |
gr.Button("Human-Written", variant="secondary", interactive=False), | |
gr.Button("Machine-Humanized", variant="secondary", interactive=False), | |
gr.Button("Machine-Polished", variant="secondary", interactive=False), | |
""), | |
outputs=[text, mg, hw, mh, mp, result]) | |
demo.launch(share=False) | |