The inference uses the auto_split_upscale mechanism.
Browse files1. auto_split_upscale is in the dataops.py file, and its source is from the ESRGAN project forked by authors joeyballentine and BlueAmulet.
2. The face model now supports RestoreFormer++, authored by wzhouxiff.
3. Added support for parsing older RRDB models from the ESRGAN project.
4. Added support for parsing DAT models from the DAT project by author zhengchen1999.
5. Added support for parsing HAT models from the HAT project by author XPixelGroup.
6. Added support for parsing RealPLKSR models from the PLKSR project by author dslisleedh & neosr-project.
- app.py +467 -222
- requirements.txt +6 -2
- utils/dataops.py +127 -0
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os
|
2 |
import gc
|
3 |
import cv2
|
4 |
import requests
|
@@ -6,244 +6,489 @@ import numpy as np
|
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
import traceback
|
9 |
-
from
|
10 |
-
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
11 |
-
from gfpgan.utils import GFPGANer
|
12 |
from realesrgan.utils import RealESRGANer
|
13 |
-
|
14 |
|
15 |
# Define URLs and their corresponding local storage paths
|
16 |
face_model = {
|
17 |
-
"GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
|
18 |
-
"GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
19 |
"GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
20 |
-
"RestoreFormer
|
21 |
-
"CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
|
|
|
|
|
|
|
|
|
22 |
}
|
23 |
realesr_model = {
|
24 |
-
|
25 |
-
"realesr-
|
26 |
-
"
|
|
|
|
|
27 |
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
28 |
"RealESRNet_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
|
29 |
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
}
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
print(
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
if os.path.exists(output_path):
|
70 |
-
print(f"File already exists, skipping download: {output_path}")
|
71 |
-
continue
|
72 |
-
|
73 |
-
# Start downloading
|
74 |
-
download_from_url(output_path, url)
|
75 |
-
|
76 |
-
|
77 |
-
def inference(img, version, realesr, scale: float):
|
78 |
-
print(img, version, scale)
|
79 |
-
try:
|
80 |
-
img_name = os.path.basename(str(img))
|
81 |
-
basename, extension = os.path.splitext(img_name)
|
82 |
-
img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
|
83 |
-
if len(img.shape) == 3 and img.shape[2] == 4:
|
84 |
-
img_mode = 'RGBA'
|
85 |
-
elif len(img.shape) == 2: # for gray inputs
|
86 |
-
img_mode = None
|
87 |
-
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
88 |
-
else:
|
89 |
-
img_mode = None
|
90 |
-
|
91 |
-
h, w = img.shape[0:2]
|
92 |
-
if h < 300:
|
93 |
-
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
download_from_url(os.path.join("weights", realesr), realesr_model[realesr])
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
elif realesr == 'RealESRGAN_x2plus.pth': # x2 RRDBNet model
|
111 |
-
netscale = 2
|
112 |
-
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
|
113 |
-
elif realesr == 'realesr-animevideov3.pth': # x4 VGG-style model (XS size)
|
114 |
-
netscale = 4
|
115 |
-
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=netscale, act_type='prelu')
|
116 |
-
elif realesr == 'realesr-general-x4v3.pth': # x4 VGG-style model (S size)
|
117 |
netscale = 4
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
model_path='weights/GFPGANv1.2.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
130 |
-
elif version == 'GFPGANv1.3.pth':
|
131 |
-
face_enhancer = GFPGANer(
|
132 |
-
model_path='weights/GFPGANv1.3.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
133 |
-
elif version == 'GFPGANv1.4.pth':
|
134 |
-
face_enhancer = GFPGANer(
|
135 |
-
model_path='weights/GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
136 |
-
elif version == 'RestoreFormer.pth':
|
137 |
-
face_enhancer = GFPGANer(
|
138 |
-
model_path='weights/RestoreFormer.pth', upscale=scale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
139 |
-
elif version == 'CodeFormer.pth':
|
140 |
-
face_enhancer = GFPGANer(
|
141 |
-
model_path='weights/CodeFormer.pth', upscale=scale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
142 |
-
|
143 |
-
files = []
|
144 |
-
outputs = []
|
145 |
-
try:
|
146 |
-
if face_enhancer:
|
147 |
-
cropped_faces, restored_aligned, restored_img = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
148 |
-
# save faces
|
149 |
-
if cropped_faces and restored_aligned:
|
150 |
-
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
|
151 |
-
# save cropped face
|
152 |
-
save_crop_path = f"output/{basename}{idx:02d}_cropped_faces.png"
|
153 |
-
cv2.imwrite(save_crop_path, cropped_face)
|
154 |
-
# save restored face
|
155 |
-
save_restore_path = f"output/{basename}{idx:02d}_restored_faces.png"
|
156 |
-
cv2.imwrite(save_restore_path, restored_face)
|
157 |
-
# save comparison image
|
158 |
-
save_cmp_path = f"output/{basename}{idx:02d}_cmp.png"
|
159 |
-
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
160 |
-
cv2.imwrite(save_cmp_path, cmp_img)
|
161 |
-
|
162 |
-
files.append(save_crop_path)
|
163 |
-
files.append(save_restore_path)
|
164 |
-
files.append(save_cmp_path)
|
165 |
-
outputs.append(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB))
|
166 |
-
outputs.append(cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB))
|
167 |
-
outputs.append(cv2.cvtColor(cmp_img, cv2.COLOR_BGR2RGB))
|
168 |
-
else:
|
169 |
-
restored_img, _ = upsampler.enhance(img, outscale=scale)
|
170 |
-
except RuntimeError as error:
|
171 |
-
print(traceback.format_exc())
|
172 |
-
print('Error', error)
|
173 |
-
finally:
|
174 |
-
if face_enhancer:
|
175 |
-
face_enhancer._cleanup()
|
176 |
-
else:
|
177 |
-
# Free GPU memory and clean up resources
|
178 |
-
torch.cuda.empty_cache()
|
179 |
-
gc.collect()
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
187 |
except Exception as error:
|
188 |
print(traceback.format_exc())
|
189 |
-
print("
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
[
|
215 |
-
|
216 |
-
""
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
#
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
-
demo.queue(default_concurrency_limit=4)
|
249 |
-
demo.launch(inbrowser=True)
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import gc
|
3 |
import cv2
|
4 |
import requests
|
|
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
import traceback
|
9 |
+
from facexlib.utils.misc import download_from_url
|
|
|
|
|
10 |
from realesrgan.utils import RealESRGANer
|
11 |
+
|
12 |
|
13 |
# Define URLs and their corresponding local storage paths
|
14 |
face_model = {
|
|
|
|
|
15 |
"GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
16 |
+
"RestoreFormer++.ckpt": "https://github.com/wzhouxiff/RestoreFormerPlusPlus/releases/download/v1.0.0/RestoreFormer++.ckpt",
|
17 |
+
# "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
|
18 |
+
# legacy model
|
19 |
+
"GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
20 |
+
"GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
|
21 |
+
"RestoreFormer.ckpt": "https://github.com/wzhouxiff/RestoreFormerPlusPlus/releases/download/v1.0.0/RestoreFormer.ckpt",
|
22 |
}
|
23 |
realesr_model = {
|
24 |
+
# SRVGGNet
|
25 |
+
"realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", # x4 SRVGGNet (S size)
|
26 |
+
"realesr-animevideov3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", # x4 SRVGGNet (XS size)
|
27 |
+
# RRDBNet
|
28 |
+
"RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", # x4 RRDBNet with 6 blocks
|
29 |
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
30 |
"RealESRNet_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
|
31 |
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
32 |
+
# ESRGAN(oldRRDB)
|
33 |
+
"4x-AnimeSharp.pth": "https://huggingface.co/utnah/esrgan/resolve/main/4x-AnimeSharp.pth?download=true", # https://openmodeldb.info/models/4x-AnimeSharp
|
34 |
+
"4x_IllustrationJaNai_V1_ESRGAN_135k.pth": "https://drive.google.com/uc?export=download&confirm=1&id=1qpioSqBkB_IkSBhEAewSSNFt6qgkBimP", # https://openmodeldb.info/models/4x-IllustrationJaNai-V1-DAT2
|
35 |
+
# DATNet
|
36 |
+
"4xNomos8kDAT.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kDAT/4xNomos8kDAT.pth", # https://openmodeldb.info/models/4x-Nomos8kDAT
|
37 |
+
"4x-DWTP-DS-dat2-v3.pth": "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-DWTP-DS-dat2-v3.pth", # https://openmodeldb.info/models/4x-DWTP-DS-dat2-v3
|
38 |
+
"4x_IllustrationJaNai_V1_DAT2_190k.pth": "https://drive.google.com/uc?export=download&confirm=1&id=1qpioSqBkB_IkSBhEAewSSNFt6qgkBimP", # https://openmodeldb.info/models/4x-IllustrationJaNai-V1-DAT2
|
39 |
+
# HAT
|
40 |
+
"4xNomos8kSCHAT-L.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kSCHAT/4xNomos8kSCHAT-L.pth", # https://openmodeldb.info/models/4x-Nomos8kSCHAT-L
|
41 |
+
"4xNomos8kSCHAT-S.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kSCHAT/4xNomos8kSCHAT-S.pth", # https://openmodeldb.info/models/4x-Nomos8kSCHAT-S
|
42 |
+
"4xNomos8kHAT-L_otf.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kHAT-L_otf/4xNomos8kHAT-L_otf.pth", # https://openmodeldb.info/models/4x-Nomos8kHAT-L-otf
|
43 |
+
# RealPLKSR_dysample
|
44 |
+
"4xHFA2k_ludvae_realplksr_dysample.pth": "https://github.com/Phhofm/models/releases/download/4xHFA2k_ludvae_realplksr_dysample/4xHFA2k_ludvae_realplksr_dysample.pth", # https://openmodeldb.info/models/4x-HFA2k-ludvae-realplksr-dysample
|
45 |
+
"4xArtFaces_realplksr_dysample.pth": "https://github.com/Phhofm/models/releases/download/4xArtFaces_realplksr_dysample/4xArtFaces_realplksr_dysample.pth", # https://openmodeldb.info/models/4x-ArtFaces-realplksr-dysample
|
46 |
+
"4x-PBRify_RPLKSRd_V3.pth": "https://github.com/Kim2091/Kim2091-Models/releases/download/4x-PBRify_RPLKSRd_V3/4x-PBRify_RPLKSRd_V3.pth", # https://openmodeldb.info/models/4x-PBRify-RPLKSRd-V3
|
47 |
+
"4xNomos2_realplksr_dysample.pth": "https://github.com/Phhofm/models/releases/download/4xNomos2_realplksr_dysample/4xNomos2_realplksr_dysample.pth", # https://openmodeldb.info/models/4x-Nomos2-realplksr-dysample
|
48 |
+
# RealPLKSR
|
49 |
+
"2x-AnimeSharpV2_RPLKSR_Sharp.pth": "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV2_Set/2x-AnimeSharpV2_RPLKSR_Sharp.pth", # https://openmodeldb.info/models/2x-AnimeSharpV2-RPLKSR-Sharp
|
50 |
+
"2x-AnimeSharpV2_RPLKSR_Soft.pth": "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV2_Set/2x-AnimeSharpV2_RPLKSR_Soft.pth", # https://openmodeldb.info/models/2x-AnimeSharpV2-RPLKSR-Soft
|
51 |
+
"4xPurePhoto-RealPLSKR.pth": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/4xPurePhoto-RealPLSKR.pth", # https://openmodeldb.info/models/4x-PurePhoto-RealPLSKR
|
52 |
+
"2x_Text2HD_v.1-RealPLKSR.pth": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/2x_Text2HD_v.1-RealPLKSR.pth", # https://openmodeldb.info/models/2x-Text2HD-v-1
|
53 |
+
"2xVHS2HD-RealPLKSR.pth": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/2xVHS2HD-RealPLKSR.pth", # https://openmodeldb.info/models/2x-VHS2HD
|
54 |
+
"4xNomosWebPhoto_RealPLKSR.pth": "https://github.com/Phhofm/models/releases/download/4xNomosWebPhoto_RealPLKSR/4xNomosWebPhoto_RealPLKSR.pth", # https://openmodeldb.info/models/4x-NomosWebPhoto-RealPLKSR
|
55 |
+
}
|
56 |
+
|
57 |
+
files_to_download = {
|
58 |
+
"a1.jpg":
|
59 |
+
"https://thumbs.dreamstime.com/b/tower-bridge-traditional-red-bus-black-white-colors-view-to-tower-bridge-london-black-white-colors-108478942.jpg",
|
60 |
+
"a2.jpg":
|
61 |
+
"https://media.istockphoto.com/id/523514029/photo/london-skyline-b-w.jpg?s=612x612&w=0&k=20&c=kJS1BAtfqYeUDaORupj0sBPc1hpzJhBUUqEFfRnHzZ0=",
|
62 |
+
"a3.jpg":
|
63 |
+
"https://i.guim.co.uk/img/media/06f614065ed82ca0e917b149a32493c791619854/0_0_3648_2789/master/3648.jpg?width=700&quality=85&auto=format&fit=max&s=05764b507c18a38590090d987c8b6202",
|
64 |
+
"a4.jpg":
|
65 |
+
"https://i.pinimg.com/736x/46/96/9e/46969eb94aec2437323464804d27706d--victorian-london-victorian-era.jpg",
|
66 |
}
|
67 |
+
|
68 |
+
def get_model_type(model_name):
|
69 |
+
# Define model type mappings based on key parts of the model names
|
70 |
+
model_type = "other"
|
71 |
+
if any(value in model_name.lower() for value in ("realesrgan", "realesrnet")):
|
72 |
+
model_type = "RRDB"
|
73 |
+
elif "realesr" in model_name.lower() in model_name.lower():
|
74 |
+
model_type = "SRVGG"
|
75 |
+
elif "esrgan" in model_name.lower() or "4x-AnimeSharp.pth" == model_name:
|
76 |
+
model_type = "ESRGAN"
|
77 |
+
elif "dat" in model_name.lower():
|
78 |
+
model_type = "DAT"
|
79 |
+
elif "hat" in model_name.lower():
|
80 |
+
model_type = "HAT"
|
81 |
+
elif ("realplksr" in model_name.lower() and "dysample" in model_name.lower()) or "rplksrd" in model_name.lower():
|
82 |
+
model_type = "RealPLKSR_dysample"
|
83 |
+
elif "realplksr" in model_name.lower() or "rplksr" in model_name.lower():
|
84 |
+
model_type = "RealPLKSR"
|
85 |
+
return f"{model_type}, {model_name}"
|
86 |
+
|
87 |
+
typed_realesr_model = {get_model_type(key): value for key, value in realesr_model.items()}
|
88 |
+
|
89 |
+
def download_from_urls(urls, save_dir=None):
|
90 |
+
for file_name, url in urls.items():
|
91 |
+
download_from_url(url, file_name, save_dir)
|
92 |
+
|
93 |
+
|
94 |
+
class Upscale:
|
95 |
+
def inference(self, img, face_restoration, realesr, scale: float):
|
96 |
+
print(img)
|
97 |
+
print(face_restoration, realesr, scale)
|
98 |
+
try:
|
99 |
+
self.scale = scale
|
100 |
+
self.img_name = os.path.basename(str(img))
|
101 |
+
self.basename, self.extension = os.path.splitext(self.img_name)
|
102 |
+
|
103 |
+
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # cv2.imread(img, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
106 |
+
if len(img.shape) == 2: # for gray inputs
|
107 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
108 |
|
109 |
+
h, w = img.shape[0:2]
|
110 |
+
if h < 300:
|
111 |
+
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
112 |
+
|
113 |
+
if face_restoration:
|
114 |
+
download_from_url(face_model[face_restoration], face_restoration, os.path.join("weights", "face"))
|
115 |
+
if realesr:
|
116 |
+
realesr_type, realesr = realesr.split(", ", 1)
|
117 |
+
download_from_url(realesr_model[realesr], realesr, os.path.join("weights", "realesr"))
|
118 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
netscale = 4
|
120 |
+
loadnet = None
|
121 |
+
model = None
|
122 |
+
is_auto_split_upscale = True
|
123 |
+
half = True if torch.cuda.is_available() else False
|
124 |
+
if realesr_type:
|
125 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
126 |
+
from basicsr.archs.realplksr_arch import realplksr
|
127 |
+
# background enhancer with RealESRGAN
|
128 |
+
if realesr_type == "RRDB":
|
129 |
+
netscale = 2 if "x2" in realesr else 4
|
130 |
+
num_block = 6 if "6B" in realesr else 23
|
131 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=netscale)
|
132 |
+
elif realesr_type == "SRVGG":
|
133 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
134 |
+
netscale = 4
|
135 |
+
num_conv = 16 if "animevideov3" in realesr else 32
|
136 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=netscale, act_type='prelu')
|
137 |
+
elif realesr_type == "ESRGAN":
|
138 |
+
netscale = 4
|
139 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
|
140 |
+
loadnet = {}
|
141 |
+
loadnet_origin = torch.load(os.path.join("weights", "realesr", realesr), map_location=torch.device('cpu'), weights_only=True)
|
142 |
+
for key, value in loadnet_origin.items():
|
143 |
+
new_key = key.replace("model.0", "conv_first").replace("model.1.sub.23.", "conv_body.").replace("model.1.sub", "body") \
|
144 |
+
.replace(".0.weight", ".weight").replace(".0.bias", ".bias").replace(".RDB1.", ".rdb1.").replace(".RDB2.", ".rdb2.").replace(".RDB3.", ".rdb3.") \
|
145 |
+
.replace("model.3.", "conv_up1.").replace("model.6.", "conv_up2.").replace("model.8.", "conv_hr.").replace("model.10.", "conv_last.")
|
146 |
+
loadnet[new_key] = value
|
147 |
+
elif realesr_type == "DAT":
|
148 |
+
from basicsr.archs.dat_arch import DAT
|
149 |
+
half = False
|
150 |
+
netscale = 4
|
151 |
+
expansion_factor = 2. if "dat2" in realesr.lower() else 4.
|
152 |
+
model = DAT(img_size=64, in_chans=3, embed_dim=180, split_size=[8,32], depth=[6,6,6,6,6,6], num_heads=[6,6,6,6,6,6], expansion_factor=expansion_factor, upscale=netscale)
|
153 |
+
# # Speculate on the parameters.
|
154 |
+
# loadnet_origin = torch.load(os.path.join("weights", "realesr", realesr), map_location=torch.device('cpu'), weights_only=True)
|
155 |
+
# inferred_params = self.infer_parameters_from_state_dict_for_dat(loadnet_origin, netscale)
|
156 |
+
# for param, value in inferred_params.items():
|
157 |
+
# print(f"{param}: {value}")
|
158 |
+
elif realesr_type == "HAT":
|
159 |
+
half = False
|
160 |
+
netscale = 4
|
161 |
+
import torch.nn.functional as F
|
162 |
+
from basicsr.archs.hat_arch import HAT
|
163 |
+
class HATWithAutoPadding(HAT):
|
164 |
+
def pad_to_multiple(self, img, multiple):
|
165 |
+
"""
|
166 |
+
Fill the image to multiples of both width and height as integers.
|
167 |
+
"""
|
168 |
+
_, _, h, w = img.shape
|
169 |
+
pad_h = (multiple - h % multiple) % multiple
|
170 |
+
pad_w = (multiple - w % multiple) % multiple
|
171 |
+
|
172 |
+
# Padding on the top, bottom, left, and right.
|
173 |
+
pad_top = pad_h // 2
|
174 |
+
pad_bottom = pad_h - pad_top
|
175 |
+
pad_left = pad_w // 2
|
176 |
+
pad_right = pad_w - pad_left
|
177 |
+
|
178 |
+
img_padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="reflect")
|
179 |
+
return img_padded, (pad_top, pad_bottom, pad_left, pad_right)
|
180 |
+
|
181 |
+
def remove_padding(self, img, pad_info):
|
182 |
+
"""
|
183 |
+
Remove padding and restore to the original size, considering upscaling.
|
184 |
+
"""
|
185 |
+
pad_top, pad_bottom, pad_left, pad_right = pad_info
|
186 |
+
|
187 |
+
# Adjust padding based on upscaling factor
|
188 |
+
pad_top = int(pad_top * self.upscale)
|
189 |
+
pad_bottom = int(pad_bottom * self.upscale)
|
190 |
+
pad_left = int(pad_left * self.upscale)
|
191 |
+
pad_right = int(pad_right * self.upscale)
|
192 |
+
|
193 |
+
return img[:, :, pad_top:-pad_bottom if pad_bottom > 0 else None, pad_left:-pad_right if pad_right > 0 else None]
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
# Step 1: Auto padding
|
197 |
+
x_padded, pad_info = self.pad_to_multiple(x, self.window_size)
|
198 |
+
|
199 |
+
# Step 2: Normal model processing
|
200 |
+
x_processed = super().forward(x_padded)
|
201 |
+
|
202 |
+
# Step 3: Remove padding
|
203 |
+
x_cropped = self.remove_padding(x_processed, pad_info)
|
204 |
+
return x_cropped
|
205 |
+
|
206 |
+
# The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
|
207 |
+
# https://github.com/XPixelGroup/HAT/tree/main/options/test
|
208 |
+
if "hat-l" in realesr.lower():
|
209 |
+
window_size = 16
|
210 |
+
compress_ratio = 3
|
211 |
+
squeeze_factor = 30
|
212 |
+
depths = [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
|
213 |
+
embed_dim = 180
|
214 |
+
num_heads = [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
|
215 |
+
mlp_ratio = 2
|
216 |
+
upsampler = "pixelshuffle"
|
217 |
+
elif "hat-s" in realesr.lower():
|
218 |
+
window_size = 16
|
219 |
+
compress_ratio = 24
|
220 |
+
squeeze_factor = 24
|
221 |
+
depths = [6, 6, 6, 6, 6, 6]
|
222 |
+
embed_dim = 144
|
223 |
+
num_heads = [6, 6, 6, 6, 6, 6]
|
224 |
+
mlp_ratio = 2
|
225 |
+
upsampler = "pixelshuffle"
|
226 |
+
model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
227 |
+
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=netscale,)
|
228 |
+
elif realesr_type == "RealPLKSR_dysample":
|
229 |
+
netscale = 4
|
230 |
+
model = realplksr(upscaling_factor=netscale, dysample=True)
|
231 |
+
elif realesr_type == "RealPLKSR":
|
232 |
+
half = False if "RealPLSKR" in realesr else half
|
233 |
+
netscale = 2 if realesr.startswith("2x") else 4
|
234 |
+
model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=netscale)
|
235 |
+
|
236 |
+
|
237 |
+
self.upsampler = None
|
238 |
+
if loadnet:
|
239 |
+
self.upsampler = RealESRGANer(scale=netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
240 |
+
elif model:
|
241 |
+
self.upsampler = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "realesr", realesr), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
242 |
+
elif realesr:
|
243 |
+
self.upsampler = None
|
244 |
+
import PIL
|
245 |
+
from image_gen_aux import UpscaleWithModel
|
246 |
+
class UpscaleWithModel_Gfpgan(UpscaleWithModel):
|
247 |
+
def cv2pil(self, image):
|
248 |
+
''' OpenCV type -> PIL type
|
249 |
+
https://qiita.com/derodero24/items/f22c22b22451609908ee
|
250 |
+
'''
|
251 |
+
new_image = image.copy()
|
252 |
+
if new_image.ndim == 2: # Grayscale
|
253 |
+
pass
|
254 |
+
elif new_image.shape[2] == 3: # Color
|
255 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
|
256 |
+
elif new_image.shape[2] == 4: # Transparency
|
257 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
|
258 |
+
new_image = PIL.Image.fromarray(new_image)
|
259 |
+
return new_image
|
260 |
+
|
261 |
+
def pil2cv(self, image):
|
262 |
+
''' PIL type -> OpenCV type
|
263 |
+
https://qiita.com/derodero24/items/f22c22b22451609908ee
|
264 |
+
'''
|
265 |
+
new_image = np.array(image, dtype=np.uint8)
|
266 |
+
if new_image.ndim == 2: # Grayscale
|
267 |
+
pass
|
268 |
+
elif new_image.shape[2] == 3: # Color
|
269 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
270 |
+
elif new_image.shape[2] == 4: # Transparency
|
271 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
272 |
+
return new_image
|
273 |
+
|
274 |
+
def enhance(self, img, outscale=None):
|
275 |
+
# img: numpy
|
276 |
+
h_input, w_input = img.shape[0:2]
|
277 |
+
pil_img = self.cv2pil(img)
|
278 |
+
pil_img = self.__call__(pil_img)
|
279 |
+
cv_image = self.pil2cv(pil_img)
|
280 |
+
if outscale is not None and outscale != float(netscale):
|
281 |
+
cv_image = cv2.resize(
|
282 |
+
cv_image, (
|
283 |
+
int(w_input * outscale),
|
284 |
+
int(h_input * outscale),
|
285 |
+
), interpolation=cv2.INTER_LANCZOS4)
|
286 |
+
return cv_image, None
|
287 |
+
|
288 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
289 |
+
upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "realesr", realesr)).to(device)
|
290 |
+
upscaler.__class__ = UpscaleWithModel_Gfpgan
|
291 |
+
self.upsampler = upscaler
|
292 |
+
self.face_enhancer = None
|
293 |
+
|
294 |
+
if face_restoration:
|
295 |
+
from gfpgan.utils import GFPGANer
|
296 |
+
if face_restoration and face_restoration.startswith("GFPGANv1."):
|
297 |
+
self.face_enhancer = GFPGANer(model_path=os.path.join("weights", "face", face_restoration), upscale=self.scale, arch="clean", channel_multiplier=2, bg_upsampler=self.upsampler)
|
298 |
+
elif face_restoration and face_restoration.startswith("RestoreFormer"):
|
299 |
+
arch = "RestoreFormer++" if face_restoration.startswith("RestoreFormer++") else "RestoreFormer"
|
300 |
+
self.face_enhancer = GFPGANer(model_path=os.path.join("weights", "face", face_restoration), upscale=self.scale, arch=arch, channel_multiplier=2, bg_upsampler=self.upsampler)
|
301 |
+
elif face_restoration == 'CodeFormer.pth':
|
302 |
+
self.face_enhancer = GFPGANer(
|
303 |
+
model_path='weights/CodeFormer.pth', upscale=self.scale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=self.upsampler)
|
304 |
+
|
305 |
+
|
306 |
+
files = []
|
307 |
+
outputs = []
|
308 |
+
try:
|
309 |
+
bg_upsample_img = None
|
310 |
+
if self.upsampler and self.upsampler.enhance:
|
311 |
+
from utils.dataops import auto_split_upscale
|
312 |
+
bg_upsample_img, _ = auto_split_upscale(img, self.upsampler.enhance, self.scale) if is_auto_split_upscale else self.upsampler.enhance(img, outscale=self.scale)
|
313 |
+
|
314 |
+
if self.face_enhancer:
|
315 |
+
cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, bg_upsample_img=bg_upsample_img)
|
316 |
+
# save faces
|
317 |
+
if cropped_faces and restored_aligned:
|
318 |
+
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
|
319 |
+
# save cropped face
|
320 |
+
save_crop_path = f"output/{self.basename}{idx:02d}_cropped_faces.png"
|
321 |
+
self.imwriteUTF8(save_crop_path, cropped_face)
|
322 |
+
# save restored face
|
323 |
+
save_restore_path = f"output/{self.basename}{idx:02d}_restored_faces.png"
|
324 |
+
self.imwriteUTF8(save_restore_path, restored_face)
|
325 |
+
# save comparison image
|
326 |
+
save_cmp_path = f"output/{self.basename}{idx:02d}_cmp.png"
|
327 |
+
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
328 |
+
self.imwriteUTF8(save_cmp_path, cmp_img)
|
329 |
|
330 |
+
files.append(save_crop_path)
|
331 |
+
files.append(save_restore_path)
|
332 |
+
files.append(save_cmp_path)
|
333 |
+
outputs.append(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB))
|
334 |
+
outputs.append(cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB))
|
335 |
+
outputs.append(cv2.cvtColor(cmp_img, cv2.COLOR_BGR2RGB))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
+
restored_img = bg_upsample_img
|
338 |
+
except RuntimeError as error:
|
339 |
+
print(traceback.format_exc())
|
340 |
+
print('Error', error)
|
341 |
+
finally:
|
342 |
+
if self.face_enhancer:
|
343 |
+
self.face_enhancer._cleanup()
|
344 |
+
else:
|
345 |
+
# Free GPU memory and clean up resources
|
346 |
+
torch.cuda.empty_cache()
|
347 |
+
gc.collect()
|
348 |
|
349 |
+
if not self.extension:
|
350 |
+
self.extension = ".png" if self.img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
|
351 |
+
save_path = f"output/{self.basename}{self.extension}"
|
352 |
+
self.imwriteUTF8(save_path, restored_img)
|
353 |
+
|
354 |
+
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
|
355 |
+
files.append(save_path)
|
356 |
+
outputs.append(restored_img)
|
357 |
+
return outputs, files
|
358 |
except Exception as error:
|
359 |
print(traceback.format_exc())
|
360 |
+
print("global exception", error)
|
361 |
+
return None, None
|
362 |
+
|
363 |
+
|
364 |
+
def infer_parameters_from_state_dict_for_dat(self, state_dict, upscale=4):
|
365 |
+
if "params" in state_dict:
|
366 |
+
state_dict = state_dict["params"]
|
367 |
+
elif "params_ema" in state_dict:
|
368 |
+
state_dict = state_dict["params_ema"]
|
369 |
+
|
370 |
+
inferred_params = {}
|
371 |
+
|
372 |
+
# Speculate on the depth.
|
373 |
+
depth = {}
|
374 |
+
for key in state_dict.keys():
|
375 |
+
if "blocks" in key:
|
376 |
+
layer = int(key.split(".")[1])
|
377 |
+
block = int(key.split(".")[3])
|
378 |
+
depth[layer] = max(depth.get(layer, 0), block + 1)
|
379 |
+
inferred_params["depth"] = [depth[layer] for layer in sorted(depth.keys())]
|
380 |
+
|
381 |
+
# Speculate on the number of num_heads per layer.
|
382 |
+
# ex.
|
383 |
+
# layers.0.blocks.1.attn.temperature: torch.Size([6, 1, 1])
|
384 |
+
# layers.5.blocks.5.attn.temperature: torch.Size([6, 1, 1])
|
385 |
+
# The shape of temperature is [num_heads, 1, 1].
|
386 |
+
num_heads = []
|
387 |
+
for layer in range(len(inferred_params["depth"])):
|
388 |
+
for block in range(inferred_params["depth"][layer]):
|
389 |
+
key = f"layers.{layer}.blocks.{block}.attn.temperature"
|
390 |
+
if key in state_dict:
|
391 |
+
num_heads_layer = state_dict[key].shape[0]
|
392 |
+
num_heads.append(num_heads_layer)
|
393 |
+
break
|
394 |
+
|
395 |
+
inferred_params["num_heads"] = num_heads
|
396 |
+
|
397 |
+
# Speculate on embed_dim.
|
398 |
+
# ex. layers.0.blocks.0.attn.qkv.weight: torch.Size([540, 180])
|
399 |
+
for key in state_dict.keys():
|
400 |
+
if "attn.qkv.weight" in key:
|
401 |
+
qkv_weight = state_dict[key]
|
402 |
+
embed_dim = qkv_weight.shape[1] # Note: The in_features of qkv corresponds to embed_dim.
|
403 |
+
inferred_params["embed_dim"] = embed_dim
|
404 |
+
break
|
405 |
+
|
406 |
+
# Speculate on split_size.
|
407 |
+
# ex.
|
408 |
+
# layers.0.blocks.0.attn.attns.0.rpe_biases: torch.Size([945, 2])
|
409 |
+
# layers.0.blocks.0.attn.attns.0.relative_position_index: torch.Size([256, 256])
|
410 |
+
# layers.0.blocks.2.attn.attn_mask_0: torch.Size([16, 256, 256])
|
411 |
+
# layers.0.blocks.2.attn.attn_mask_1: torch.Size([16, 256, 256])
|
412 |
+
for key in state_dict.keys():
|
413 |
+
if "relative_position_index" in key:
|
414 |
+
relative_position_size = state_dict[key].shape[0]
|
415 |
+
# Determine split_size[0] and split_size[1] based on the provided data.
|
416 |
+
split_size_0, split_size_1 = 8, relative_position_size // 8 # 256 = 8 * 32
|
417 |
+
inferred_params["split_size"] = [split_size_0, split_size_1]
|
418 |
+
break
|
419 |
+
|
420 |
+
# Speculate on the expansion_factor.
|
421 |
+
# ex.
|
422 |
+
# layers.0.blocks.0.ffn.fc1.weight: torch.Size([360, 180])
|
423 |
+
# layers.5.blocks.5.ffn.fc1.weight: torch.Size([360, 180])
|
424 |
+
if "embed_dim" in inferred_params:
|
425 |
+
for key in state_dict.keys():
|
426 |
+
if "ffn.fc1.weight" in key:
|
427 |
+
fc1_weight = state_dict[key]
|
428 |
+
expansion_factor = fc1_weight.shape[0] // inferred_params["embed_dim"]
|
429 |
+
inferred_params["expansion_factor"] = expansion_factor
|
430 |
+
break
|
431 |
+
|
432 |
+
inferred_params["img_size"] = 64
|
433 |
+
inferred_params["in_chans"] = 3 # Assume an RGB image.
|
434 |
+
|
435 |
+
for key in state_dict.keys():
|
436 |
+
print(f"{key}: {state_dict[key].shape}")
|
437 |
+
|
438 |
+
return inferred_params
|
439 |
+
|
440 |
+
|
441 |
+
def imwriteUTF8(self, save_path, image): # `cv2.imwrite` does not support writing files to UTF-8 file paths.
|
442 |
+
img_name = os.path.basename(save_path)
|
443 |
+
_, extension = os.path.splitext(img_name)
|
444 |
+
is_success, im_buf_arr = cv2.imencode(extension, image)
|
445 |
+
if (is_success): im_buf_arr.tofile(save_path)
|
446 |
+
|
447 |
+
|
448 |
+
def main():
|
449 |
+
if torch.cuda.is_available():
|
450 |
+
torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
|
451 |
+
# Ensure the target directory exists
|
452 |
+
os.makedirs('output', exist_ok=True)
|
453 |
+
|
454 |
+
# Iterate through each file
|
455 |
+
download_from_urls(files_to_download, ".")
|
456 |
+
|
457 |
+
title = "Image Upscaling & Restoration(esp. Face) using GFPGAN Algorithm"
|
458 |
+
description = r"""Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration and Upscalling of the image with a Generative Facial Prior</b></a>.<br>
|
459 |
+
Practically the algorithm is used to restore your **old photos** or improve **AI-generated faces**.<br>
|
460 |
+
To use it, simply just upload the concerned image.<br>
|
461 |
+
"""
|
462 |
+
article = r"""
|
463 |
+
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
|
464 |
+
[![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
|
465 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
|
466 |
+
<center><img src='https://visitor-badge.glitch.me/badge?page_id=dj_face_restoration_GFPGAN' alt='visitor badge'></center>
|
467 |
+
"""
|
468 |
+
|
469 |
+
upscale = Upscale()
|
470 |
+
|
471 |
+
demo = gr.Interface(
|
472 |
+
upscale.inference, [
|
473 |
+
gr.Image(type="filepath", label="Input", format="png"),
|
474 |
+
gr.Dropdown(list(face_model.keys())+[None], type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model."),
|
475 |
+
gr.Dropdown(list(typed_realesr_model.keys())+[None], type="value", value='SRVGG, realesr-general-x4v3.pth', label='RealESR version'),
|
476 |
+
gr.Number(label="Rescaling factor", value=4),
|
477 |
+
], [
|
478 |
+
gr.Gallery(type="numpy", label="Output (The whole image)", format="png"),
|
479 |
+
gr.File(label="Download the output image")
|
480 |
+
],
|
481 |
+
title=title,
|
482 |
+
description=description,
|
483 |
+
article=article,
|
484 |
+
examples=[["a1.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2],
|
485 |
+
["a2.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2],
|
486 |
+
["a3.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2],
|
487 |
+
["a4.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2]])
|
488 |
|
489 |
+
demo.queue(default_concurrency_limit=4)
|
490 |
+
demo.launch(inbrowser=True)
|
491 |
+
|
492 |
+
|
493 |
+
if __name__ == "__main__":
|
494 |
+
main()
|
requirements.txt
CHANGED
@@ -2,10 +2,11 @@
|
|
2 |
|
3 |
gradio==5.8.0
|
4 |
|
5 |
-
basicsr @ git+https://github.com/
|
6 |
facexlib @ git+https://github.com/avan06/facexlib
|
7 |
gfpgan @ git+https://github.com/avan06/GFPGAN
|
8 |
realesrgan @ git+https://github.com/avan06/Real-ESRGAN
|
|
|
9 |
numpy
|
10 |
opencv-python
|
11 |
|
@@ -18,4 +19,7 @@ scipy
|
|
18 |
tqdm
|
19 |
lmdb
|
20 |
pyyaml
|
21 |
-
yapf
|
|
|
|
|
|
|
|
2 |
|
3 |
gradio==5.8.0
|
4 |
|
5 |
+
basicsr @ git+https://github.com/avan06/BasicSR
|
6 |
facexlib @ git+https://github.com/avan06/facexlib
|
7 |
gfpgan @ git+https://github.com/avan06/GFPGAN
|
8 |
realesrgan @ git+https://github.com/avan06/Real-ESRGAN
|
9 |
+
|
10 |
numpy
|
11 |
opencv-python
|
12 |
|
|
|
19 |
tqdm
|
20 |
lmdb
|
21 |
pyyaml
|
22 |
+
yapf
|
23 |
+
|
24 |
+
image_gen_aux @ git+https://github.com/huggingface/image_gen_aux
|
25 |
+
gdown # supports downloading the large file from Google Drive
|
utils/dataops.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# The file source is from the [ESRGAN](https://github.com/xinntao/ESRGAN) project
|
4 |
+
# forked by authors [joeyballentine](https://github.com/joeyballentine/ESRGAN) and [BlueAmulet](https://github.com/BlueAmulet/ESRGAN).
|
5 |
+
|
6 |
+
import gc
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def bgr_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
13 |
+
# flip image channels
|
14 |
+
# https://github.com/pytorch/pytorch/issues/229
|
15 |
+
out: torch.Tensor = image.flip(-3)
|
16 |
+
# out: torch.Tensor = image[[2, 1, 0], :, :] #RGB to BGR #may be faster
|
17 |
+
return out
|
18 |
+
|
19 |
+
|
20 |
+
def rgb_to_bgr(image: torch.Tensor) -> torch.Tensor:
|
21 |
+
# same operation as bgr_to_rgb(), flip image channels
|
22 |
+
return bgr_to_rgb(image)
|
23 |
+
|
24 |
+
|
25 |
+
def bgra_to_rgba(image: torch.Tensor) -> torch.Tensor:
|
26 |
+
out: torch.Tensor = image[[2, 1, 0, 3], :, :]
|
27 |
+
return out
|
28 |
+
|
29 |
+
|
30 |
+
def rgba_to_bgra(image: torch.Tensor) -> torch.Tensor:
|
31 |
+
# same operation as bgra_to_rgba(), flip image channels
|
32 |
+
return bgra_to_rgba(image)
|
33 |
+
|
34 |
+
|
35 |
+
def auto_split_upscale(
|
36 |
+
lr_img: np.ndarray,
|
37 |
+
upscale_function,
|
38 |
+
scale: int = 4,
|
39 |
+
overlap: int = 32,
|
40 |
+
max_depth: int = None,
|
41 |
+
current_depth: int = 1,
|
42 |
+
):
|
43 |
+
# Attempt to upscale if unknown depth or if reached known max depth
|
44 |
+
if max_depth is None or max_depth == current_depth:
|
45 |
+
try:
|
46 |
+
print(f"auto_split_upscale, current depth: {current_depth}")
|
47 |
+
result, _ = upscale_function(lr_img, scale)
|
48 |
+
return result, current_depth
|
49 |
+
except RuntimeError as e:
|
50 |
+
# Check to see if its actually the CUDA out of memory error
|
51 |
+
if "CUDA" in str(e):
|
52 |
+
# Collect garbage (clear VRAM)
|
53 |
+
torch.cuda.empty_cache()
|
54 |
+
gc.collect()
|
55 |
+
# Re-raise the exception if not an OOM error
|
56 |
+
else:
|
57 |
+
raise RuntimeError(e)
|
58 |
+
finally:
|
59 |
+
# Free GPU memory and clean up resources
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
gc.collect()
|
62 |
+
|
63 |
+
h, w, c = lr_img.shape
|
64 |
+
|
65 |
+
# Split image into 4ths
|
66 |
+
top_left = lr_img[: h // 2 + overlap, : w // 2 + overlap, :]
|
67 |
+
top_right = lr_img[: h // 2 + overlap, w // 2 - overlap :, :]
|
68 |
+
bottom_left = lr_img[h // 2 - overlap :, : w // 2 + overlap, :]
|
69 |
+
bottom_right = lr_img[h // 2 - overlap :, w // 2 - overlap :, :]
|
70 |
+
|
71 |
+
# Recursively upscale the quadrants
|
72 |
+
# After we go through the top left quadrant, we know the maximum depth and no longer need to test for out-of-memory
|
73 |
+
top_left_rlt, depth = auto_split_upscale(
|
74 |
+
top_left,
|
75 |
+
upscale_function,
|
76 |
+
scale=scale,
|
77 |
+
overlap=overlap,
|
78 |
+
max_depth=max_depth,
|
79 |
+
current_depth=current_depth + 1,
|
80 |
+
)
|
81 |
+
top_right_rlt, _ = auto_split_upscale(
|
82 |
+
top_right,
|
83 |
+
upscale_function,
|
84 |
+
scale=scale,
|
85 |
+
overlap=overlap,
|
86 |
+
max_depth=depth,
|
87 |
+
current_depth=current_depth + 1,
|
88 |
+
)
|
89 |
+
bottom_left_rlt, _ = auto_split_upscale(
|
90 |
+
bottom_left,
|
91 |
+
upscale_function,
|
92 |
+
scale=scale,
|
93 |
+
overlap=overlap,
|
94 |
+
max_depth=depth,
|
95 |
+
current_depth=current_depth + 1,
|
96 |
+
)
|
97 |
+
bottom_right_rlt, _ = auto_split_upscale(
|
98 |
+
bottom_right,
|
99 |
+
upscale_function,
|
100 |
+
scale=scale,
|
101 |
+
overlap=overlap,
|
102 |
+
max_depth=depth,
|
103 |
+
current_depth=current_depth + 1,
|
104 |
+
)
|
105 |
+
|
106 |
+
# Define output shape
|
107 |
+
out_h = h * scale
|
108 |
+
out_w = w * scale
|
109 |
+
|
110 |
+
# Create blank output image
|
111 |
+
output_img = np.zeros((out_h, out_w, c), np.uint8)
|
112 |
+
|
113 |
+
# Fill output image with tiles, cropping out the overlaps
|
114 |
+
output_img[: out_h // 2, : out_w // 2, :] = top_left_rlt[
|
115 |
+
: out_h // 2, : out_w // 2, :
|
116 |
+
]
|
117 |
+
output_img[: out_h // 2, -out_w // 2 :, :] = top_right_rlt[
|
118 |
+
: out_h // 2, -out_w // 2 :, :
|
119 |
+
]
|
120 |
+
output_img[-out_h // 2 :, : out_w // 2, :] = bottom_left_rlt[
|
121 |
+
-out_h // 2 :, : out_w // 2, :
|
122 |
+
]
|
123 |
+
output_img[-out_h // 2 :, -out_w // 2 :, :] = bottom_right_rlt[
|
124 |
+
-out_h // 2 :, -out_w // 2 :, :
|
125 |
+
]
|
126 |
+
|
127 |
+
return output_img, depth
|