|
import gradio as gr |
|
import torch |
|
import yaml |
|
import os |
|
import numpy as np |
|
from PIL import Image |
|
import time |
|
from cdim.noise import get_noise |
|
from cdim.operators import get_operator |
|
from cdim.image_utils import save_to_image |
|
from cdim.dps_model.dps_unet import create_model |
|
from cdim.diffusion.scheduling_ddim import DDIMScheduler |
|
from cdim.diffusion.diffusion_pipeline import run_diffusion |
|
from cdim.eta_scheduler import EtaScheduler |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = None |
|
ddim_scheduler = None |
|
model_type = None |
|
|
|
|
|
def load_image(image_path): |
|
"""Process input image to tensor format.""" |
|
image = Image.open(image_path) |
|
original_image = np.array(image.resize((256, 256), Image.BICUBIC)) |
|
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2) |
|
return (original_image / 127.5 - 1.0).to(torch.float)[:, :3] |
|
|
|
|
|
def load_yaml(file_path: str) -> dict: |
|
"""Load configurations from a YAML file.""" |
|
with open(file_path) as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
return config |
|
|
|
|
|
def convert_to_np(torch_image): |
|
return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8) |
|
|
|
|
|
def generate_noisy_image(image_choice, noise_sigma, operator_key): |
|
"""Generate the noisy image and store necessary data for restoration.""" |
|
|
|
image_paths = { |
|
"CelebA HQ 1": "sample_images/celebhq_29999.jpg", |
|
"CelebA HQ 2": "sample_images/celebhq_00001.jpg", |
|
"CelebA HQ 3": "sample_images/celebhq_00000.jpg" |
|
} |
|
|
|
config_paths = { |
|
"Box Inpainting": "operator_configs/box_inpainting_config.yaml", |
|
"Random Inpainting": "operator_configs/random_inpainting_config.yaml", |
|
"Super Resolution": "operator_configs/super_resolution_config.yaml", |
|
"Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml" |
|
} |
|
|
|
image_path = image_paths[image_choice] |
|
|
|
|
|
original_image = load_image(image_path).to(device) |
|
noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml") |
|
noise_config["sigma"] = noise_sigma |
|
noise_function = get_noise(**noise_config) |
|
operator_config = load_yaml(config_paths[operator_key]) |
|
operator_config["device"] = device |
|
operator = get_operator(**operator_config) |
|
|
|
noisy_measurement = noise_function(operator(original_image)) |
|
noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0])) |
|
|
|
|
|
data = { |
|
'noisy_measurement': noisy_measurement, |
|
'operator': operator, |
|
'noise_function': noise_function |
|
} |
|
|
|
return noisy_image, data |
|
|
|
|
|
def run_restoration(data, T, K): |
|
"""Run the restoration process and return the restored image.""" |
|
global model, ddim_scheduler, model_type |
|
|
|
|
|
noisy_measurement = data['noisy_measurement'] |
|
operator = data['operator'] |
|
noise_function = data['noise_function'] |
|
|
|
|
|
if model is None: |
|
model_type = "diffusers" |
|
model = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256").to(device).unet |
|
|
|
ddim_scheduler = DDIMScheduler( |
|
num_train_timesteps=1000, |
|
beta_start=0.0001, |
|
beta_end=0.02, |
|
beta_schedule="linear" |
|
) |
|
|
|
|
|
eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None) |
|
output_image = run_diffusion( |
|
model, ddim_scheduler, noisy_measurement, operator, noise_function, device, |
|
eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2' |
|
) |
|
|
|
|
|
output_image = Image.fromarray(convert_to_np(output_image[0])) |
|
return output_image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Noisy Image Restoration with Diffusion Models") |
|
|
|
with gr.Row(): |
|
T = gr.Slider(10, 200, value=50, step=1, label="Number of Inference Steps (T)") |
|
K = gr.Slider(1, 10, value=3, step=1, label="K Value") |
|
noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma") |
|
|
|
image_select = gr.Dropdown( |
|
choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3"], |
|
value="CelebA HQ 1", |
|
label="Select Input Image" |
|
) |
|
|
|
operator_select = gr.Dropdown( |
|
choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"], |
|
value="Box Inpainting", |
|
label="Select Task" |
|
) |
|
|
|
run_button = gr.Button("Run Inference") |
|
noisy_image = gr.Image(label="Noisy Image") |
|
restored_image = gr.Image(label="Restored Image") |
|
state = gr.State() |
|
|
|
|
|
run_button.click( |
|
fn=generate_noisy_image, |
|
inputs=[image_select, noise_sigma, operator_select], |
|
outputs=[noisy_image, state], |
|
).then( |
|
fn=run_restoration, |
|
inputs=[state, T, K], |
|
outputs=restored_image |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|