LambdaSuperRes / app.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame contribute delete
No virus
4.89 kB
import datetime
import hashlib
import numpy as np
import os
import subprocess
from pathlib import Path
from typing import Any, Dict
import cv2
import gradio as gr
from joblib import Parallel, delayed
from numpy.typing import NDArray
from PIL import Image
def _run_in_subprocess(command: str, wd: str) -> Any:
p = subprocess.Popen(command, shell=True, cwd=wd)
(output, err) = p.communicate()
p_status = p.wait()
print("Status of subprocess: ", p_status)
return p_status
SWIN_IR_WD = "KAIR"
SWINIR_CKPT_DIR: str = Path("KAIR/model_zoo/")
MODEL_NAME_TO_PATH: Dict[str, Path] = {
"LambdaSwinIR_v0.1": Path(str(SWINIR_CKPT_DIR) + "/805000_G.pth"),
}
SWINIR_NAME_TO_PATCH_SIZE: Dict[str, int] = {
"LambdaSwinIR_v0.1": 96,
}
SWINIR_NAME_TO_SCALE: Dict[str, int] = {
"LambdaSwinIR_v0.1": 2,
}
SWINIR_NAME_TO_LARGE_MODEL: Dict[str, bool] = {
"LambdaSwinIR_v0.1": False,
}
def _run_swin_ir(
image: NDArray,
model_path: Path,
patch_size: int,
scale: int,
is_large_model: bool,
):
print("model_path: ", str(model_path))
m = hashlib.sha256()
now_time = datetime.datetime.utcnow()
m.update(bytes(str(model_path), encoding='utf-8') +
bytes(now_time.strftime("%Y-%m-%d %H:%M:%S.%f"), encoding='utf-8'))
random_id = m.hexdigest()[0:20]
cwd = os.getcwd()
input_root = Path(cwd + "/sr_interactive_tmp")
input_root.mkdir(parents=True, exist_ok=True)
Image.fromarray(image).save(str(input_root) + "/gradio_img.png")
command = f"python main_test_swinir.py --scale {scale} " + \
f"--folder_lq {input_root} --task real_sr " + \
f"--model_path {cwd}/{model_path} --training_patch_size {patch_size}"
if is_large_model:
command += " --large_model"
print("COMMAND: ", command)
status = _run_in_subprocess(command, wd=cwd + "/" + SWIN_IR_WD)
print("STATUS: ", status)
if scale == 2:
str_scale = "2"
if scale == 4:
str_scale = "4_large"
output_img = Image.open(f"{cwd}/KAIR/results/swinir_real_sr_x{str_scale}/gradio_img_SwinIR.png")
output_root = Path("./sr_interactive_tmp_output")
output_root.mkdir(parents=True, exist_ok=True)
output_img.save(str(output_root) + "/SwinIR_" + random_id + ".png")
print("SAVING: SwinIR_" + random_id + ".png")
result = np.array(output_img)
return result
def _bilinear_upsample(image: NDArray):
result = cv2.resize(
image,
dsize=(image.shape[1] * 2, image.shape[0] * 2),
interpolation=cv2.INTER_LANCZOS4
)
return result
def _decide_sr_algo(model_name: str, image: NDArray):
# if "SwinIR" in model_name:
# result = _run_swin_ir(image,
# model_path=MODEL_NAME_TO_PATH[model_name],
# patch_size=SWINIR_NAME_TO_PATCH_SIZE[model_name],
# scale=SWINIR_NAME_TO_SCALE[model_name],
# is_large_model=("SwinIR-L" in model_name))
# else:
# result = _bilinear_upsample(image)
# elif algo == SR_OPTIONS[1]:
# result = _run_maxine(image, mode="SR")
# elif algo == SR_OPTIONS[2]:
# result = _run_maxine(image, mode="UPSCALE")
# return result
result = _run_swin_ir(image,
model_path=MODEL_NAME_TO_PATH[model_name],
patch_size=SWINIR_NAME_TO_PATCH_SIZE[model_name],
scale=SWINIR_NAME_TO_SCALE[model_name],
is_large_model=SWINIR_NAME_TO_LARGE_MODEL[model_name])
return result
def _super_resolve(model_name: str, input_img):
# futures = []
# with ThreadPoolExecutor(max_workers=4) as executor:
# for model_name in model_names:
# futures.append(executor.submit(_decide_sr_algo, model_name, input_img))
# return [f.result() for f in futures]
# return Parallel(n_jobs=2, prefer="threads")(
# delayed(_decide_sr_algo)(model_name, input_img)
# for model_name in model_names
# )
return _decide_sr_algo(model_name, input_img)
def _gradio_handler(sr_option: str, input_img: NDArray):
return _super_resolve(sr_option, input_img)
gr.close_all()
SR_OPTIONS = ["LambdaSwinIR_v0.1"]
examples = [
["LambdaSwinIR_v0.1", "examples/oldphoto6.png"],
["LambdaSwinIR_v0.1", "examples/Lincoln.png"],
["LambdaSwinIR_v0.1", "examples/OST_009.png"],
["LambdaSwinIR_v0.1", "examples/00003.png"],
["LambdaSwinIR_v0.1", "examples/00000067_cropped.png"],
]
ui = gr.Interface(fn=_gradio_handler,
inputs=[
gr.Radio(SR_OPTIONS),
gr.Image(image_mode="RGB")
],
outputs=["image"],
live=False,
examples=examples,
cache_examples=True)
ui.launch(enable_queue=True)