ohayonguy commited on
Commit
5afc7ad
·
1 Parent(s): 9a668a2

added support for num steps

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -58,7 +58,7 @@ def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, dev
58
 
59
  @torch.inference_mode()
60
  @spaces.GPU()
61
- def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_back=True, scale=2):
62
  face_helper.clean_all()
63
 
64
  if has_aligned: # the inputs are already aligned
@@ -79,7 +79,7 @@ def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_ba
79
 
80
  dummy_x = torch.zeros_like(cropped_face_t)
81
  with torch.autocast("cuda", dtype=torch.bfloat16):
82
- output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, 25, device)
83
  restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
84
  # restored_face = cropped_face
85
 
@@ -104,7 +104,7 @@ def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_ba
104
 
105
  @torch.inference_mode()
106
  @spaces.GPU()
107
- def inference(img, aligned, scale, num_steps):
108
  if scale > 4:
109
  scale = 4 # avoid too large scale value
110
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
@@ -136,7 +136,7 @@ def inference(img, aligned, scale, num_steps):
136
 
137
  has_aligned = True if aligned == 'Yes' else False
138
  _, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
139
- paste_back=True)
140
  if has_aligned:
141
  output = restored_aligned[0]
142
  else:
 
58
 
59
  @torch.inference_mode()
60
  @spaces.GPU()
61
+ def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face=False, paste_back=True, scale=2):
62
  face_helper.clean_all()
63
 
64
  if has_aligned: # the inputs are already aligned
 
79
 
80
  dummy_x = torch.zeros_like(cropped_face_t)
81
  with torch.autocast("cuda", dtype=torch.bfloat16):
82
+ output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, num_flow_steps, device)
83
  restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
84
  # restored_face = cropped_face
85
 
 
104
 
105
  @torch.inference_mode()
106
  @spaces.GPU()
107
+ def inference(img, aligned, scale, num_flow_steps):
108
  if scale > 4:
109
  scale = 4 # avoid too large scale value
110
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
136
 
137
  has_aligned = True if aligned == 'Yes' else False
138
  _, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
139
+ paste_back=True, num_flow_steps=num_flow_steps)
140
  if has_aligned:
141
  output = restored_aligned[0]
142
  else: