File size: 3,412 Bytes
5b2cc7a d8f7287 3e0a809 d8f7287 c63740a d8f7287 3e0a809 d8f7287 3e0a809 d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 3e0a809 c63740a 3e0a809 d8f7287 5b2cc7a 3e0a809 c63740a 3e0a809 5b2cc7a 3e0a809 5b2cc7a d8f7287 5b2cc7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import argparse
import os
import yaml
import time
from PIL import Image
import numpy as np
import torch
from diffusers import DiffusionPipeline
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion
torch.manual_seed(8)
def load_image(path):
"""
Load the image and normalize to [-1, 1]
"""
original_image = Image.open(path)
# Resize if needed
original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
return (original_image / 127.5 - 1.0).to(torch.float)
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main(args):
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
print(f"Using device {device_str}")
device = torch.device(device_str)
os.makedirs(args.output_dir, exist_ok=True)
original_image = load_image(args.input_image).to(device)
# Load the noise function
noise_config = load_yaml(args.noise_config)
noise_function = get_noise(**noise_config)
# Load the measurement function A
operator_config = load_yaml(args.operator_config)
operator_config["device"] = device
operator = get_operator(**operator_config)
if args.model_config.endswith(".yaml"):
# Local model from DPS
model_type = "dps"
model_config = load_yaml(args.model_config)
model = create_model(**model_config)
model = model.to(device)
model.eval()
else:
# Huggingface diffusers model
model_type = "diffusers"
model = DiffusionPipeline.from_pretrained(args.model_config).to("cuda").unet
# All the models have the same scheduler.
# you can change this for different models
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
prediction_type="epsilon",
timestep_spacing="leading",
steps_offset=0,
)
noisy_measurement = noise_function(operator(original_image))
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
t0 = time.time()
output_image = run_diffusion(
model, ddim_scheduler,
noisy_measurement, operator, noise_function, device,
num_inference_steps=args.T,
K=args.K,
model_type=model_type)
print(f"total time {time.time() - t0}")
save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("input_image", type=str)
parser.add_argument("T", type=int)
parser.add_argument("K", type=int)
parser.add_argument("model", type=str)
parser.add_argument("operator_config", type=str)
parser.add_argument("noise_config", type=str)
parser.add_argument("model_config", type=str)
parser.add_argument("--output-dir", default=".", type=str)
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
main(parser.parse_args()) |