avans06 commited on
Commit
c85d0ce
1 Parent(s): 0f09e67

The inference uses the auto_split_upscale mechanism.

Browse files

1. auto_split_upscale is in the dataops.py file, and its source is from the ESRGAN project forked by authors joeyballentine and BlueAmulet.

2. The face model now supports RestoreFormer++, authored by wzhouxiff.

3. Added support for parsing older RRDB models from the ESRGAN project.

4. Added support for parsing DAT models from the DAT project by author zhengchen1999.

5. Added support for parsing HAT models from the HAT project by author XPixelGroup.

6. Added support for parsing RealPLKSR models from the PLKSR project by author dslisleedh & neosr-project.

Files changed (3) hide show
  1. app.py +467 -222
  2. requirements.txt +6 -2
  3. utils/dataops.py +127 -0
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import gc
3
  import cv2
4
  import requests
@@ -6,244 +6,489 @@ import numpy as np
6
  import gradio as gr
7
  import torch
8
  import traceback
9
- from tqdm import tqdm
10
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
11
- from gfpgan.utils import GFPGANer
12
  from realesrgan.utils import RealESRGANer
13
- from basicsr.archs.rrdbnet_arch import RRDBNet
14
 
15
  # Define URLs and their corresponding local storage paths
16
  face_model = {
17
- "GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
18
- "GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
19
  "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
20
- "RestoreFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth",
21
- "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
 
 
 
 
22
  }
23
  realesr_model = {
24
- "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
25
- "realesr-animevideov3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
26
- "RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
 
 
27
  "RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
28
  "RealESRNet_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
29
  "RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
30
- "4x-AnimeSharp.pth": "https://huggingface.co/utnah/esrgan/resolve/main/4x-AnimeSharp.pth?download=true",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
- files_to_download = [
33
- ( "a1.jpg",
34
- "https://thumbs.dreamstime.com/b/tower-bridge-traditional-red-bus-black-white-colors-view-to-tower-bridge-london-black-white-colors-108478942.jpg" ),
35
- ( "a2.jpg",
36
- "https://media.istockphoto.com/id/523514029/photo/london-skyline-b-w.jpg?s=612x612&w=0&k=20&c=kJS1BAtfqYeUDaORupj0sBPc1hpzJhBUUqEFfRnHzZ0=" ),
37
- ( "a3.jpg",
38
- "https://i.guim.co.uk/img/media/06f614065ed82ca0e917b149a32493c791619854/0_0_3648_2789/master/3648.jpg?width=700&quality=85&auto=format&fit=max&s=05764b507c18a38590090d987c8b6202" ),
39
- ( "a4.jpg",
40
- "https://i.pinimg.com/736x/46/96/9e/46969eb94aec2437323464804d27706d--victorian-london-victorian-era.jpg" ),
41
- ]
42
-
43
- # Ensure the target directory exists
44
- os.makedirs("weights", exist_ok=True)
45
- os.makedirs('output', exist_ok=True)
46
-
47
- def download_from_url(output_path, url):
48
- try:
49
- # Check if the file already exists
50
- if os.path.exists(output_path):
51
- print(f"File already exists, skipping download: {output_path}")
52
- return
53
-
54
- print(f"Downloading: {url}")
55
- with requests.get(url, stream=True) as response, open(output_path, "wb") as f:
56
- total_size = int(response.headers.get('content-length', 0))
57
- with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
58
- for chunk in response.iter_content(chunk_size=8192):
59
- f.write(chunk)
60
- pbar.update(len(chunk))
61
- print(f"Download successful: {output_path}")
62
- except requests.RequestException as e:
63
- print(f"Download failed: {url}, Error: {e}")
64
-
65
-
66
- # Iterate through each file
67
- for output_path, url in files_to_download:
68
- # Check if the file already exists
69
- if os.path.exists(output_path):
70
- print(f"File already exists, skipping download: {output_path}")
71
- continue
72
-
73
- # Start downloading
74
- download_from_url(output_path, url)
75
-
76
-
77
- def inference(img, version, realesr, scale: float):
78
- print(img, version, scale)
79
- try:
80
- img_name = os.path.basename(str(img))
81
- basename, extension = os.path.splitext(img_name)
82
- img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
83
- if len(img.shape) == 3 and img.shape[2] == 4:
84
- img_mode = 'RGBA'
85
- elif len(img.shape) == 2: # for gray inputs
86
- img_mode = None
87
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
88
- else:
89
- img_mode = None
90
-
91
- h, w = img.shape[0:2]
92
- if h < 300:
93
- img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
94
 
95
- if version:
96
- download_from_url(os.path.join("weights", version), face_model[version])
97
- if realesr:
98
- download_from_url(os.path.join("weights", realesr), realesr_model[realesr])
99
 
100
- # background enhancer with RealESRGAN
101
- if realesr == 'RealESRGAN_x4plus.pth': # x4 RRDBNet model
102
- netscale = 4
103
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
104
- elif realesr == 'RealESRNet_x4plus.pth': # x4 RRDBNet model
105
- netscale = 4
106
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
107
- elif realesr == 'RealESRGAN_x4plus_anime_6B.pth': # x4 RRDBNet model with 6 blocks
108
- netscale = 4
109
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=netscale)
110
- elif realesr == 'RealESRGAN_x2plus.pth': # x2 RRDBNet model
111
- netscale = 2
112
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
113
- elif realesr == 'realesr-animevideov3.pth': # x4 VGG-style model (XS size)
114
- netscale = 4
115
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=netscale, act_type='prelu')
116
- elif realesr == 'realesr-general-x4v3.pth': # x4 VGG-style model (S size)
117
  netscale = 4
118
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=netscale, act_type='prelu')
119
- # elif realesr == '4x-AnimeSharp.pth': # 4x-AnimeSharp
120
- # netscale = 4
121
- # model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- half = True if torch.cuda.is_available() else False
124
- upsampler = RealESRGANer(scale=netscale, model_path=os.path.join("weights", realesr), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
125
-
126
- face_enhancer = None
127
- if version == 'GFPGANv1.2.pth':
128
- face_enhancer = GFPGANer(
129
- model_path='weights/GFPGANv1.2.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
130
- elif version == 'GFPGANv1.3.pth':
131
- face_enhancer = GFPGANer(
132
- model_path='weights/GFPGANv1.3.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
133
- elif version == 'GFPGANv1.4.pth':
134
- face_enhancer = GFPGANer(
135
- model_path='weights/GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
136
- elif version == 'RestoreFormer.pth':
137
- face_enhancer = GFPGANer(
138
- model_path='weights/RestoreFormer.pth', upscale=scale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
139
- elif version == 'CodeFormer.pth':
140
- face_enhancer = GFPGANer(
141
- model_path='weights/CodeFormer.pth', upscale=scale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
142
-
143
- files = []
144
- outputs = []
145
- try:
146
- if face_enhancer:
147
- cropped_faces, restored_aligned, restored_img = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
148
- # save faces
149
- if cropped_faces and restored_aligned:
150
- for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
151
- # save cropped face
152
- save_crop_path = f"output/{basename}{idx:02d}_cropped_faces.png"
153
- cv2.imwrite(save_crop_path, cropped_face)
154
- # save restored face
155
- save_restore_path = f"output/{basename}{idx:02d}_restored_faces.png"
156
- cv2.imwrite(save_restore_path, restored_face)
157
- # save comparison image
158
- save_cmp_path = f"output/{basename}{idx:02d}_cmp.png"
159
- cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
160
- cv2.imwrite(save_cmp_path, cmp_img)
161
-
162
- files.append(save_crop_path)
163
- files.append(save_restore_path)
164
- files.append(save_cmp_path)
165
- outputs.append(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB))
166
- outputs.append(cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB))
167
- outputs.append(cv2.cvtColor(cmp_img, cv2.COLOR_BGR2RGB))
168
- else:
169
- restored_img, _ = upsampler.enhance(img, outscale=scale)
170
- except RuntimeError as error:
171
- print(traceback.format_exc())
172
- print('Error', error)
173
- finally:
174
- if face_enhancer:
175
- face_enhancer._cleanup()
176
- else:
177
- # Free GPU memory and clean up resources
178
- torch.cuda.empty_cache()
179
- gc.collect()
180
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- try:
183
- if scale != 2:
184
- interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
185
- h, w = img.shape[0:2]
186
- restored_img = cv2.resize(restored_img, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
 
 
 
 
187
  except Exception as error:
188
  print(traceback.format_exc())
189
- print("wrong scale input.", error)
190
-
191
- if not extension:
192
- extension = ".png" if img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
193
- save_path = f"output/{basename}{extension}"
194
- cv2.imwrite(save_path, restored_img)
195
-
196
- restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
197
- files.append(save_path)
198
- outputs.append(restored_img)
199
- return outputs, files
200
- except Exception as error:
201
- print(traceback.format_exc())
202
- print("global exception", error)
203
- return None, None
204
-
205
-
206
- title = "Image Upscaling & Restoration(esp. Face) using GFPGAN Algorithm"
207
- description = r"""Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration and Upscalling of the image with a Generative Facial Prior</b></a>.<br>
208
- Practically the algorithm is used to restore your **old photos** or improve **AI-generated faces**.<br>
209
- To use it, simply just upload the concerned image.<br>
210
- """
211
- article = r"""
212
- [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
213
- [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
214
- [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
215
- <center><img src='https://visitor-badge.glitch.me/badge?page_id=dj_face_restoration_GFPGAN' alt='visitor badge'></center>
216
- """
217
- demo = gr.Interface(
218
- inference, [
219
- gr.Image(type="filepath", label="Input", format="png"),
220
- gr.Dropdown(["GFPGANv1.2.pth",
221
- "GFPGANv1.3.pth",
222
- "GFPGANv1.4.pth",
223
- "RestoreFormer.pth",
224
- # "CodeFormer.pth",
225
- None], type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model."),
226
- gr.Dropdown(["realesr-general-x4v3.pth",
227
- "realesr-animevideov3.pth",
228
- "RealESRGAN_x4plus_anime_6B.pth",
229
- "RealESRGAN_x2plus.pth",
230
- "RealESRNet_x4plus.pth",
231
- "RealESRGAN_x4plus.pth",
232
- # "4x-AnimeSharp.pth",
233
- None], type="value", value='realesr-general-x4v3.pth', label='RealESR version'),
234
- gr.Number(label="Rescaling factor", value=2),
235
- # gr.Slider(0, 100, label='Weight, only for CodeFormer. 0 for better quality, 100 for better identity', value=50)
236
- ], [
237
- gr.Gallery(type="numpy", label="Output (The whole image)", format="png"),
238
- gr.File(label="Download the output image")
239
- ],
240
- title=title,
241
- description=description,
242
- article=article,
243
- examples=[['a1.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2],
244
- ['a2.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2],
245
- ['a3.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2],
246
- ['a4.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- demo.queue(default_concurrency_limit=4)
249
- demo.launch(inbrowser=True)
 
 
 
 
 
1
+ import os
2
  import gc
3
  import cv2
4
  import requests
 
6
  import gradio as gr
7
  import torch
8
  import traceback
9
+ from facexlib.utils.misc import download_from_url
 
 
10
  from realesrgan.utils import RealESRGANer
11
+
12
 
13
  # Define URLs and their corresponding local storage paths
14
  face_model = {
 
 
15
  "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
16
+ "RestoreFormer++.ckpt": "https://github.com/wzhouxiff/RestoreFormerPlusPlus/releases/download/v1.0.0/RestoreFormer++.ckpt",
17
+ # "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
18
+ # legacy model
19
+ "GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
20
+ "GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
21
+ "RestoreFormer.ckpt": "https://github.com/wzhouxiff/RestoreFormerPlusPlus/releases/download/v1.0.0/RestoreFormer.ckpt",
22
  }
23
  realesr_model = {
24
+ # SRVGGNet
25
+ "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", # x4 SRVGGNet (S size)
26
+ "realesr-animevideov3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", # x4 SRVGGNet (XS size)
27
+ # RRDBNet
28
+ "RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", # x4 RRDBNet with 6 blocks
29
  "RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
30
  "RealESRNet_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
31
  "RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
32
+ # ESRGAN(oldRRDB)
33
+ "4x-AnimeSharp.pth": "https://huggingface.co/utnah/esrgan/resolve/main/4x-AnimeSharp.pth?download=true", # https://openmodeldb.info/models/4x-AnimeSharp
34
+ "4x_IllustrationJaNai_V1_ESRGAN_135k.pth": "https://drive.google.com/uc?export=download&confirm=1&id=1qpioSqBkB_IkSBhEAewSSNFt6qgkBimP", # https://openmodeldb.info/models/4x-IllustrationJaNai-V1-DAT2
35
+ # DATNet
36
+ "4xNomos8kDAT.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kDAT/4xNomos8kDAT.pth", # https://openmodeldb.info/models/4x-Nomos8kDAT
37
+ "4x-DWTP-DS-dat2-v3.pth": "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-DWTP-DS-dat2-v3.pth", # https://openmodeldb.info/models/4x-DWTP-DS-dat2-v3
38
+ "4x_IllustrationJaNai_V1_DAT2_190k.pth": "https://drive.google.com/uc?export=download&confirm=1&id=1qpioSqBkB_IkSBhEAewSSNFt6qgkBimP", # https://openmodeldb.info/models/4x-IllustrationJaNai-V1-DAT2
39
+ # HAT
40
+ "4xNomos8kSCHAT-L.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kSCHAT/4xNomos8kSCHAT-L.pth", # https://openmodeldb.info/models/4x-Nomos8kSCHAT-L
41
+ "4xNomos8kSCHAT-S.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kSCHAT/4xNomos8kSCHAT-S.pth", # https://openmodeldb.info/models/4x-Nomos8kSCHAT-S
42
+ "4xNomos8kHAT-L_otf.pth": "https://github.com/Phhofm/models/releases/download/4xNomos8kHAT-L_otf/4xNomos8kHAT-L_otf.pth", # https://openmodeldb.info/models/4x-Nomos8kHAT-L-otf
43
+ # RealPLKSR_dysample
44
+ "4xHFA2k_ludvae_realplksr_dysample.pth": "https://github.com/Phhofm/models/releases/download/4xHFA2k_ludvae_realplksr_dysample/4xHFA2k_ludvae_realplksr_dysample.pth", # https://openmodeldb.info/models/4x-HFA2k-ludvae-realplksr-dysample
45
+ "4xArtFaces_realplksr_dysample.pth": "https://github.com/Phhofm/models/releases/download/4xArtFaces_realplksr_dysample/4xArtFaces_realplksr_dysample.pth", # https://openmodeldb.info/models/4x-ArtFaces-realplksr-dysample
46
+ "4x-PBRify_RPLKSRd_V3.pth": "https://github.com/Kim2091/Kim2091-Models/releases/download/4x-PBRify_RPLKSRd_V3/4x-PBRify_RPLKSRd_V3.pth", # https://openmodeldb.info/models/4x-PBRify-RPLKSRd-V3
47
+ "4xNomos2_realplksr_dysample.pth": "https://github.com/Phhofm/models/releases/download/4xNomos2_realplksr_dysample/4xNomos2_realplksr_dysample.pth", # https://openmodeldb.info/models/4x-Nomos2-realplksr-dysample
48
+ # RealPLKSR
49
+ "2x-AnimeSharpV2_RPLKSR_Sharp.pth": "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV2_Set/2x-AnimeSharpV2_RPLKSR_Sharp.pth", # https://openmodeldb.info/models/2x-AnimeSharpV2-RPLKSR-Sharp
50
+ "2x-AnimeSharpV2_RPLKSR_Soft.pth": "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV2_Set/2x-AnimeSharpV2_RPLKSR_Soft.pth", # https://openmodeldb.info/models/2x-AnimeSharpV2-RPLKSR-Soft
51
+ "4xPurePhoto-RealPLSKR.pth": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/4xPurePhoto-RealPLSKR.pth", # https://openmodeldb.info/models/4x-PurePhoto-RealPLSKR
52
+ "2x_Text2HD_v.1-RealPLKSR.pth": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/2x_Text2HD_v.1-RealPLKSR.pth", # https://openmodeldb.info/models/2x-Text2HD-v-1
53
+ "2xVHS2HD-RealPLKSR.pth": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/2xVHS2HD-RealPLKSR.pth", # https://openmodeldb.info/models/2x-VHS2HD
54
+ "4xNomosWebPhoto_RealPLKSR.pth": "https://github.com/Phhofm/models/releases/download/4xNomosWebPhoto_RealPLKSR/4xNomosWebPhoto_RealPLKSR.pth", # https://openmodeldb.info/models/4x-NomosWebPhoto-RealPLKSR
55
+ }
56
+
57
+ files_to_download = {
58
+ "a1.jpg":
59
+ "https://thumbs.dreamstime.com/b/tower-bridge-traditional-red-bus-black-white-colors-view-to-tower-bridge-london-black-white-colors-108478942.jpg",
60
+ "a2.jpg":
61
+ "https://media.istockphoto.com/id/523514029/photo/london-skyline-b-w.jpg?s=612x612&w=0&k=20&c=kJS1BAtfqYeUDaORupj0sBPc1hpzJhBUUqEFfRnHzZ0=",
62
+ "a3.jpg":
63
+ "https://i.guim.co.uk/img/media/06f614065ed82ca0e917b149a32493c791619854/0_0_3648_2789/master/3648.jpg?width=700&quality=85&auto=format&fit=max&s=05764b507c18a38590090d987c8b6202",
64
+ "a4.jpg":
65
+ "https://i.pinimg.com/736x/46/96/9e/46969eb94aec2437323464804d27706d--victorian-london-victorian-era.jpg",
66
  }
67
+
68
+ def get_model_type(model_name):
69
+ # Define model type mappings based on key parts of the model names
70
+ model_type = "other"
71
+ if any(value in model_name.lower() for value in ("realesrgan", "realesrnet")):
72
+ model_type = "RRDB"
73
+ elif "realesr" in model_name.lower() in model_name.lower():
74
+ model_type = "SRVGG"
75
+ elif "esrgan" in model_name.lower() or "4x-AnimeSharp.pth" == model_name:
76
+ model_type = "ESRGAN"
77
+ elif "dat" in model_name.lower():
78
+ model_type = "DAT"
79
+ elif "hat" in model_name.lower():
80
+ model_type = "HAT"
81
+ elif ("realplksr" in model_name.lower() and "dysample" in model_name.lower()) or "rplksrd" in model_name.lower():
82
+ model_type = "RealPLKSR_dysample"
83
+ elif "realplksr" in model_name.lower() or "rplksr" in model_name.lower():
84
+ model_type = "RealPLKSR"
85
+ return f"{model_type}, {model_name}"
86
+
87
+ typed_realesr_model = {get_model_type(key): value for key, value in realesr_model.items()}
88
+
89
+ def download_from_urls(urls, save_dir=None):
90
+ for file_name, url in urls.items():
91
+ download_from_url(url, file_name, save_dir)
92
+
93
+
94
+ class Upscale:
95
+ def inference(self, img, face_restoration, realesr, scale: float):
96
+ print(img)
97
+ print(face_restoration, realesr, scale)
98
+ try:
99
+ self.scale = scale
100
+ self.img_name = os.path.basename(str(img))
101
+ self.basename, self.extension = os.path.splitext(self.img_name)
102
+
103
+ img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
106
+ if len(img.shape) == 2: # for gray inputs
107
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
108
 
109
+ h, w = img.shape[0:2]
110
+ if h < 300:
111
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
112
+
113
+ if face_restoration:
114
+ download_from_url(face_model[face_restoration], face_restoration, os.path.join("weights", "face"))
115
+ if realesr:
116
+ realesr_type, realesr = realesr.split(", ", 1)
117
+ download_from_url(realesr_model[realesr], realesr, os.path.join("weights", "realesr"))
118
+
 
 
 
 
 
 
 
119
  netscale = 4
120
+ loadnet = None
121
+ model = None
122
+ is_auto_split_upscale = True
123
+ half = True if torch.cuda.is_available() else False
124
+ if realesr_type:
125
+ from basicsr.archs.rrdbnet_arch import RRDBNet
126
+ from basicsr.archs.realplksr_arch import realplksr
127
+ # background enhancer with RealESRGAN
128
+ if realesr_type == "RRDB":
129
+ netscale = 2 if "x2" in realesr else 4
130
+ num_block = 6 if "6B" in realesr else 23
131
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=netscale)
132
+ elif realesr_type == "SRVGG":
133
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
134
+ netscale = 4
135
+ num_conv = 16 if "animevideov3" in realesr else 32
136
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=netscale, act_type='prelu')
137
+ elif realesr_type == "ESRGAN":
138
+ netscale = 4
139
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale)
140
+ loadnet = {}
141
+ loadnet_origin = torch.load(os.path.join("weights", "realesr", realesr), map_location=torch.device('cpu'), weights_only=True)
142
+ for key, value in loadnet_origin.items():
143
+ new_key = key.replace("model.0", "conv_first").replace("model.1.sub.23.", "conv_body.").replace("model.1.sub", "body") \
144
+ .replace(".0.weight", ".weight").replace(".0.bias", ".bias").replace(".RDB1.", ".rdb1.").replace(".RDB2.", ".rdb2.").replace(".RDB3.", ".rdb3.") \
145
+ .replace("model.3.", "conv_up1.").replace("model.6.", "conv_up2.").replace("model.8.", "conv_hr.").replace("model.10.", "conv_last.")
146
+ loadnet[new_key] = value
147
+ elif realesr_type == "DAT":
148
+ from basicsr.archs.dat_arch import DAT
149
+ half = False
150
+ netscale = 4
151
+ expansion_factor = 2. if "dat2" in realesr.lower() else 4.
152
+ model = DAT(img_size=64, in_chans=3, embed_dim=180, split_size=[8,32], depth=[6,6,6,6,6,6], num_heads=[6,6,6,6,6,6], expansion_factor=expansion_factor, upscale=netscale)
153
+ # # Speculate on the parameters.
154
+ # loadnet_origin = torch.load(os.path.join("weights", "realesr", realesr), map_location=torch.device('cpu'), weights_only=True)
155
+ # inferred_params = self.infer_parameters_from_state_dict_for_dat(loadnet_origin, netscale)
156
+ # for param, value in inferred_params.items():
157
+ # print(f"{param}: {value}")
158
+ elif realesr_type == "HAT":
159
+ half = False
160
+ netscale = 4
161
+ import torch.nn.functional as F
162
+ from basicsr.archs.hat_arch import HAT
163
+ class HATWithAutoPadding(HAT):
164
+ def pad_to_multiple(self, img, multiple):
165
+ """
166
+ Fill the image to multiples of both width and height as integers.
167
+ """
168
+ _, _, h, w = img.shape
169
+ pad_h = (multiple - h % multiple) % multiple
170
+ pad_w = (multiple - w % multiple) % multiple
171
+
172
+ # Padding on the top, bottom, left, and right.
173
+ pad_top = pad_h // 2
174
+ pad_bottom = pad_h - pad_top
175
+ pad_left = pad_w // 2
176
+ pad_right = pad_w - pad_left
177
+
178
+ img_padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="reflect")
179
+ return img_padded, (pad_top, pad_bottom, pad_left, pad_right)
180
+
181
+ def remove_padding(self, img, pad_info):
182
+ """
183
+ Remove padding and restore to the original size, considering upscaling.
184
+ """
185
+ pad_top, pad_bottom, pad_left, pad_right = pad_info
186
+
187
+ # Adjust padding based on upscaling factor
188
+ pad_top = int(pad_top * self.upscale)
189
+ pad_bottom = int(pad_bottom * self.upscale)
190
+ pad_left = int(pad_left * self.upscale)
191
+ pad_right = int(pad_right * self.upscale)
192
+
193
+ return img[:, :, pad_top:-pad_bottom if pad_bottom > 0 else None, pad_left:-pad_right if pad_right > 0 else None]
194
+
195
+ def forward(self, x):
196
+ # Step 1: Auto padding
197
+ x_padded, pad_info = self.pad_to_multiple(x, self.window_size)
198
+
199
+ # Step 2: Normal model processing
200
+ x_processed = super().forward(x_padded)
201
+
202
+ # Step 3: Remove padding
203
+ x_cropped = self.remove_padding(x_processed, pad_info)
204
+ return x_cropped
205
+
206
+ # The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
207
+ # https://github.com/XPixelGroup/HAT/tree/main/options/test
208
+ if "hat-l" in realesr.lower():
209
+ window_size = 16
210
+ compress_ratio = 3
211
+ squeeze_factor = 30
212
+ depths = [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
213
+ embed_dim = 180
214
+ num_heads = [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
215
+ mlp_ratio = 2
216
+ upsampler = "pixelshuffle"
217
+ elif "hat-s" in realesr.lower():
218
+ window_size = 16
219
+ compress_ratio = 24
220
+ squeeze_factor = 24
221
+ depths = [6, 6, 6, 6, 6, 6]
222
+ embed_dim = 144
223
+ num_heads = [6, 6, 6, 6, 6, 6]
224
+ mlp_ratio = 2
225
+ upsampler = "pixelshuffle"
226
+ model = HATWithAutoPadding(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
227
+ squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=netscale,)
228
+ elif realesr_type == "RealPLKSR_dysample":
229
+ netscale = 4
230
+ model = realplksr(upscaling_factor=netscale, dysample=True)
231
+ elif realesr_type == "RealPLKSR":
232
+ half = False if "RealPLSKR" in realesr else half
233
+ netscale = 2 if realesr.startswith("2x") else 4
234
+ model = realplksr(dim=64, n_blocks=28, kernel_size=17, split_ratio=0.25, upscaling_factor=netscale)
235
+
236
+
237
+ self.upsampler = None
238
+ if loadnet:
239
+ self.upsampler = RealESRGANer(scale=netscale, loadnet=loadnet, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
240
+ elif model:
241
+ self.upsampler = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "realesr", realesr), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
242
+ elif realesr:
243
+ self.upsampler = None
244
+ import PIL
245
+ from image_gen_aux import UpscaleWithModel
246
+ class UpscaleWithModel_Gfpgan(UpscaleWithModel):
247
+ def cv2pil(self, image):
248
+ ''' OpenCV type -> PIL type
249
+ https://qiita.com/derodero24/items/f22c22b22451609908ee
250
+ '''
251
+ new_image = image.copy()
252
+ if new_image.ndim == 2: # Grayscale
253
+ pass
254
+ elif new_image.shape[2] == 3: # Color
255
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
256
+ elif new_image.shape[2] == 4: # Transparency
257
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
258
+ new_image = PIL.Image.fromarray(new_image)
259
+ return new_image
260
+
261
+ def pil2cv(self, image):
262
+ ''' PIL type -> OpenCV type
263
+ https://qiita.com/derodero24/items/f22c22b22451609908ee
264
+ '''
265
+ new_image = np.array(image, dtype=np.uint8)
266
+ if new_image.ndim == 2: # Grayscale
267
+ pass
268
+ elif new_image.shape[2] == 3: # Color
269
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
270
+ elif new_image.shape[2] == 4: # Transparency
271
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
272
+ return new_image
273
+
274
+ def enhance(self, img, outscale=None):
275
+ # img: numpy
276
+ h_input, w_input = img.shape[0:2]
277
+ pil_img = self.cv2pil(img)
278
+ pil_img = self.__call__(pil_img)
279
+ cv_image = self.pil2cv(pil_img)
280
+ if outscale is not None and outscale != float(netscale):
281
+ cv_image = cv2.resize(
282
+ cv_image, (
283
+ int(w_input * outscale),
284
+ int(h_input * outscale),
285
+ ), interpolation=cv2.INTER_LANCZOS4)
286
+ return cv_image, None
287
+
288
+ device = "cuda" if torch.cuda.is_available() else "cpu"
289
+ upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "realesr", realesr)).to(device)
290
+ upscaler.__class__ = UpscaleWithModel_Gfpgan
291
+ self.upsampler = upscaler
292
+ self.face_enhancer = None
293
+
294
+ if face_restoration:
295
+ from gfpgan.utils import GFPGANer
296
+ if face_restoration and face_restoration.startswith("GFPGANv1."):
297
+ self.face_enhancer = GFPGANer(model_path=os.path.join("weights", "face", face_restoration), upscale=self.scale, arch="clean", channel_multiplier=2, bg_upsampler=self.upsampler)
298
+ elif face_restoration and face_restoration.startswith("RestoreFormer"):
299
+ arch = "RestoreFormer++" if face_restoration.startswith("RestoreFormer++") else "RestoreFormer"
300
+ self.face_enhancer = GFPGANer(model_path=os.path.join("weights", "face", face_restoration), upscale=self.scale, arch=arch, channel_multiplier=2, bg_upsampler=self.upsampler)
301
+ elif face_restoration == 'CodeFormer.pth':
302
+ self.face_enhancer = GFPGANer(
303
+ model_path='weights/CodeFormer.pth', upscale=self.scale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=self.upsampler)
304
+
305
+
306
+ files = []
307
+ outputs = []
308
+ try:
309
+ bg_upsample_img = None
310
+ if self.upsampler and self.upsampler.enhance:
311
+ from utils.dataops import auto_split_upscale
312
+ bg_upsample_img, _ = auto_split_upscale(img, self.upsampler.enhance, self.scale) if is_auto_split_upscale else self.upsampler.enhance(img, outscale=self.scale)
313
+
314
+ if self.face_enhancer:
315
+ cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, bg_upsample_img=bg_upsample_img)
316
+ # save faces
317
+ if cropped_faces and restored_aligned:
318
+ for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
319
+ # save cropped face
320
+ save_crop_path = f"output/{self.basename}{idx:02d}_cropped_faces.png"
321
+ self.imwriteUTF8(save_crop_path, cropped_face)
322
+ # save restored face
323
+ save_restore_path = f"output/{self.basename}{idx:02d}_restored_faces.png"
324
+ self.imwriteUTF8(save_restore_path, restored_face)
325
+ # save comparison image
326
+ save_cmp_path = f"output/{self.basename}{idx:02d}_cmp.png"
327
+ cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
328
+ self.imwriteUTF8(save_cmp_path, cmp_img)
329
 
330
+ files.append(save_crop_path)
331
+ files.append(save_restore_path)
332
+ files.append(save_cmp_path)
333
+ outputs.append(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB))
334
+ outputs.append(cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB))
335
+ outputs.append(cv2.cvtColor(cmp_img, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ restored_img = bg_upsample_img
338
+ except RuntimeError as error:
339
+ print(traceback.format_exc())
340
+ print('Error', error)
341
+ finally:
342
+ if self.face_enhancer:
343
+ self.face_enhancer._cleanup()
344
+ else:
345
+ # Free GPU memory and clean up resources
346
+ torch.cuda.empty_cache()
347
+ gc.collect()
348
 
349
+ if not self.extension:
350
+ self.extension = ".png" if self.img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
351
+ save_path = f"output/{self.basename}{self.extension}"
352
+ self.imwriteUTF8(save_path, restored_img)
353
+
354
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
355
+ files.append(save_path)
356
+ outputs.append(restored_img)
357
+ return outputs, files
358
  except Exception as error:
359
  print(traceback.format_exc())
360
+ print("global exception", error)
361
+ return None, None
362
+
363
+
364
+ def infer_parameters_from_state_dict_for_dat(self, state_dict, upscale=4):
365
+ if "params" in state_dict:
366
+ state_dict = state_dict["params"]
367
+ elif "params_ema" in state_dict:
368
+ state_dict = state_dict["params_ema"]
369
+
370
+ inferred_params = {}
371
+
372
+ # Speculate on the depth.
373
+ depth = {}
374
+ for key in state_dict.keys():
375
+ if "blocks" in key:
376
+ layer = int(key.split(".")[1])
377
+ block = int(key.split(".")[3])
378
+ depth[layer] = max(depth.get(layer, 0), block + 1)
379
+ inferred_params["depth"] = [depth[layer] for layer in sorted(depth.keys())]
380
+
381
+ # Speculate on the number of num_heads per layer.
382
+ # ex.
383
+ # layers.0.blocks.1.attn.temperature: torch.Size([6, 1, 1])
384
+ # layers.5.blocks.5.attn.temperature: torch.Size([6, 1, 1])
385
+ # The shape of temperature is [num_heads, 1, 1].
386
+ num_heads = []
387
+ for layer in range(len(inferred_params["depth"])):
388
+ for block in range(inferred_params["depth"][layer]):
389
+ key = f"layers.{layer}.blocks.{block}.attn.temperature"
390
+ if key in state_dict:
391
+ num_heads_layer = state_dict[key].shape[0]
392
+ num_heads.append(num_heads_layer)
393
+ break
394
+
395
+ inferred_params["num_heads"] = num_heads
396
+
397
+ # Speculate on embed_dim.
398
+ # ex. layers.0.blocks.0.attn.qkv.weight: torch.Size([540, 180])
399
+ for key in state_dict.keys():
400
+ if "attn.qkv.weight" in key:
401
+ qkv_weight = state_dict[key]
402
+ embed_dim = qkv_weight.shape[1] # Note: The in_features of qkv corresponds to embed_dim.
403
+ inferred_params["embed_dim"] = embed_dim
404
+ break
405
+
406
+ # Speculate on split_size.
407
+ # ex.
408
+ # layers.0.blocks.0.attn.attns.0.rpe_biases: torch.Size([945, 2])
409
+ # layers.0.blocks.0.attn.attns.0.relative_position_index: torch.Size([256, 256])
410
+ # layers.0.blocks.2.attn.attn_mask_0: torch.Size([16, 256, 256])
411
+ # layers.0.blocks.2.attn.attn_mask_1: torch.Size([16, 256, 256])
412
+ for key in state_dict.keys():
413
+ if "relative_position_index" in key:
414
+ relative_position_size = state_dict[key].shape[0]
415
+ # Determine split_size[0] and split_size[1] based on the provided data.
416
+ split_size_0, split_size_1 = 8, relative_position_size // 8 # 256 = 8 * 32
417
+ inferred_params["split_size"] = [split_size_0, split_size_1]
418
+ break
419
+
420
+ # Speculate on the expansion_factor.
421
+ # ex.
422
+ # layers.0.blocks.0.ffn.fc1.weight: torch.Size([360, 180])
423
+ # layers.5.blocks.5.ffn.fc1.weight: torch.Size([360, 180])
424
+ if "embed_dim" in inferred_params:
425
+ for key in state_dict.keys():
426
+ if "ffn.fc1.weight" in key:
427
+ fc1_weight = state_dict[key]
428
+ expansion_factor = fc1_weight.shape[0] // inferred_params["embed_dim"]
429
+ inferred_params["expansion_factor"] = expansion_factor
430
+ break
431
+
432
+ inferred_params["img_size"] = 64
433
+ inferred_params["in_chans"] = 3 # Assume an RGB image.
434
+
435
+ for key in state_dict.keys():
436
+ print(f"{key}: {state_dict[key].shape}")
437
+
438
+ return inferred_params
439
+
440
+
441
+ def imwriteUTF8(self, save_path, image): # `cv2.imwrite` does not support writing files to UTF-8 file paths.
442
+ img_name = os.path.basename(save_path)
443
+ _, extension = os.path.splitext(img_name)
444
+ is_success, im_buf_arr = cv2.imencode(extension, image)
445
+ if (is_success): im_buf_arr.tofile(save_path)
446
+
447
+
448
+ def main():
449
+ if torch.cuda.is_available():
450
+ torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
451
+ # Ensure the target directory exists
452
+ os.makedirs('output', exist_ok=True)
453
+
454
+ # Iterate through each file
455
+ download_from_urls(files_to_download, ".")
456
+
457
+ title = "Image Upscaling & Restoration(esp. Face) using GFPGAN Algorithm"
458
+ description = r"""Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration and Upscalling of the image with a Generative Facial Prior</b></a>.<br>
459
+ Practically the algorithm is used to restore your **old photos** or improve **AI-generated faces**.<br>
460
+ To use it, simply just upload the concerned image.<br>
461
+ """
462
+ article = r"""
463
+ [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
464
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
465
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
466
+ <center><img src='https://visitor-badge.glitch.me/badge?page_id=dj_face_restoration_GFPGAN' alt='visitor badge'></center>
467
+ """
468
+
469
+ upscale = Upscale()
470
+
471
+ demo = gr.Interface(
472
+ upscale.inference, [
473
+ gr.Image(type="filepath", label="Input", format="png"),
474
+ gr.Dropdown(list(face_model.keys())+[None], type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model."),
475
+ gr.Dropdown(list(typed_realesr_model.keys())+[None], type="value", value='SRVGG, realesr-general-x4v3.pth', label='RealESR version'),
476
+ gr.Number(label="Rescaling factor", value=4),
477
+ ], [
478
+ gr.Gallery(type="numpy", label="Output (The whole image)", format="png"),
479
+ gr.File(label="Download the output image")
480
+ ],
481
+ title=title,
482
+ description=description,
483
+ article=article,
484
+ examples=[["a1.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2],
485
+ ["a2.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2],
486
+ ["a3.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2],
487
+ ["a4.jpg", "GFPGANv1.4.pth", "SRVGG, realesr-general-x4v3.pth", 2]])
488
 
489
+ demo.queue(default_concurrency_limit=4)
490
+ demo.launch(inbrowser=True)
491
+
492
+
493
+ if __name__ == "__main__":
494
+ main()
requirements.txt CHANGED
@@ -2,10 +2,11 @@
2
 
3
  gradio==5.8.0
4
 
5
- basicsr @ git+https://github.com/XPixelGroup/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
7
  gfpgan @ git+https://github.com/avan06/GFPGAN
8
  realesrgan @ git+https://github.com/avan06/Real-ESRGAN
 
9
  numpy
10
  opencv-python
11
 
@@ -18,4 +19,7 @@ scipy
18
  tqdm
19
  lmdb
20
  pyyaml
21
- yapf
 
 
 
 
2
 
3
  gradio==5.8.0
4
 
5
+ basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
7
  gfpgan @ git+https://github.com/avan06/GFPGAN
8
  realesrgan @ git+https://github.com/avan06/Real-ESRGAN
9
+
10
  numpy
11
  opencv-python
12
 
 
19
  tqdm
20
  lmdb
21
  pyyaml
22
+ yapf
23
+
24
+ image_gen_aux @ git+https://github.com/huggingface/image_gen_aux
25
+ gdown # supports downloading the large file from Google Drive
utils/dataops.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # The file source is from the [ESRGAN](https://github.com/xinntao/ESRGAN) project
4
+ # forked by authors [joeyballentine](https://github.com/joeyballentine/ESRGAN) and [BlueAmulet](https://github.com/BlueAmulet/ESRGAN).
5
+
6
+ import gc
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def bgr_to_rgb(image: torch.Tensor) -> torch.Tensor:
13
+ # flip image channels
14
+ # https://github.com/pytorch/pytorch/issues/229
15
+ out: torch.Tensor = image.flip(-3)
16
+ # out: torch.Tensor = image[[2, 1, 0], :, :] #RGB to BGR #may be faster
17
+ return out
18
+
19
+
20
+ def rgb_to_bgr(image: torch.Tensor) -> torch.Tensor:
21
+ # same operation as bgr_to_rgb(), flip image channels
22
+ return bgr_to_rgb(image)
23
+
24
+
25
+ def bgra_to_rgba(image: torch.Tensor) -> torch.Tensor:
26
+ out: torch.Tensor = image[[2, 1, 0, 3], :, :]
27
+ return out
28
+
29
+
30
+ def rgba_to_bgra(image: torch.Tensor) -> torch.Tensor:
31
+ # same operation as bgra_to_rgba(), flip image channels
32
+ return bgra_to_rgba(image)
33
+
34
+
35
+ def auto_split_upscale(
36
+ lr_img: np.ndarray,
37
+ upscale_function,
38
+ scale: int = 4,
39
+ overlap: int = 32,
40
+ max_depth: int = None,
41
+ current_depth: int = 1,
42
+ ):
43
+ # Attempt to upscale if unknown depth or if reached known max depth
44
+ if max_depth is None or max_depth == current_depth:
45
+ try:
46
+ print(f"auto_split_upscale, current depth: {current_depth}")
47
+ result, _ = upscale_function(lr_img, scale)
48
+ return result, current_depth
49
+ except RuntimeError as e:
50
+ # Check to see if its actually the CUDA out of memory error
51
+ if "CUDA" in str(e):
52
+ # Collect garbage (clear VRAM)
53
+ torch.cuda.empty_cache()
54
+ gc.collect()
55
+ # Re-raise the exception if not an OOM error
56
+ else:
57
+ raise RuntimeError(e)
58
+ finally:
59
+ # Free GPU memory and clean up resources
60
+ torch.cuda.empty_cache()
61
+ gc.collect()
62
+
63
+ h, w, c = lr_img.shape
64
+
65
+ # Split image into 4ths
66
+ top_left = lr_img[: h // 2 + overlap, : w // 2 + overlap, :]
67
+ top_right = lr_img[: h // 2 + overlap, w // 2 - overlap :, :]
68
+ bottom_left = lr_img[h // 2 - overlap :, : w // 2 + overlap, :]
69
+ bottom_right = lr_img[h // 2 - overlap :, w // 2 - overlap :, :]
70
+
71
+ # Recursively upscale the quadrants
72
+ # After we go through the top left quadrant, we know the maximum depth and no longer need to test for out-of-memory
73
+ top_left_rlt, depth = auto_split_upscale(
74
+ top_left,
75
+ upscale_function,
76
+ scale=scale,
77
+ overlap=overlap,
78
+ max_depth=max_depth,
79
+ current_depth=current_depth + 1,
80
+ )
81
+ top_right_rlt, _ = auto_split_upscale(
82
+ top_right,
83
+ upscale_function,
84
+ scale=scale,
85
+ overlap=overlap,
86
+ max_depth=depth,
87
+ current_depth=current_depth + 1,
88
+ )
89
+ bottom_left_rlt, _ = auto_split_upscale(
90
+ bottom_left,
91
+ upscale_function,
92
+ scale=scale,
93
+ overlap=overlap,
94
+ max_depth=depth,
95
+ current_depth=current_depth + 1,
96
+ )
97
+ bottom_right_rlt, _ = auto_split_upscale(
98
+ bottom_right,
99
+ upscale_function,
100
+ scale=scale,
101
+ overlap=overlap,
102
+ max_depth=depth,
103
+ current_depth=current_depth + 1,
104
+ )
105
+
106
+ # Define output shape
107
+ out_h = h * scale
108
+ out_w = w * scale
109
+
110
+ # Create blank output image
111
+ output_img = np.zeros((out_h, out_w, c), np.uint8)
112
+
113
+ # Fill output image with tiles, cropping out the overlaps
114
+ output_img[: out_h // 2, : out_w // 2, :] = top_left_rlt[
115
+ : out_h // 2, : out_w // 2, :
116
+ ]
117
+ output_img[: out_h // 2, -out_w // 2 :, :] = top_right_rlt[
118
+ : out_h // 2, -out_w // 2 :, :
119
+ ]
120
+ output_img[-out_h // 2 :, : out_w // 2, :] = bottom_left_rlt[
121
+ -out_h // 2 :, : out_w // 2, :
122
+ ]
123
+ output_img[-out_h // 2 :, -out_w // 2 :, :] = bottom_right_rlt[
124
+ -out_h // 2 :, -out_w // 2 :, :
125
+ ]
126
+
127
+ return output_img, depth