Jacob Logas commited on
Commit
8722cc4
·
1 Parent(s): 1173b78

Update for zerogpu

Browse files
Files changed (3) hide show
  1. app.py +8 -6
  2. requirements.txt +3 -1
  3. util/prepare_utils.py +4 -1
app.py CHANGED
@@ -8,6 +8,7 @@ from util.prepare_utils import prepare_models, prepare_dir_vec, get_ensemble
8
  from align.detector import detect_faces
9
  from align.align_trans import get_reference_facial_points, warp_and_crop_face
10
  import torchvision.transforms as transforms
 
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(device)
@@ -30,10 +31,10 @@ using_subspace = False
30
  V_reduction_root = "./"
31
  model_backbones = ["IR_152", "IR_152", "ResNet_152", "ResNet_152"]
32
  model_roots = [
33
- "models/Backbone_IR_152_Arcface_Epoch_112.pth",
34
- "models/Backbone_IR_152_Cosface_Epoch_70.pth",
35
- "models/Backbone_ResNet_152_Arcface_Epoch_65.pth",
36
- "models/Backbone_ResNet_152_Cosface_Epoch_68.pth",
37
  ]
38
  direction = 1
39
  crop_size = 112
@@ -51,6 +52,7 @@ models_attack, V_reduction, dim = prepare_models(
51
  )
52
 
53
 
 
54
  def protect(img):
55
  img = Image.fromarray(img)
56
  reference = get_reference_facial_points(default_square=True) * scale
@@ -110,7 +112,7 @@ def protect(img):
110
 
111
  gr.Interface(
112
  fn=protect,
113
- inputs=gr.components.Image(shape=(512, 512)),
114
  outputs=gr.components.Image(type="pil"),
115
  allow_flagging="never",
116
- ).launch(show_error=True, quiet=False, share=True)
 
8
  from align.detector import detect_faces
9
  from align.align_trans import get_reference_facial_points, warp_and_crop_face
10
  import torchvision.transforms as transforms
11
+ import spaces
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(device)
 
31
  V_reduction_root = "./"
32
  model_backbones = ["IR_152", "IR_152", "ResNet_152", "ResNet_152"]
33
  model_roots = [
34
+ "https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_IR_152_Arcface_Epoch_112.pth",
35
+ "https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_IR_152_Cosface_Epoch_70.pth",
36
+ "https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_ResNet_152_Arcface_Epoch_65.pth",
37
+ "https://github.com/cmu-spuds/lowkey_gradio/releases/download/weights/Backbone_ResNet_152_Cosface_Epoch_68.pth",
38
  ]
39
  direction = 1
40
  crop_size = 112
 
52
  )
53
 
54
 
55
+ @spaces.GPU
56
  def protect(img):
57
  img = Image.fromarray(img)
58
  reference = get_reference_facial_points(default_square=True) * scale
 
112
 
113
  gr.Interface(
114
  fn=protect,
115
+ inputs=gr.components.Image(height=512, width=512),
116
  outputs=gr.components.Image(type="pil"),
117
  allow_flagging="never",
118
+ ).launch(show_error=True, quiet=False, share=False)
requirements.txt CHANGED
@@ -4,4 +4,6 @@ Pillow>=10.4.0
4
  torch>=2.3.1
5
  torchvision>=0.18.1
6
  tqdm>=4.66.4
7
- lpips>=0.1.4
 
 
 
4
  torch>=2.3.1
5
  torchvision>=0.18.1
6
  tqdm>=4.66.4
7
+ opencv-python>=4.10.0.84
8
+ git+https://github.com/logasja/lpips-pytorch
9
+ spaces>=0.28.3
util/prepare_utils.py CHANGED
@@ -226,7 +226,10 @@ def prepare_models(
226
  models_attack = []
227
  for i in range(len(model_backbones)):
228
  model = backbone_dict[model_backbones[i]]
229
- model.load_state_dict(torch.load(model_roots[i], map_location=device))
 
 
 
230
  models_attack.append(model)
231
 
232
  if using_subspace:
 
226
  models_attack = []
227
  for i in range(len(model_backbones)):
228
  model = backbone_dict[model_backbones[i]]
229
+ state_dict = torch.hub.load_state_dict_from_url(
230
+ model_roots[i], map_location=device, progress=True
231
+ )
232
+ model.load_state_dict(state_dict)
233
  models_attack.append(model)
234
 
235
  if using_subspace: