File size: 3,076 Bytes
5b2cc7a d8f7287 3e0a809 d8f7287 3e0a809 d8f7287 3e0a809 d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 3e0a809 d8f7287 5b2cc7a 3e0a809 5b2cc7a 3e0a809 5b2cc7a d8f7287 5b2cc7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import argparse
import os
import yaml
import time
from PIL import Image
import numpy as np
import torch
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion
torch.manual_seed(8)
def load_image(path):
"""
Load the image and normalize to [-1, 1]
"""
original_image = Image.open(path)
# Resize if needed
original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
return (original_image / 127.5 - 1.0).to(torch.float)
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main(args):
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
print(f"Using device {device_str}")
device = torch.device(device_str)
os.makedirs(args.output_dir, exist_ok=True)
original_image = load_image(args.input_image).to(device)
# Load the noise function
noise_config = load_yaml(args.noise_config)
noise_function = get_noise(**noise_config)
# Load the measurement function A
operator_config = load_yaml(args.operator_config)
operator_config["device"] = device
operator = get_operator(**operator_config)
# Load the model
model_config = load_yaml(args.model_config)
model = create_model(**model_config)
model = model.to(device)
model.eval()
# All the models have the same scheduler.
# you can change this for different models
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
prediction_type="epsilon",
timestep_spacing="leading",
steps_offset=0,
)
noisy_measurement = noise_function(operator(original_image))
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
t0 = time.time()
output_image = run_diffusion(
model, ddim_scheduler,
noisy_measurement, operator, noise_function, device,
num_inference_steps=args.T,
K=args.K)
print(f"total time {time.time() - t0}")
save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("input_image", type=str)
parser.add_argument("T", type=int)
parser.add_argument("K", type=int)
parser.add_argument("model", type=str)
parser.add_argument("operator_config", type=str)
parser.add_argument("noise_config", type=str)
parser.add_argument("model_config", type=str)
parser.add_argument("--output-dir", default=".", type=str)
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
main(parser.parse_args()) |