File size: 4,085 Bytes
5b2cc7a d8f7287 3e0a809 d8f7287 c63740a d8f7287 3e0a809 e6c2b25 d8f7287 e6c2b25 d8f7287 366a67c d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 3e0a809 c63740a 3e0a809 d8f7287 5b2cc7a e6c2b25 3e0a809 e6c2b25 3e0a809 c63740a 95aa1d5 3e0a809 5b2cc7a 3e0a809 e6c2b25 366a67c 5b2cc7a 95aa1d5 d8f7287 5b2cc7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import argparse
import os
import yaml
import time
from PIL import Image
import numpy as np
import torch
from diffusers import DiffusionPipeline
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion
from cdim.eta_scheduler import EtaScheduler
# torch.manual_seed(7)
def load_image(path):
"""
Load the image and normalize to [-1, 1]
"""
original_image = Image.open(path)
# Resize if needed
original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main(args):
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
print(f"Using device {device_str}")
device = torch.device(device_str)
os.makedirs(args.output_dir, exist_ok=True)
original_image = load_image(args.input_image).to(device)
# Load the noise function
noise_config = load_yaml(args.noise_config)
noise_function = get_noise(**noise_config)
# Load the measurement function A
operator_config = load_yaml(args.operator_config)
operator_config["device"] = device
operator = get_operator(**operator_config)
if args.model_config.endswith(".yaml"):
# Local model from DPS
model_type = "dps"
model_config = load_yaml(args.model_config)
model = create_model(**model_config)
model = model.to(device)
model.eval()
else:
# Huggingface diffusers model
model_type = "diffusers"
model = DiffusionPipeline.from_pretrained(args.model_config).to("cuda").unet
# All the models have the same scheduler.
# you can change this for different models
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
prediction_type="epsilon",
timestep_spacing="leading",
steps_offset=0,
)
noisy_measurement = noise_function(operator(original_image))
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
eta_scheduler = EtaScheduler(args.eta_type, operator.name, args.T,
args.K, args.loss, args.lambda_val)
t0 = time.time()
output_image = run_diffusion(
model, ddim_scheduler,
noisy_measurement, operator, noise_function, device,
eta_scheduler,
num_inference_steps=args.T,
K=args.K,
model_type=model_type,
loss_type=args.loss)
print(f"total time {time.time() - t0}")
save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("input_image", type=str)
parser.add_argument("T", type=int)
parser.add_argument("K", type=int)
parser.add_argument("operator_config", type=str)
parser.add_argument("noise_config", type=str)
parser.add_argument("model_config", type=str)
parser.add_argument("--eta-type", type=str,
choices=['gradnorm', 'expected_gradnorm'],
default='expected_gradnorm')
parser.add_argument("--lambda-val", type=float,
default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.")
parser.add_argument("--output-dir", default=".", type=str)
parser.add_argument("--loss", type=str,
choices=['l2', 'kl', 'categorical_kl'], default='l2',
help="Algorithm to use. Options: 'l2', 'kl', 'categorical_kl'. Default is 'l2'."
)
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
main(parser.parse_args()) |