Spaces:
Running
Running
import datetime | |
import hashlib | |
import numpy as np | |
import os | |
import subprocess | |
from pathlib import Path | |
from typing import Any, Dict | |
import cv2 | |
import gradio as gr | |
from joblib import Parallel, delayed | |
from numpy.typing import NDArray | |
from PIL import Image | |
def _run_in_subprocess(command: str, wd: str) -> Any: | |
p = subprocess.Popen(command, shell=True, cwd=wd) | |
(output, err) = p.communicate() | |
p_status = p.wait() | |
print("Status of subprocess: ", p_status) | |
return p_status | |
SWIN_IR_WD = "KAIR" | |
SWINIR_CKPT_DIR: str = Path("KAIR/model_zoo/") | |
MODEL_NAME_TO_PATH: Dict[str, Path] = { | |
"LambdaSwinIR_v0.1": Path(str(SWINIR_CKPT_DIR) + "/805000_G.pth"), | |
} | |
SWINIR_NAME_TO_PATCH_SIZE: Dict[str, int] = { | |
"LambdaSwinIR_v0.1": 96, | |
} | |
SWINIR_NAME_TO_SCALE: Dict[str, int] = { | |
"LambdaSwinIR_v0.1": 2, | |
} | |
SWINIR_NAME_TO_LARGE_MODEL: Dict[str, bool] = { | |
"LambdaSwinIR_v0.1": False, | |
} | |
def _run_swin_ir( | |
image: NDArray, | |
model_path: Path, | |
patch_size: int, | |
scale: int, | |
is_large_model: bool, | |
): | |
print("model_path: ", str(model_path)) | |
m = hashlib.sha256() | |
now_time = datetime.datetime.utcnow() | |
m.update(bytes(str(model_path), encoding='utf-8') + | |
bytes(now_time.strftime("%Y-%m-%d %H:%M:%S.%f"), encoding='utf-8')) | |
random_id = m.hexdigest()[0:20] | |
cwd = os.getcwd() | |
input_root = Path(cwd + "/sr_interactive_tmp") | |
input_root.mkdir(parents=True, exist_ok=True) | |
Image.fromarray(image).save(str(input_root) + "/gradio_img.png") | |
command = f"python main_test_swinir.py --scale {scale} " + \ | |
f"--folder_lq {input_root} --task real_sr " + \ | |
f"--model_path {cwd}/{model_path} --training_patch_size {patch_size}" | |
if is_large_model: | |
command += " --large_model" | |
print("COMMAND: ", command) | |
status = _run_in_subprocess(command, wd=cwd + "/" + SWIN_IR_WD) | |
print("STATUS: ", status) | |
if scale == 2: | |
str_scale = "2" | |
if scale == 4: | |
str_scale = "4_large" | |
output_img = Image.open(f"{cwd}/KAIR/results/swinir_real_sr_x{str_scale}/gradio_img_SwinIR.png") | |
output_root = Path("./sr_interactive_tmp_output") | |
output_root.mkdir(parents=True, exist_ok=True) | |
output_img.save(str(output_root) + "/SwinIR_" + random_id + ".png") | |
print("SAVING: SwinIR_" + random_id + ".png") | |
result = np.array(output_img) | |
return result | |
def _bilinear_upsample(image: NDArray): | |
result = cv2.resize( | |
image, | |
dsize=(image.shape[1] * 2, image.shape[0] * 2), | |
interpolation=cv2.INTER_LANCZOS4 | |
) | |
return result | |
def _decide_sr_algo(model_name: str, image: NDArray): | |
# if "SwinIR" in model_name: | |
# result = _run_swin_ir(image, | |
# model_path=MODEL_NAME_TO_PATH[model_name], | |
# patch_size=SWINIR_NAME_TO_PATCH_SIZE[model_name], | |
# scale=SWINIR_NAME_TO_SCALE[model_name], | |
# is_large_model=("SwinIR-L" in model_name)) | |
# else: | |
# result = _bilinear_upsample(image) | |
# elif algo == SR_OPTIONS[1]: | |
# result = _run_maxine(image, mode="SR") | |
# elif algo == SR_OPTIONS[2]: | |
# result = _run_maxine(image, mode="UPSCALE") | |
# return result | |
result = _run_swin_ir(image, | |
model_path=MODEL_NAME_TO_PATH[model_name], | |
patch_size=SWINIR_NAME_TO_PATCH_SIZE[model_name], | |
scale=SWINIR_NAME_TO_SCALE[model_name], | |
is_large_model=SWINIR_NAME_TO_LARGE_MODEL[model_name]) | |
return result | |
def _super_resolve(model_name: str, input_img): | |
# futures = [] | |
# with ThreadPoolExecutor(max_workers=4) as executor: | |
# for model_name in model_names: | |
# futures.append(executor.submit(_decide_sr_algo, model_name, input_img)) | |
# return [f.result() for f in futures] | |
# return Parallel(n_jobs=2, prefer="threads")( | |
# delayed(_decide_sr_algo)(model_name, input_img) | |
# for model_name in model_names | |
# ) | |
return _decide_sr_algo(model_name, input_img) | |
def _gradio_handler(sr_option: str, input_img: NDArray): | |
return _super_resolve(sr_option, input_img) | |
gr.close_all() | |
SR_OPTIONS = ["LambdaSwinIR_v0.1"] | |
examples = [ | |
["LambdaSwinIR_v0.1", "examples/oldphoto6.png"], | |
["LambdaSwinIR_v0.1", "examples/Lincoln.png"], | |
["LambdaSwinIR_v0.1", "examples/OST_009.png"], | |
["LambdaSwinIR_v0.1", "examples/00003.png"], | |
["LambdaSwinIR_v0.1", "examples/00000067_cropped.png"], | |
] | |
ui = gr.Interface(fn=_gradio_handler, | |
inputs=[ | |
gr.Radio(SR_OPTIONS), | |
gr.Image(image_mode="RGB") | |
], | |
outputs=["image"], | |
live=False, | |
examples=examples, | |
cache_examples=True) | |
ui.launch(enable_queue=True) | |