LowKey / app.py
Jacob Logas
Fix device usage?
0792228 unverified
import spaces
import gradio as gr
import torch
from PIL import Image
import numpy as np
from util.feature_extraction_utils import normalize_transforms
from util.attack_utils import Attack
from util.prepare_utils import prepare_models, prepare_dir_vec, get_ensemble
from align.detector import detect_faces
from align.align_trans import get_reference_facial_points, warp_and_crop_face
import torchvision.transforms as transforms
to_tensor = transforms.ToTensor()
eps = 0.05
n_iters = 50
input_size = [112, 112]
attack_type = "lpips"
c_tv = None
c_sim = 0.05
lr = 0.0025
net_type = "alex"
noise_size = 0.005
n_starts = 1
kernel_size_gf = 7
sigma_gf = 3
combination = True
using_subspace = False
V_reduction_root = "./"
model_backbones = ["IR_152", "IR_152", "ResNet_152", "ResNet_152"]
model_roots = [
"https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_IR_152_Arcface_Epoch_112.pth",
"https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_IR_152_Cosface_Epoch_70.pth",
"https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_ResNet_152_Arcface_Epoch_65.pth",
"https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_ResNet_152_Cosface_Epoch_68.pth",
]
direction = 1
crop_size = 112
scale = crop_size / 112.0
for root in model_roots:
torch.hub.load_state_dict_from_url(root, map_location="cpu", progress=True)
@spaces.GPU(duration=120)
def protect(img, progress=gr.Progress(track_tqdm=True)):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models_attack, V_reduction, dim = prepare_models(
model_backbones,
input_size,
model_roots,
kernel_size_gf,
sigma_gf,
combination,
using_subspace,
V_reduction_root,
)
img = Image.fromarray(img)
reference = get_reference_facial_points(default_square=True) * scale
h, w, c = np.array(img).shape
_, landmarks = detect_faces(img)
facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
_, tfm = warp_and_crop_face(
np.array(img), facial5points, reference, crop_size=(crop_size, crop_size)
)
# pytorch transform
theta = normalize_transforms(tfm, w, h)
tensor_img = to_tensor(img).unsqueeze(0).to(device)
V_reduction = None
dim = 512
# Find gradient direction vector
dir_vec_extractor = get_ensemble(
models=models_attack,
sigma_gf=None,
kernel_size_gf=None,
combination=False,
V_reduction=V_reduction,
warp=True,
theta_warp=theta,
)
dir_vec = prepare_dir_vec(dir_vec_extractor, tensor_img, dim, combination)
img_attacked = tensor_img.clone()
attack = Attack(
models_attack,
dim,
attack_type,
eps,
c_sim,
net_type,
lr,
n_iters,
noise_size,
n_starts,
c_tv,
sigma_gf,
kernel_size_gf,
combination,
warp=True,
theta_warp=theta,
V_reduction=V_reduction,
)
img_attacked = attack.execute(tensor_img, dir_vec, direction).detach().cpu()
img_attacked_pil = transforms.ToPILImage()(img_attacked[0])
return img_attacked_pil
gr.Interface(
fn=protect,
inputs=gr.components.Image(height=512, width=512),
outputs=gr.components.Image(type="pil"),
allow_flagging="never",
).launch(show_error=True, quiet=False, share=False)