File size: 3,076 Bytes
5b2cc7a
 
d8f7287
3e0a809
d8f7287
 
 
 
 
 
 
 
3e0a809
 
 
d8f7287
3e0a809
d8f7287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2cc7a
 
 
d8f7287
 
 
 
5b2cc7a
d8f7287
5b2cc7a
d8f7287
 
 
5b2cc7a
d8f7287
 
 
 
3e0a809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8f7287
 
 
5b2cc7a
3e0a809
 
 
 
 
 
 
 
 
 
5b2cc7a
 
 
 
 
 
 
 
3e0a809
5b2cc7a
d8f7287
 
5b2cc7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import os
import yaml
import time

from PIL import Image
import numpy as np
import torch

from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion

torch.manual_seed(8)

def load_image(path):
    """
    Load the image and normalize to [-1, 1]
    """
    original_image = Image.open(path)

    # Resize if needed
    original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
    original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
    return (original_image / 127.5 - 1.0).to(torch.float)
    

def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


def main(args):
    device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
    print(f"Using device {device_str}")
    device = torch.device(device_str) 

    os.makedirs(args.output_dir, exist_ok=True)
    original_image = load_image(args.input_image).to(device)

    # Load the noise function
    noise_config = load_yaml(args.noise_config)
    noise_function = get_noise(**noise_config)

    # Load the measurement function A
    operator_config = load_yaml(args.operator_config)
    operator_config["device"] = device
    operator = get_operator(**operator_config)

    # Load the model
    model_config = load_yaml(args.model_config)
    model = create_model(**model_config)
    model = model.to(device)
    model.eval()

    # All the models have the same scheduler.
    # you can change this for different models
    ddim_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
        prediction_type="epsilon",
        timestep_spacing="leading",
        steps_offset=0,
    )

    noisy_measurement = noise_function(operator(original_image))
    save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))

    t0 = time.time()
    output_image = run_diffusion(
        model, ddim_scheduler,
        noisy_measurement, operator, noise_function, device,
        num_inference_steps=args.T,
        K=args.K)
    print(f"total time {time.time() - t0}")

    save_to_image(output_image, os.path.join(args.output_dir, "output.png"))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("input_image", type=str)
    parser.add_argument("T", type=int)
    parser.add_argument("K", type=int)
    parser.add_argument("model", type=str)
    parser.add_argument("operator_config", type=str)
    parser.add_argument("noise_config", type=str)
    parser.add_argument("model_config", type=str)
    parser.add_argument("--output-dir", default=".", type=str)
    parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)

    main(parser.parse_args())