qwertyforce's picture
Update app.py
f2bff0f verified
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from statistics import mean
Image.MAX_IMAGE_PIXELS = None
def read_img_file(f):
img = Image.open(f)
if img.mode != 'RGB':
img = img.convert('RGB')
return img
_transform_test_random=transforms.Compose([
transforms.RandomCrop((256,256)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
_transform_test_random_vit = transforms.Compose([
transforms.RandomCrop((252,252)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
def detect(img, model_choices):
if model_choices == "EVA-02 ViT L/14":
model = torch.load("./model_eva.pth",map_location="cpu").cpu().eval()
_transform = _transform_test_random_vit
elif model_choices == "ConvNext Large":
model = torch.load("./model_convnext.pth",map_location="cpu").cpu().eval()
_transform = _transform_test_random
elif model_choices == "EfficientNet-V2 B0":
model = torch.load("./model_effnet.pth",map_location="cpu").cpu().eval()
_transform = _transform_test_random
output = ""
with torch.inference_mode():
tmp=[]
for _ in range(10):
img_random_crop = _transform(img)
outputs = model.forward(img_random_crop.unsqueeze(0))
outputs = torch.sigmoid(outputs).cpu().numpy()
tmp.append(outputs[0][0])
output+=f"{str(tmp)}\n"
output+=f"10 try method: {mean(tmp)}\n"
# print(tmp)
# print("10 try method: ", mean(tmp))
with torch.inference_mode():
img_crop = _transform(img)
outputs = model.forward(img_crop.unsqueeze(0))
outputs = torch.sigmoid(outputs).cpu().numpy()
output+=f"1 try method: {outputs}\n"
# print("1 try method: ",outputs)
return output
model_choices = ["ConvNext Large", "EVA-02 ViT L/14", "EfficientNet-V2 B0"]
descr = f"""
Detecting AutoEncoder is Enough to Catch LDM Generated Images (https://arxiv.org/abs/2411.06441)
Code at https://github.com/qwertyforce/Detect_LDM_By_Detecting_VAE
Models at https://huggingface.co/qwertyforce/Detect_LDM_By_Detecting_VAE
"""
demo = gr.Interface(fn=detect,
inputs=[gr.Image(type="pil", label="Input Image"),
gr.Radio(
model_choices,
type="value",
value="EVA-02 ViT L/14",
label="Choose Detector Model",
)], outputs="text",title = "Detecting AutoEncoder is Enough to Catch LDM Generated Images",description=descr)
demo.launch()