Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
import torch | |
from peft import PeftModel | |
import numpy as np | |
import os | |
from unittest.mock import patch | |
from transformers.dynamic_module_utils import get_imports | |
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: | |
if not str(filename).endswith("modeling_florence2.py"): | |
return get_imports(filename) | |
imports = get_imports(filename) | |
imports.remove("flash_attn") | |
return imports | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch_dtype = torch.float32 | |
# Load the fine-tuned base model | |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
caption_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True, revision='refs/pr/6', torch_dtype=torch_dtype).to(device) | |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
model = AutoModelForCausalLM.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True, torch_dtype=torch_dtype).to(device) | |
processor = AutoProcessor.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True) | |
model.eval() | |
def caption_generate(task_prompt, text_input=None, image=None): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if text_input is None: | |
prompt = task_prompt | |
else: | |
prompt = task_prompt + text_input | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) | |
generated_ids = caption_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) | |
return parsed_answer[task_prompt][1:-1] | |
def run_example(task_prompt, text_input=None, image=None): | |
if text_input is None: | |
prompt = task_prompt | |
else: | |
prompt = task_prompt + text_input | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = image.convert("RGB") | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
inputs = {k: v.to(torch_dtype) if v.is_floating_point() else v for k, v in inputs.items()} | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
result = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))['<DEEPFAKE_DETECTION>'] | |
if result.lower() == "yes": | |
return "This is a real image." | |
elif result.lower() == "no": | |
return "This is a fake image." | |
else: | |
return f"Uncertain. Model output: {result}" | |
# Define the Gradio interface | |
css = """ | |
body { | |
background-color: #1e1e2e; | |
color: #d4d4dc; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
#output { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #444; | |
background-color: #282c34; | |
color: #f1f1f1; | |
padding: 10px; | |
} | |
.gr-button { | |
background-color: #3a3f51; | |
border: none; | |
color: #ffffff; | |
padding: 10px 20px; | |
text-align: center; | |
font-size: 14px; | |
cursor: pointer; | |
transition: 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #4b5263; | |
} | |
.gr-textbox { | |
background-color: #2e2e38; | |
border: 1px solid #555; | |
color: #ffffff; | |
} | |
.gr-markdown { | |
color: #d4d4dc; | |
} | |
""" | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'dark') { | |
url.searchParams.set('__theme', 'dark'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
TITLE = "# FLODA: Vision-Language Models for Deepfake Detection" | |
DESCRIPTION = """ | |
FLODA (FLorence-2 Optimized for Deepfake Assessment) is an advanced deepfake detection model leveraging the power of [Florence-2](https://huggingface.co/microsoft/Florence-2-base-ft). | |
FLODA combines image captioning with authenticity assessment in a single end-to-end architecture, demonstrating superior performance compared to existing benchmarks. | |
Learn more about FLODA in the published paper [here](https://github.com/byh711/FLODA). | |
""" | |
with gr.Blocks(js=js_func, css=css) as demo: | |
gr.Markdown(TITLE) | |
gr.Markdown(DESCRIPTION) | |
with gr.Tab(label="FLODA: Deepfake Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label="Input Picture", type="numpy") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Output Text") | |
submit_btn.click( | |
fn=lambda image: run_example("<DEEPFAKE_DETECTION>", text_input=None, image=image), | |
inputs=[input_img], | |
outputs=[output_text] | |
) | |
demo.launch(debug=True) |