Spaces:
Running
on
L4
Running
on
L4
import pickle | |
import json | |
import os | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import jax | |
from gradio_dualvision import DualVisionApp | |
from gradio_dualvision.gradio_patches.radio import Radio | |
from huggingface_hub import hf_hub_download | |
from model import build_thera | |
from super_resolve import process | |
REPO_ID_EDSR = "prs-eth/thera-edsr-pro" | |
REPO_ID_RDN = "prs-eth/thera-rdn-pro" | |
MAX_SIZE = 500 | |
print(f"JAX devices: {jax.devices()}") | |
print(f"JAX device type: {jax.devices()[0].device_kind}") | |
model_path = hf_hub_download(repo_id=REPO_ID_EDSR, filename="model.pkl") | |
with open(model_path, 'rb') as fh: | |
check = pickle.load(fh) | |
params_edsr, backbone, size = check['model'], check['backbone'], check['size'] | |
model_edsr = build_thera(3, backbone, size) | |
model_path = hf_hub_download(repo_id=REPO_ID_RDN, filename="model.pkl") | |
with open(model_path, 'rb') as fh: | |
check = pickle.load(fh) | |
params_rdn, backbone, size = check['model'], check['backbone'], check['size'] | |
model_rdn = build_thera(3, backbone, size) | |
class TheraApp(DualVisionApp): | |
DEFAULT_SCALE = 3.92 | |
DEFAULT_DO_ENSEMBLE = False | |
DEFAULT_MODEL = 'edsr' | |
def make_header(self): | |
gr.Markdown( | |
""" | |
## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields | |
<p align="center"> | |
<a title="Website" href="https://therasr.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/pdf/2311.17643" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436"> | |
</a> | |
<a title="Github" href="https://github.com/prs-eth/thera" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/prs-eth/thera?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
</p> | |
<p align="center" style="margin-top: 0px;"> | |
<strong>Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!</strong> | |
</p> | |
<p align="center" style="margin-top: 0px;"> | |
<strong>Note: The model has not been trained on input images with JPEG artifacts, so this may not work well.</strong> | |
</p> | |
""" | |
) | |
def build_user_components(self): | |
with gr.Row(): | |
scale = gr.Slider( | |
label="Scaling factor", | |
minimum=1, | |
maximum=6, | |
step=0.01, | |
value=self.DEFAULT_SCALE, | |
) | |
model = gr.Radio( | |
[ | |
("EDSR", 'edsr'), | |
("RDN", 'rdn'), | |
], | |
label="Backbone", | |
value=self.DEFAULT_MODEL, | |
) | |
do_ensemble = gr.Radio( | |
[ | |
("No", False), | |
("Yes", True), | |
], | |
label="Do Ensemble", | |
value=self.DEFAULT_DO_ENSEMBLE, | |
) | |
return { | |
"scale": scale, | |
"model": model, | |
"do_ensemble": do_ensemble, | |
} | |
def process(self, image_in: Image.Image, **kwargs): | |
scale = kwargs.get("scale", self.DEFAULT_SCALE) | |
do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE) | |
model = kwargs.get("model", self.DEFAULT_MODEL) | |
if max(*image_in.size) > MAX_SIZE: | |
gr.Warning(f"The image has been cropped to enable a smooth experience for all users.") | |
width, height = image_in.size | |
crop_width = min(width, MAX_SIZE) | |
crop_height = min(height, MAX_SIZE) | |
left = (width - crop_width) / 2 | |
top = (height - crop_height) / 2 | |
right = left + crop_width | |
bottom = top + crop_height | |
image_in = image_in.crop((left, top, right, bottom)) | |
source = np.asarray(image_in) / 255. | |
# determine target shape | |
target_shape = ( | |
round(source.shape[0] * scale), | |
round(source.shape[1] * scale), | |
) | |
if model == 'edsr': | |
m, p = model_edsr, params_edsr | |
elif model == 'rdn': | |
m, p = model_rdn, params_rdn | |
else: | |
raise NotImplementedError('model:', model) | |
out = process(source, m, p, target_shape, do_ensemble=do_ensemble) | |
out = Image.fromarray(np.asarray(out)) | |
nearest = image_in.resize(out.size, Image.NEAREST) | |
out_modalities = { | |
"nearest": nearest, | |
"out": out, | |
} | |
out_settings = { | |
'scale': scale, | |
'model': model, | |
'do_ensemble': do_ensemble, | |
} | |
return out_modalities, out_settings | |
def process_components( | |
self, image_in, modality_selector_left, modality_selector_right, **kwargs | |
): | |
if image_in is None: | |
raise gr.Error("Input image is required") | |
image_settings = {} | |
if isinstance(image_in, str): | |
image_settings_path = image_in + ".settings.json" | |
if os.path.isfile(image_settings_path): | |
with open(image_settings_path, "r") as f: | |
image_settings = json.load(f) | |
image_in = Image.open(image_in).convert("RGB") | |
else: | |
if not isinstance(image_in, Image.Image): | |
raise gr.Error(f"Input must be a PIL image, got {type(image_in)}") | |
image_in = image_in.convert("RGB") | |
image_settings.update(kwargs) | |
results_dict, results_settings = self.process(image_in, **image_settings) | |
if not isinstance(results_dict, dict): | |
raise gr.Error( | |
f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}" | |
) | |
if len(results_dict) == 0: | |
raise gr.Error("`process` did not return any modalities") | |
for k, v in results_dict.items(): | |
if not isinstance(k, str): | |
raise gr.Error( | |
f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}" | |
) | |
if k == self.key_original_image: | |
raise gr.Error( | |
f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input" | |
) | |
if not isinstance(v, Image.Image): | |
raise gr.Error( | |
f"Value for key '{k}' must be a PIL Image, got type {type(v)}" | |
) | |
if len(results_settings) != len(self.input_keys): | |
raise gr.Error( | |
f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})" | |
) | |
if any(k not in results_settings for k in self.input_keys): | |
raise gr.Error(f"Mismatching setgings keys") | |
results_settings = { | |
k: cls(**ctor_args, value=results_settings[k]) | |
for k, cls, ctor_args in zip( | |
self.input_keys, self.input_cls, self.input_kwargs | |
) | |
} | |
results_dict = { | |
**results_dict, | |
self.key_original_image: image_in, | |
} | |
results_state = [[v, k] for k, v in results_dict.items()] | |
modalities = list(results_dict.keys()) | |
modality_left = ( | |
modality_selector_left | |
if modality_selector_left in modalities | |
else modalities[0] | |
) | |
modality_right = ( | |
modality_selector_right | |
if modality_selector_right in modalities | |
else modalities[1] | |
) | |
return [ | |
results_state, # goes to a gr.Gallery | |
[ | |
results_dict[modality_left], | |
results_dict[modality_right], | |
], # ImageSliderPlus | |
Radio( | |
choices=modalities, | |
value=modality_left, | |
label="Left", | |
key="Left", | |
), | |
Radio( | |
choices=modalities if self.left_selector_visible else modalities[1:], | |
value=modality_right, | |
label="Right", | |
key="Right", | |
), | |
*results_settings.values(), | |
] | |
with TheraApp( | |
title="Thera Arbitrary-Scale Super-Resolution", | |
examples_path="files", | |
examples_per_page=12, | |
squeeze_canvas=True, | |
advanced_settings_can_be_half_width=False, | |
#spaces_zero_gpu_enabled=True, | |
) as demo: | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |