Xintao commited on
Commit
e6ac7d7
·
1 Parent(s): 65c0e3c

clean version

Browse files
Files changed (7) hide show
  1. README.md +3 -3
  2. app.py +64 -30
  3. gfpgan_utils.py +0 -119
  4. gfpganv1_clean_arch.py +0 -325
  5. realesrgan_utils.py +0 -281
  6. srvgg_arch.py +0 -67
  7. stylegan2_clean_arch.py +0 -369
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: GFPGAN
3
- emoji:
4
  colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.1.7
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: GFPGAN
3
+ emoji: 😁
4
  colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.1.8
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -3,76 +3,110 @@ import os
3
  import cv2
4
  import gradio as gr
5
  import torch
 
 
 
6
 
7
- from realesrgan_utils import RealESRGANer
8
- from srvgg_arch import SRVGGNetCompact
9
-
10
- os.system("pip freeze")
11
- os.system(
12
- "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights")
13
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./weights")
14
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./weights")
15
 
16
  torch.hub.download_url_to_file(
17
  'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
18
  'lincoln.jpg')
19
- torch.hub.download_url_to_file('https://upload.wikimedia.org/wikipedia/commons/5/50/Albert_Einstein_%28Nobel%29.png',
20
- 'einstein.png')
21
  torch.hub.download_url_to_file(
22
- 'https://upload.wikimedia.org/wikipedia/commons/thumb/9/9d/Thomas_Edison2.jpg/1024px-Thomas_Edison2.jpg',
23
- 'edison.jpg')
24
  torch.hub.download_url_to_file(
25
- 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a9/Henry_Ford_1888.jpg/1024px-Henry_Ford_1888.jpg',
26
- 'Henry.jpg')
27
  torch.hub.download_url_to_file(
28
- 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/06/Frida_Kahlo%2C_by_Guillermo_Kahlo.jpg/800px-Frida_Kahlo%2C_by_Guillermo_Kahlo.jpg',
29
- 'Frida.jpg')
30
 
31
  # determine models according to model names
 
 
 
32
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
33
  netscale = 4
34
  model_path = os.path.join('weights', 'realesr-general-x4v3.pth')
35
-
36
- # restorer
37
  half = True if torch.cuda.is_available() else False
38
  upsampler = RealESRGANer(scale=netscale, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
39
 
40
  # Use GFPGAN for face enhancement
41
- from gfpgan_utils import GFPGANer
42
-
43
- face_enhancer = GFPGANer(
44
  model_path='weights/GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
 
45
  os.makedirs('output', exist_ok=True)
46
 
47
-
48
- def inference(img, scale=2):
49
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
 
 
 
50
 
51
  h, w = img.shape[0:2]
52
  if h < 400:
53
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
54
 
 
 
 
 
55
  try:
56
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
57
  except RuntimeError as error:
58
  print('Error', error)
59
  print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
60
  else:
 
61
  extension = 'png'
62
  if scale != 2:
 
63
  h, w = img.shape[0:2]
64
- output = cv2.resize((int(w * scale /2), int(h * scale/2)), interpolation=cv2.INTER_LINEAR)
 
 
 
 
 
 
 
65
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
66
- return output
67
 
68
 
69
  title = "GFPGAN: Practical Face Restoration Algorithm"
70
- description = "Gradio demo for GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once"
71
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2101.04061' target='_blank'>Towards Real-World Blind Face Restoration with Generative Facial Prior</a> | <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  gr.Interface(
73
- inference, [gr.inputs.Image(type="filepath", label="Input"), gr.Number(lable="Rescaling factor", precision=2)],
74
- gr.outputs.Image(type="numpy", label="Output (The whole image)"),
 
 
 
 
75
  title=title,
76
  description=description,
77
  article=article,
78
- examples=[['lincoln.jpg', 2], ['einstein.png', 2], ['edison.jpg', 2], ['Henry.jpg', 2], ['Frida.jpg', 2]]).launch()
 
3
  import cv2
4
  import gradio as gr
5
  import torch
6
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
7
+ from gfpgan.utils import GFPGANer
8
+ from realesrgan.utils import RealESRGANer
9
 
10
+ # os.system("pip freeze")
11
+ # os.system(
12
+ # "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights")
13
+ # os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./weights")
14
+ # os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./weights")
 
 
 
15
 
16
  torch.hub.download_url_to_file(
17
  'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
18
  'lincoln.jpg')
 
 
19
  torch.hub.download_url_to_file(
20
+ 'https://user-images.githubusercontent.com/17445847/187400315-87a90ac9-d231-45d6-b377-38702bd1838f.jpg',
21
+ 'AI-generate.jpg')
22
  torch.hub.download_url_to_file(
23
+ 'https://user-images.githubusercontent.com/17445847/187400981-8a58f7a4-ef61-42d9-af80-bc6234cef860.jpg',
24
+ 'Blake_Lively.jpg')
25
  torch.hub.download_url_to_file(
26
+ 'https://user-images.githubusercontent.com/17445847/187401133-8a3bf269-5b4d-4432-b2f0-6d26ee1d3307.png',
27
+ '10045.jpg')
28
 
29
  # determine models according to model names
30
+
31
+
32
+ # background enhancer with RealESRGAN
33
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
34
  netscale = 4
35
  model_path = os.path.join('weights', 'realesr-general-x4v3.pth')
 
 
36
  half = True if torch.cuda.is_available() else False
37
  upsampler = RealESRGANer(scale=netscale, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
38
 
39
  # Use GFPGAN for face enhancement
40
+ face_enhancer_v3 = GFPGANer(
 
 
41
  model_path='weights/GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
42
+ face_enhancer_v2 = GFPGANer(
43
+ model_path='weights/GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
44
  os.makedirs('output', exist_ok=True)
45
 
46
+ def inference(img, version, scale):
47
+ print(torch.cuda.is_available())
48
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
49
+ if len(img.shape) == 3 and img.shape[2] == 4:
50
+ img_mode = 'RGBA'
51
+ else:
52
+ img_mode = None
53
 
54
  h, w = img.shape[0:2]
55
  if h < 400:
56
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
57
 
58
+ if version == 'v1.2':
59
+ face_enhancer = face_enhancer_v2
60
+ else:
61
+ face_enhancer = face_enhancer_v3
62
  try:
63
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
64
  except RuntimeError as error:
65
  print('Error', error)
66
  print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
67
  else:
68
+
69
  extension = 'png'
70
  if scale != 2:
71
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
72
  h, w = img.shape[0:2]
73
+ output = cv2.resize(output, (int(w * scale /2), int(h * scale/2)), interpolation=interpolation)
74
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
75
+ extension = 'png'
76
+ else:
77
+ extension = 'jpg'
78
+ save_path = f'output/out.{extension}'
79
+ cv2.imwrite(save_path, output)
80
+
81
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
82
+ return output, save_path
83
 
84
 
85
  title = "GFPGAN: Practical Face Restoration Algorithm"
86
+ description = r"""
87
+ Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>. <br>
88
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
89
+ It can be used to: <br>
90
+ - Upsample/Restore your **old photos**
91
+ - Upsample/Improve **AI-generated faces**
92
+
93
+ To use it, simply upload your image. Please click submit only once.
94
+ """
95
+ article = r"""<p style='text-align: center'><a href='https://arxiv.org/abs/2101.04061' target='_blank'>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</a> | <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>
96
+
97
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
98
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
99
+ [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
100
+
101
+ """
102
  gr.Interface(
103
+ inference,
104
+ [gr.inputs.Image(type="filepath", label="Input"),
105
+ gr.inputs.Radio(['v1.2','v1.3'], type="value", default='v1.3', label='GFPGAN version'),
106
+ gr.inputs.Number(label="Rescaling factor", default=2)],
107
+ [gr.outputs.Image(type="numpy", label="Output (The whole image)"),
108
+ gr.outputs.File(label="Download the output image")],
109
  title=title,
110
  description=description,
111
  article=article,
112
+ examples=[['AI-generate.jpg', 'v1.3', 2], ['lincoln.png', 'v1.3',2], ['Blake_Lively.jpg', 'v1.3',2], ['10045.jpg', 'v1.3',2]).launch()
gfpgan_utils.py DELETED
@@ -1,119 +0,0 @@
1
- import os
2
-
3
- import cv2
4
- import torch
5
- from basicsr.utils import img2tensor, tensor2img
6
- from basicsr.utils.download_util import load_file_from_url
7
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
8
- from torchvision.transforms.functional import normalize
9
-
10
- from gfpganv1_clean_arch import GFPGANv1Clean
11
-
12
- ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
-
14
-
15
- class GFPGANer():
16
- """Helper for restoration with GFPGAN.
17
-
18
- It will detect and crop faces, and then resize the faces to 512x512.
19
- GFPGAN is used to restored the resized faces.
20
- The background is upsampled with the bg_upsampler.
21
- Finally, the faces will be pasted back to the upsample background image.
22
-
23
- Args:
24
- model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
25
- upscale (float): The upscale of the final output. Default: 2.
26
- arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
27
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
28
- bg_upsampler (nn.Module): The upsampler for the background. Default: None.
29
- """
30
-
31
- def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
32
- self.upscale = upscale
33
- self.bg_upsampler = bg_upsampler
34
-
35
- # initialize model
36
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
37
- # initialize the GFP-GAN
38
- self.gfpgan = GFPGANv1Clean(
39
- out_size=512,
40
- num_style_feat=512,
41
- channel_multiplier=channel_multiplier,
42
- decoder_load_path=None,
43
- fix_decoder=False,
44
- num_mlp=8,
45
- input_is_latent=True,
46
- different_w=True,
47
- narrow=1,
48
- sft_half=True)
49
-
50
- # initialize face helper
51
- self.face_helper = FaceRestoreHelper(
52
- upscale,
53
- face_size=512,
54
- crop_ratio=(1, 1),
55
- det_model='retinaface_resnet50',
56
- save_ext='png',
57
- use_parse=True,
58
- device=self.device)
59
-
60
- if model_path.startswith('https://'):
61
- model_path = load_file_from_url(
62
- url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
63
- loadnet = torch.load(model_path)
64
- if 'params_ema' in loadnet:
65
- keyname = 'params_ema'
66
- else:
67
- keyname = 'params'
68
- self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
69
- self.gfpgan.eval()
70
- self.gfpgan = self.gfpgan.to(self.device)
71
-
72
- @torch.no_grad()
73
- def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
74
- self.face_helper.clean_all()
75
-
76
- if has_aligned: # the inputs are already aligned
77
- img = cv2.resize(img, (512, 512))
78
- self.face_helper.cropped_faces = [img]
79
- else:
80
- self.face_helper.read_image(img)
81
- # get face landmarks for each face
82
- self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
83
- # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
84
- # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
85
- # align and warp each face
86
- self.face_helper.align_warp_face()
87
-
88
- # face restoration
89
- for cropped_face in self.face_helper.cropped_faces:
90
- # prepare data
91
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
92
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
93
- cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
94
-
95
- try:
96
- output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
97
- # convert to image
98
- restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
99
- except RuntimeError as error:
100
- print(f'\tFailed inference for GFPGAN: {error}.')
101
- restored_face = cropped_face
102
-
103
- restored_face = restored_face.astype('uint8')
104
- self.face_helper.add_restored_face(restored_face)
105
-
106
- if not has_aligned and paste_back:
107
- # upsample the background
108
- if self.bg_upsampler is not None:
109
- # Now only support RealESRGAN for upsampling background
110
- bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
111
- else:
112
- bg_img = None
113
-
114
- self.face_helper.get_inverse_affine(None)
115
- # paste each restored face to the input image
116
- restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
117
- return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
118
- else:
119
- return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gfpganv1_clean_arch.py DELETED
@@ -1,325 +0,0 @@
1
- import math
2
- import random
3
-
4
- import torch
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from stylegan2_clean_arch import StyleGAN2GeneratorClean
10
-
11
-
12
- class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
13
- """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
-
15
- It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
16
-
17
- Args:
18
- out_size (int): The spatial size of outputs.
19
- num_style_feat (int): Channel number of style features. Default: 512.
20
- num_mlp (int): Layer number of MLP style layers. Default: 8.
21
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
22
- narrow (float): The narrow ratio for channels. Default: 1.
23
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
24
- """
25
-
26
- def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
27
- super(StyleGAN2GeneratorCSFT, self).__init__(
28
- out_size,
29
- num_style_feat=num_style_feat,
30
- num_mlp=num_mlp,
31
- channel_multiplier=channel_multiplier,
32
- narrow=narrow)
33
- self.sft_half = sft_half
34
-
35
- def forward(self,
36
- styles,
37
- conditions,
38
- input_is_latent=False,
39
- noise=None,
40
- randomize_noise=True,
41
- truncation=1,
42
- truncation_latent=None,
43
- inject_index=None,
44
- return_latents=False):
45
- """Forward function for StyleGAN2GeneratorCSFT.
46
-
47
- Args:
48
- styles (list[Tensor]): Sample codes of styles.
49
- conditions (list[Tensor]): SFT conditions to generators.
50
- input_is_latent (bool): Whether input is latent style. Default: False.
51
- noise (Tensor | None): Input noise or None. Default: None.
52
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
53
- truncation (float): The truncation ratio. Default: 1.
54
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
55
- inject_index (int | None): The injection index for mixing noise. Default: None.
56
- return_latents (bool): Whether to return style latents. Default: False.
57
- """
58
- # style codes -> latents with Style MLP layer
59
- if not input_is_latent:
60
- styles = [self.style_mlp(s) for s in styles]
61
- # noises
62
- if noise is None:
63
- if randomize_noise:
64
- noise = [None] * self.num_layers # for each style conv layer
65
- else: # use the stored noise
66
- noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
67
- # style truncation
68
- if truncation < 1:
69
- style_truncation = []
70
- for style in styles:
71
- style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
72
- styles = style_truncation
73
- # get style latents with injection
74
- if len(styles) == 1:
75
- inject_index = self.num_latent
76
-
77
- if styles[0].ndim < 3:
78
- # repeat latent code for all the layers
79
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
80
- else: # used for encoder with different latent code for each layer
81
- latent = styles[0]
82
- elif len(styles) == 2: # mixing noises
83
- if inject_index is None:
84
- inject_index = random.randint(1, self.num_latent - 1)
85
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
86
- latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
87
- latent = torch.cat([latent1, latent2], 1)
88
-
89
- # main generation
90
- out = self.constant_input(latent.shape[0])
91
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
92
- skip = self.to_rgb1(out, latent[:, 1])
93
-
94
- i = 1
95
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
96
- noise[2::2], self.to_rgbs):
97
- out = conv1(out, latent[:, i], noise=noise1)
98
-
99
- # the conditions may have fewer levels
100
- if i < len(conditions):
101
- # SFT part to combine the conditions
102
- if self.sft_half: # only apply SFT to half of the channels
103
- out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
104
- out_sft = out_sft * conditions[i - 1] + conditions[i]
105
- out = torch.cat([out_same, out_sft], dim=1)
106
- else: # apply SFT to all the channels
107
- out = out * conditions[i - 1] + conditions[i]
108
-
109
- out = conv2(out, latent[:, i + 1], noise=noise2)
110
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
111
- i += 2
112
-
113
- image = skip
114
-
115
- if return_latents:
116
- return image, latent
117
- else:
118
- return image, None
119
-
120
-
121
- class ResBlock(nn.Module):
122
- """Residual block with bilinear upsampling/downsampling.
123
-
124
- Args:
125
- in_channels (int): Channel number of the input.
126
- out_channels (int): Channel number of the output.
127
- mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
128
- """
129
-
130
- def __init__(self, in_channels, out_channels, mode='down'):
131
- super(ResBlock, self).__init__()
132
-
133
- self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
134
- self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
135
- self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
136
- if mode == 'down':
137
- self.scale_factor = 0.5
138
- elif mode == 'up':
139
- self.scale_factor = 2
140
-
141
- def forward(self, x):
142
- out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
143
- # upsample/downsample
144
- out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
145
- out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
146
- # skip
147
- x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
148
- skip = self.skip(x)
149
- out = out + skip
150
- return out
151
-
152
-
153
- @ARCH_REGISTRY.register()
154
- class GFPGANv1Clean(nn.Module):
155
- """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
156
-
157
- It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
158
-
159
- Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
160
-
161
- Args:
162
- out_size (int): The spatial size of outputs.
163
- num_style_feat (int): Channel number of style features. Default: 512.
164
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
165
- decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
166
- fix_decoder (bool): Whether to fix the decoder. Default: True.
167
-
168
- num_mlp (int): Layer number of MLP style layers. Default: 8.
169
- input_is_latent (bool): Whether input is latent style. Default: False.
170
- different_w (bool): Whether to use different latent w for different layers. Default: False.
171
- narrow (float): The narrow ratio for channels. Default: 1.
172
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
173
- """
174
-
175
- def __init__(
176
- self,
177
- out_size,
178
- num_style_feat=512,
179
- channel_multiplier=1,
180
- decoder_load_path=None,
181
- fix_decoder=True,
182
- # for stylegan decoder
183
- num_mlp=8,
184
- input_is_latent=False,
185
- different_w=False,
186
- narrow=1,
187
- sft_half=False):
188
-
189
- super(GFPGANv1Clean, self).__init__()
190
- self.input_is_latent = input_is_latent
191
- self.different_w = different_w
192
- self.num_style_feat = num_style_feat
193
-
194
- unet_narrow = narrow * 0.5 # by default, use a half of input channels
195
- channels = {
196
- '4': int(512 * unet_narrow),
197
- '8': int(512 * unet_narrow),
198
- '16': int(512 * unet_narrow),
199
- '32': int(512 * unet_narrow),
200
- '64': int(256 * channel_multiplier * unet_narrow),
201
- '128': int(128 * channel_multiplier * unet_narrow),
202
- '256': int(64 * channel_multiplier * unet_narrow),
203
- '512': int(32 * channel_multiplier * unet_narrow),
204
- '1024': int(16 * channel_multiplier * unet_narrow)
205
- }
206
-
207
- self.log_size = int(math.log(out_size, 2))
208
- first_out_size = 2**(int(math.log(out_size, 2)))
209
-
210
- self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
211
-
212
- # downsample
213
- in_channels = channels[f'{first_out_size}']
214
- self.conv_body_down = nn.ModuleList()
215
- for i in range(self.log_size, 2, -1):
216
- out_channels = channels[f'{2**(i - 1)}']
217
- self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
218
- in_channels = out_channels
219
-
220
- self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
221
-
222
- # upsample
223
- in_channels = channels['4']
224
- self.conv_body_up = nn.ModuleList()
225
- for i in range(3, self.log_size + 1):
226
- out_channels = channels[f'{2**i}']
227
- self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
228
- in_channels = out_channels
229
-
230
- # to RGB
231
- self.toRGB = nn.ModuleList()
232
- for i in range(3, self.log_size + 1):
233
- self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
234
-
235
- if different_w:
236
- linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
237
- else:
238
- linear_out_channel = num_style_feat
239
-
240
- self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
241
-
242
- # the decoder: stylegan2 generator with SFT modulations
243
- self.stylegan_decoder = StyleGAN2GeneratorCSFT(
244
- out_size=out_size,
245
- num_style_feat=num_style_feat,
246
- num_mlp=num_mlp,
247
- channel_multiplier=channel_multiplier,
248
- narrow=narrow,
249
- sft_half=sft_half)
250
-
251
- # load pre-trained stylegan2 model if necessary
252
- if decoder_load_path:
253
- self.stylegan_decoder.load_state_dict(
254
- torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
255
- # fix decoder without updating params
256
- if fix_decoder:
257
- for _, param in self.stylegan_decoder.named_parameters():
258
- param.requires_grad = False
259
-
260
- # for SFT modulations (scale and shift)
261
- self.condition_scale = nn.ModuleList()
262
- self.condition_shift = nn.ModuleList()
263
- for i in range(3, self.log_size + 1):
264
- out_channels = channels[f'{2**i}']
265
- if sft_half:
266
- sft_out_channels = out_channels
267
- else:
268
- sft_out_channels = out_channels * 2
269
- self.condition_scale.append(
270
- nn.Sequential(
271
- nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
272
- nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
273
- self.condition_shift.append(
274
- nn.Sequential(
275
- nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
276
- nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
277
-
278
- def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
279
- """Forward function for GFPGANv1Clean.
280
-
281
- Args:
282
- x (Tensor): Input images.
283
- return_latents (bool): Whether to return style latents. Default: False.
284
- return_rgb (bool): Whether return intermediate rgb images. Default: True.
285
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
286
- """
287
- conditions = []
288
- unet_skips = []
289
- out_rgbs = []
290
-
291
- # encoder
292
- feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
293
- for i in range(self.log_size - 2):
294
- feat = self.conv_body_down[i](feat)
295
- unet_skips.insert(0, feat)
296
- feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
297
-
298
- # style code
299
- style_code = self.final_linear(feat.view(feat.size(0), -1))
300
- if self.different_w:
301
- style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
302
-
303
- # decode
304
- for i in range(self.log_size - 2):
305
- # add unet skip
306
- feat = feat + unet_skips[i]
307
- # ResUpLayer
308
- feat = self.conv_body_up[i](feat)
309
- # generate scale and shift for SFT layers
310
- scale = self.condition_scale[i](feat)
311
- conditions.append(scale.clone())
312
- shift = self.condition_shift[i](feat)
313
- conditions.append(shift.clone())
314
- # generate rgb images
315
- if return_rgb:
316
- out_rgbs.append(self.toRGB[i](feat))
317
-
318
- # decoder
319
- image, _ = self.stylegan_decoder([style_code],
320
- conditions,
321
- return_latents=return_latents,
322
- input_is_latent=self.input_is_latent,
323
- randomize_noise=randomize_noise)
324
-
325
- return image, out_rgbs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
realesrgan_utils.py DELETED
@@ -1,281 +0,0 @@
1
- import math
2
- import os
3
- import queue
4
- import threading
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- from basicsr.utils.download_util import load_file_from_url
10
- from torch.nn import functional as F
11
-
12
- ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
-
14
-
15
- class RealESRGANer():
16
- """A helper class for upsampling images with RealESRGAN.
17
-
18
- Args:
19
- scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
20
- model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
21
- model (nn.Module): The defined network. Default: None.
22
- tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
23
- input images into tiles, and then process each of them. Finally, they will be merged into one image.
24
- 0 denotes for do not use tile. Default: 0.
25
- tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
26
- pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
27
- half (float): Whether to use half precision during inference. Default: False.
28
- """
29
-
30
- def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
31
- self.scale = scale
32
- self.tile_size = tile
33
- self.tile_pad = tile_pad
34
- self.pre_pad = pre_pad
35
- self.mod_scale = None
36
- self.half = half
37
-
38
- # initialize model
39
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
- # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
41
- if model_path.startswith('https://'):
42
- model_path = load_file_from_url(
43
- url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
44
- loadnet = torch.load(model_path, map_location=torch.device('cpu'))
45
- # prefer to use params_ema
46
- if 'params_ema' in loadnet:
47
- keyname = 'params_ema'
48
- else:
49
- keyname = 'params'
50
- model.load_state_dict(loadnet[keyname], strict=True)
51
- model.eval()
52
- self.model = model.to(self.device)
53
- if self.half:
54
- self.model = self.model.half()
55
-
56
- def pre_process(self, img):
57
- """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
58
- """
59
- img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
60
- self.img = img.unsqueeze(0).to(self.device)
61
- if self.half:
62
- self.img = self.img.half()
63
-
64
- # pre_pad
65
- if self.pre_pad != 0:
66
- self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
67
- # mod pad for divisible borders
68
- if self.scale == 2:
69
- self.mod_scale = 2
70
- elif self.scale == 1:
71
- self.mod_scale = 4
72
- if self.mod_scale is not None:
73
- self.mod_pad_h, self.mod_pad_w = 0, 0
74
- _, _, h, w = self.img.size()
75
- if (h % self.mod_scale != 0):
76
- self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
77
- if (w % self.mod_scale != 0):
78
- self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
79
- self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
80
-
81
- def process(self):
82
- # model inference
83
- self.output = self.model(self.img)
84
-
85
- def tile_process(self):
86
- """It will first crop input images to tiles, and then process each tile.
87
- Finally, all the processed tiles are merged into one images.
88
-
89
- Modified from: https://github.com/ata4/esrgan-launcher
90
- """
91
- batch, channel, height, width = self.img.shape
92
- output_height = height * self.scale
93
- output_width = width * self.scale
94
- output_shape = (batch, channel, output_height, output_width)
95
-
96
- # start with black image
97
- self.output = self.img.new_zeros(output_shape)
98
- tiles_x = math.ceil(width / self.tile_size)
99
- tiles_y = math.ceil(height / self.tile_size)
100
-
101
- # loop over all tiles
102
- for y in range(tiles_y):
103
- for x in range(tiles_x):
104
- # extract tile from input image
105
- ofs_x = x * self.tile_size
106
- ofs_y = y * self.tile_size
107
- # input tile area on total image
108
- input_start_x = ofs_x
109
- input_end_x = min(ofs_x + self.tile_size, width)
110
- input_start_y = ofs_y
111
- input_end_y = min(ofs_y + self.tile_size, height)
112
-
113
- # input tile area on total image with padding
114
- input_start_x_pad = max(input_start_x - self.tile_pad, 0)
115
- input_end_x_pad = min(input_end_x + self.tile_pad, width)
116
- input_start_y_pad = max(input_start_y - self.tile_pad, 0)
117
- input_end_y_pad = min(input_end_y + self.tile_pad, height)
118
-
119
- # input tile dimensions
120
- input_tile_width = input_end_x - input_start_x
121
- input_tile_height = input_end_y - input_start_y
122
- tile_idx = y * tiles_x + x + 1
123
- input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
124
-
125
- # upscale tile
126
- try:
127
- with torch.no_grad():
128
- output_tile = self.model(input_tile)
129
- except RuntimeError as error:
130
- print('Error', error)
131
- print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
132
-
133
- # output tile area on total image
134
- output_start_x = input_start_x * self.scale
135
- output_end_x = input_end_x * self.scale
136
- output_start_y = input_start_y * self.scale
137
- output_end_y = input_end_y * self.scale
138
-
139
- # output tile area without padding
140
- output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
141
- output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
142
- output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
143
- output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
144
-
145
- # put tile into output image
146
- self.output[:, :, output_start_y:output_end_y,
147
- output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
148
- output_start_x_tile:output_end_x_tile]
149
-
150
- def post_process(self):
151
- # remove extra pad
152
- if self.mod_scale is not None:
153
- _, _, h, w = self.output.size()
154
- self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
155
- # remove prepad
156
- if self.pre_pad != 0:
157
- _, _, h, w = self.output.size()
158
- self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
159
- return self.output
160
-
161
- @torch.no_grad()
162
- def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
163
- h_input, w_input = img.shape[0:2]
164
- # img: numpy
165
- img = img.astype(np.float32)
166
- if np.max(img) > 256: # 16-bit image
167
- max_range = 65535
168
- print('\tInput is a 16-bit image')
169
- else:
170
- max_range = 255
171
- img = img / max_range
172
- if len(img.shape) == 2: # gray image
173
- img_mode = 'L'
174
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
175
- elif img.shape[2] == 4: # RGBA image with alpha channel
176
- img_mode = 'RGBA'
177
- alpha = img[:, :, 3]
178
- img = img[:, :, 0:3]
179
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
180
- if alpha_upsampler == 'realesrgan':
181
- alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
182
- else:
183
- img_mode = 'RGB'
184
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
185
-
186
- # ------------------- process image (without the alpha channel) ------------------- #
187
- self.pre_process(img)
188
- if self.tile_size > 0:
189
- self.tile_process()
190
- else:
191
- self.process()
192
- output_img = self.post_process()
193
- output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
194
- output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
195
- if img_mode == 'L':
196
- output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
197
-
198
- # ------------------- process the alpha channel if necessary ------------------- #
199
- if img_mode == 'RGBA':
200
- if alpha_upsampler == 'realesrgan':
201
- self.pre_process(alpha)
202
- if self.tile_size > 0:
203
- self.tile_process()
204
- else:
205
- self.process()
206
- output_alpha = self.post_process()
207
- output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
208
- output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
209
- output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
210
- else: # use the cv2 resize for alpha channel
211
- h, w = alpha.shape[0:2]
212
- output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
213
-
214
- # merge the alpha channel
215
- output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
216
- output_img[:, :, 3] = output_alpha
217
-
218
- # ------------------------------ return ------------------------------ #
219
- if max_range == 65535: # 16-bit image
220
- output = (output_img * 65535.0).round().astype(np.uint16)
221
- else:
222
- output = (output_img * 255.0).round().astype(np.uint8)
223
-
224
- if outscale is not None and outscale != float(self.scale):
225
- output = cv2.resize(
226
- output, (
227
- int(w_input * outscale),
228
- int(h_input * outscale),
229
- ), interpolation=cv2.INTER_LANCZOS4)
230
-
231
- return output, img_mode
232
-
233
-
234
- class PrefetchReader(threading.Thread):
235
- """Prefetch images.
236
-
237
- Args:
238
- img_list (list[str]): A image list of image paths to be read.
239
- num_prefetch_queue (int): Number of prefetch queue.
240
- """
241
-
242
- def __init__(self, img_list, num_prefetch_queue):
243
- super().__init__()
244
- self.que = queue.Queue(num_prefetch_queue)
245
- self.img_list = img_list
246
-
247
- def run(self):
248
- for img_path in self.img_list:
249
- img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
250
- self.que.put(img)
251
-
252
- self.que.put(None)
253
-
254
- def __next__(self):
255
- next_item = self.que.get()
256
- if next_item is None:
257
- raise StopIteration
258
- return next_item
259
-
260
- def __iter__(self):
261
- return self
262
-
263
-
264
- class IOConsumer(threading.Thread):
265
-
266
- def __init__(self, opt, que, qid):
267
- super().__init__()
268
- self._queue = que
269
- self.qid = qid
270
- self.opt = opt
271
-
272
- def run(self):
273
- while True:
274
- msg = self._queue.get()
275
- if isinstance(msg, str) and msg == 'quit':
276
- break
277
-
278
- output = msg['output']
279
- save_path = msg['save_path']
280
- cv2.imwrite(save_path, output)
281
- print(f'IO worker {self.qid} is done.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
srvgg_arch.py DELETED
@@ -1,67 +0,0 @@
1
- from torch import nn as nn
2
- from torch.nn import functional as F
3
-
4
-
5
- class SRVGGNetCompact(nn.Module):
6
- """A compact VGG-style network structure for super-resolution.
7
-
8
- It is a compact network structure, which performs upsampling in the last layer and no convolution is
9
- conducted on the HR feature space.
10
-
11
- Args:
12
- num_in_ch (int): Channel number of inputs. Default: 3.
13
- num_out_ch (int): Channel number of outputs. Default: 3.
14
- num_feat (int): Channel number of intermediate features. Default: 64.
15
- num_conv (int): Number of convolution layers in the body network. Default: 16.
16
- upscale (int): Upsampling factor. Default: 4.
17
- act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
18
- """
19
-
20
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
21
- super(SRVGGNetCompact, self).__init__()
22
- self.num_in_ch = num_in_ch
23
- self.num_out_ch = num_out_ch
24
- self.num_feat = num_feat
25
- self.num_conv = num_conv
26
- self.upscale = upscale
27
- self.act_type = act_type
28
-
29
- self.body = nn.ModuleList()
30
- # the first conv
31
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
32
- # the first activation
33
- if act_type == 'relu':
34
- activation = nn.ReLU(inplace=True)
35
- elif act_type == 'prelu':
36
- activation = nn.PReLU(num_parameters=num_feat)
37
- elif act_type == 'leakyrelu':
38
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
39
- self.body.append(activation)
40
-
41
- # the body structure
42
- for _ in range(num_conv):
43
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
44
- # activation
45
- if act_type == 'relu':
46
- activation = nn.ReLU(inplace=True)
47
- elif act_type == 'prelu':
48
- activation = nn.PReLU(num_parameters=num_feat)
49
- elif act_type == 'leakyrelu':
50
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
51
- self.body.append(activation)
52
-
53
- # the last conv
54
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
55
- # upsample
56
- self.upsampler = nn.PixelShuffle(upscale)
57
-
58
- def forward(self, x):
59
- out = x
60
- for i in range(0, len(self.body)):
61
- out = self.body[i](out)
62
-
63
- out = self.upsampler(out)
64
- # add the nearest upsampled image, so that the network learns the residual
65
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
66
- out += base
67
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stylegan2_clean_arch.py DELETED
@@ -1,369 +0,0 @@
1
- import math
2
- import random
3
-
4
- import torch
5
- from basicsr.archs.arch_util import default_init_weights
6
- from basicsr.utils.registry import ARCH_REGISTRY
7
- from torch import nn
8
- from torch.nn import functional as F
9
-
10
-
11
- class NormStyleCode(nn.Module):
12
-
13
- def forward(self, x):
14
- """Normalize the style codes.
15
-
16
- Args:
17
- x (Tensor): Style codes with shape (b, c).
18
-
19
- Returns:
20
- Tensor: Normalized tensor.
21
- """
22
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
23
-
24
-
25
- class ModulatedConv2d(nn.Module):
26
- """Modulated Conv2d used in StyleGAN2.
27
-
28
- There is no bias in ModulatedConv2d.
29
-
30
- Args:
31
- in_channels (int): Channel number of the input.
32
- out_channels (int): Channel number of the output.
33
- kernel_size (int): Size of the convolving kernel.
34
- num_style_feat (int): Channel number of style features.
35
- demodulate (bool): Whether to demodulate in the conv layer. Default: True.
36
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
37
- eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
38
- """
39
-
40
- def __init__(self,
41
- in_channels,
42
- out_channels,
43
- kernel_size,
44
- num_style_feat,
45
- demodulate=True,
46
- sample_mode=None,
47
- eps=1e-8):
48
- super(ModulatedConv2d, self).__init__()
49
- self.in_channels = in_channels
50
- self.out_channels = out_channels
51
- self.kernel_size = kernel_size
52
- self.demodulate = demodulate
53
- self.sample_mode = sample_mode
54
- self.eps = eps
55
-
56
- # modulation inside each modulated conv
57
- self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
58
- # initialization
59
- default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
60
-
61
- self.weight = nn.Parameter(
62
- torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
63
- math.sqrt(in_channels * kernel_size**2))
64
- self.padding = kernel_size // 2
65
-
66
- def forward(self, x, style):
67
- """Forward function.
68
-
69
- Args:
70
- x (Tensor): Tensor with shape (b, c, h, w).
71
- style (Tensor): Tensor with shape (b, num_style_feat).
72
-
73
- Returns:
74
- Tensor: Modulated tensor after convolution.
75
- """
76
- b, c, h, w = x.shape # c = c_in
77
- # weight modulation
78
- style = self.modulation(style).view(b, 1, c, 1, 1)
79
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
80
- weight = self.weight * style # (b, c_out, c_in, k, k)
81
-
82
- if self.demodulate:
83
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
84
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
85
-
86
- weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
87
-
88
- # upsample or downsample if necessary
89
- if self.sample_mode == 'upsample':
90
- x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
91
- elif self.sample_mode == 'downsample':
92
- x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
93
-
94
- b, c, h, w = x.shape
95
- x = x.view(1, b * c, h, w)
96
- # weight: (b*c_out, c_in, k, k), groups=b
97
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
98
- out = out.view(b, self.out_channels, *out.shape[2:4])
99
-
100
- return out
101
-
102
- def __repr__(self):
103
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
104
- f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
105
-
106
-
107
- class StyleConv(nn.Module):
108
- """Style conv used in StyleGAN2.
109
-
110
- Args:
111
- in_channels (int): Channel number of the input.
112
- out_channels (int): Channel number of the output.
113
- kernel_size (int): Size of the convolving kernel.
114
- num_style_feat (int): Channel number of style features.
115
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
116
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
117
- """
118
-
119
- def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
120
- super(StyleConv, self).__init__()
121
- self.modulated_conv = ModulatedConv2d(
122
- in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
123
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
124
- self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
125
- self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
126
-
127
- def forward(self, x, style, noise=None):
128
- # modulate
129
- out = self.modulated_conv(x, style) * 2**0.5 # for conversion
130
- # noise injection
131
- if noise is None:
132
- b, _, h, w = out.shape
133
- noise = out.new_empty(b, 1, h, w).normal_()
134
- out = out + self.weight * noise
135
- # add bias
136
- out = out + self.bias
137
- # activation
138
- out = self.activate(out)
139
- return out
140
-
141
-
142
- class ToRGB(nn.Module):
143
- """To RGB (image space) from features.
144
-
145
- Args:
146
- in_channels (int): Channel number of input.
147
- num_style_feat (int): Channel number of style features.
148
- upsample (bool): Whether to upsample. Default: True.
149
- """
150
-
151
- def __init__(self, in_channels, num_style_feat, upsample=True):
152
- super(ToRGB, self).__init__()
153
- self.upsample = upsample
154
- self.modulated_conv = ModulatedConv2d(
155
- in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
156
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
157
-
158
- def forward(self, x, style, skip=None):
159
- """Forward function.
160
-
161
- Args:
162
- x (Tensor): Feature tensor with shape (b, c, h, w).
163
- style (Tensor): Tensor with shape (b, num_style_feat).
164
- skip (Tensor): Base/skip tensor. Default: None.
165
-
166
- Returns:
167
- Tensor: RGB images.
168
- """
169
- out = self.modulated_conv(x, style)
170
- out = out + self.bias
171
- if skip is not None:
172
- if self.upsample:
173
- skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
174
- out = out + skip
175
- return out
176
-
177
-
178
- class ConstantInput(nn.Module):
179
- """Constant input.
180
-
181
- Args:
182
- num_channel (int): Channel number of constant input.
183
- size (int): Spatial size of constant input.
184
- """
185
-
186
- def __init__(self, num_channel, size):
187
- super(ConstantInput, self).__init__()
188
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
189
-
190
- def forward(self, batch):
191
- out = self.weight.repeat(batch, 1, 1, 1)
192
- return out
193
-
194
-
195
- @ARCH_REGISTRY.register()
196
- class StyleGAN2GeneratorClean(nn.Module):
197
- """Clean version of StyleGAN2 Generator.
198
-
199
- Args:
200
- out_size (int): The spatial size of outputs.
201
- num_style_feat (int): Channel number of style features. Default: 512.
202
- num_mlp (int): Layer number of MLP style layers. Default: 8.
203
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
204
- narrow (float): Narrow ratio for channels. Default: 1.0.
205
- """
206
-
207
- def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
208
- super(StyleGAN2GeneratorClean, self).__init__()
209
- # Style MLP layers
210
- self.num_style_feat = num_style_feat
211
- style_mlp_layers = [NormStyleCode()]
212
- for i in range(num_mlp):
213
- style_mlp_layers.extend(
214
- [nn.Linear(num_style_feat, num_style_feat, bias=True),
215
- nn.LeakyReLU(negative_slope=0.2, inplace=True)])
216
- self.style_mlp = nn.Sequential(*style_mlp_layers)
217
- # initialization
218
- default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
219
-
220
- # channel list
221
- channels = {
222
- '4': int(512 * narrow),
223
- '8': int(512 * narrow),
224
- '16': int(512 * narrow),
225
- '32': int(512 * narrow),
226
- '64': int(256 * channel_multiplier * narrow),
227
- '128': int(128 * channel_multiplier * narrow),
228
- '256': int(64 * channel_multiplier * narrow),
229
- '512': int(32 * channel_multiplier * narrow),
230
- '1024': int(16 * channel_multiplier * narrow)
231
- }
232
- self.channels = channels
233
-
234
- self.constant_input = ConstantInput(channels['4'], size=4)
235
- self.style_conv1 = StyleConv(
236
- channels['4'],
237
- channels['4'],
238
- kernel_size=3,
239
- num_style_feat=num_style_feat,
240
- demodulate=True,
241
- sample_mode=None)
242
- self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
243
-
244
- self.log_size = int(math.log(out_size, 2))
245
- self.num_layers = (self.log_size - 2) * 2 + 1
246
- self.num_latent = self.log_size * 2 - 2
247
-
248
- self.style_convs = nn.ModuleList()
249
- self.to_rgbs = nn.ModuleList()
250
- self.noises = nn.Module()
251
-
252
- in_channels = channels['4']
253
- # noise
254
- for layer_idx in range(self.num_layers):
255
- resolution = 2**((layer_idx + 5) // 2)
256
- shape = [1, 1, resolution, resolution]
257
- self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
258
- # style convs and to_rgbs
259
- for i in range(3, self.log_size + 1):
260
- out_channels = channels[f'{2**i}']
261
- self.style_convs.append(
262
- StyleConv(
263
- in_channels,
264
- out_channels,
265
- kernel_size=3,
266
- num_style_feat=num_style_feat,
267
- demodulate=True,
268
- sample_mode='upsample'))
269
- self.style_convs.append(
270
- StyleConv(
271
- out_channels,
272
- out_channels,
273
- kernel_size=3,
274
- num_style_feat=num_style_feat,
275
- demodulate=True,
276
- sample_mode=None))
277
- self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
278
- in_channels = out_channels
279
-
280
- def make_noise(self):
281
- """Make noise for noise injection."""
282
- device = self.constant_input.weight.device
283
- noises = [torch.randn(1, 1, 4, 4, device=device)]
284
-
285
- for i in range(3, self.log_size + 1):
286
- for _ in range(2):
287
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
288
-
289
- return noises
290
-
291
- def get_latent(self, x):
292
- return self.style_mlp(x)
293
-
294
- def mean_latent(self, num_latent):
295
- latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
296
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
297
- return latent
298
-
299
- def forward(self,
300
- styles,
301
- input_is_latent=False,
302
- noise=None,
303
- randomize_noise=True,
304
- truncation=1,
305
- truncation_latent=None,
306
- inject_index=None,
307
- return_latents=False):
308
- """Forward function for StyleGAN2GeneratorClean.
309
-
310
- Args:
311
- styles (list[Tensor]): Sample codes of styles.
312
- input_is_latent (bool): Whether input is latent style. Default: False.
313
- noise (Tensor | None): Input noise or None. Default: None.
314
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
315
- truncation (float): The truncation ratio. Default: 1.
316
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
317
- inject_index (int | None): The injection index for mixing noise. Default: None.
318
- return_latents (bool): Whether to return style latents. Default: False.
319
- """
320
- # style codes -> latents with Style MLP layer
321
- if not input_is_latent:
322
- styles = [self.style_mlp(s) for s in styles]
323
- # noises
324
- if noise is None:
325
- if randomize_noise:
326
- noise = [None] * self.num_layers # for each style conv layer
327
- else: # use the stored noise
328
- noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
329
- # style truncation
330
- if truncation < 1:
331
- style_truncation = []
332
- for style in styles:
333
- style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
334
- styles = style_truncation
335
- # get style latents with injection
336
- if len(styles) == 1:
337
- inject_index = self.num_latent
338
-
339
- if styles[0].ndim < 3:
340
- # repeat latent code for all the layers
341
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
342
- else: # used for encoder with different latent code for each layer
343
- latent = styles[0]
344
- elif len(styles) == 2: # mixing noises
345
- if inject_index is None:
346
- inject_index = random.randint(1, self.num_latent - 1)
347
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
348
- latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
349
- latent = torch.cat([latent1, latent2], 1)
350
-
351
- # main generation
352
- out = self.constant_input(latent.shape[0])
353
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
354
- skip = self.to_rgb1(out, latent[:, 1])
355
-
356
- i = 1
357
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
358
- noise[2::2], self.to_rgbs):
359
- out = conv1(out, latent[:, i], noise=noise1)
360
- out = conv2(out, latent[:, i + 1], noise=noise2)
361
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
362
- i += 2
363
-
364
- image = skip
365
-
366
- if return_latents:
367
- return image, latent
368
- else:
369
- return image, None