cleardusk commited on
Commit
d4a5c81
·
1 Parent(s): 866a537

chore: refine mask_crop loading

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. src/config/inference_config.py +3 -1
  3. src/utils/crop.py +1 -1
app.py CHANGED
@@ -16,8 +16,8 @@ import gdown
16
  import os
17
  import spaces
18
 
19
- folder_url = f"https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib"
20
- gdown.download_folder(url=folder_url, output="pretrained_weights", quiet=False)
21
 
22
  def partial_fields(target_class, kwargs):
23
  return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
@@ -175,4 +175,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
175
  outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
176
  )
177
 
178
- demo.launch()
 
16
  import os
17
  import spaces
18
 
19
+ # folder_url = f"https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib"
20
+ # gdown.download_folder(url=folder_url, output="pretrained_weights", quiet=False)
21
 
22
  def partial_fields(target_class, kwargs):
23
  return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
 
175
  outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
176
  )
177
 
178
+ demo.launch()
src/config/inference_config.py CHANGED
@@ -5,6 +5,8 @@ config dataclass used for inference
5
  """
6
 
7
  import os.path as osp
 
 
8
  from dataclasses import dataclass
9
  from typing import Literal, Tuple
10
  from .base_config import PrintableConfig, make_abs_path
@@ -38,7 +40,7 @@ class InferenceConfig(PrintableConfig):
38
 
39
  flag_write_result: bool = True # whether to write output video
40
  flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
41
- mask_crop = None
42
  flag_write_gif: bool = False
43
  size_gif: int = 256
44
  ref_max_shape: int = 1280
 
5
  """
6
 
7
  import os.path as osp
8
+ import cv2
9
+ from numpy import ndarray
10
  from dataclasses import dataclass
11
  from typing import Literal, Tuple
12
  from .base_config import PrintableConfig, make_abs_path
 
40
 
41
  flag_write_result: bool = True # whether to write output video
42
  flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
43
+ mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
44
  flag_write_gif: bool = False
45
  size_gif: int = 256
46
  ref_max_shape: int = 1280
src/utils/crop.py CHANGED
@@ -409,4 +409,4 @@ def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
409
  dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
410
  result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
411
  result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
412
- return result
 
409
  dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
410
  result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
411
  result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
412
+ return result