VIVEK JAYARAM
commited on
Commit
•
e6c2b25
1
Parent(s):
64e8afb
Expected gradnorm
Browse files- cdim/diffusion/diffusion_pipeline.py +3 -1
- cdim/eta_scheduler.py +52 -0
- cdim/etas.json +50 -0
- cdim/operators/__init__.py +10 -2
- inference.py +11 -1
- operator_configs/super_resolution_config.yaml +1 -1
cdim/diffusion/diffusion_pipeline.py
CHANGED
@@ -23,6 +23,7 @@ def run_diffusion(
|
|
23 |
operator,
|
24 |
noise_function,
|
25 |
device,
|
|
|
26 |
num_inference_steps: int = 1000,
|
27 |
K=5,
|
28 |
image_dim=256,
|
@@ -90,7 +91,8 @@ def run_diffusion(
|
|
90 |
else:
|
91 |
raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
|
92 |
|
93 |
-
|
|
|
94 |
image = image.detach().requires_grad_()
|
95 |
|
96 |
return image
|
|
|
23 |
operator,
|
24 |
noise_function,
|
25 |
device,
|
26 |
+
eta_scheduler,
|
27 |
num_inference_steps: int = 1000,
|
28 |
K=5,
|
29 |
image_dim=256,
|
|
|
91 |
else:
|
92 |
raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
|
93 |
|
94 |
+
step_size = eta_scheduler.get_step_size(str(t.item()), torch.linalg.norm(image.grad))
|
95 |
+
image -= step_size * image.grad
|
96 |
image = image.detach().requires_grad_()
|
97 |
|
98 |
return image
|
cdim/eta_scheduler.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
class EtaScheduler:
|
4 |
+
def __init__(self, method, task, T, K, loss_type, lambda_val=None):
|
5 |
+
self.task = task
|
6 |
+
self.T = T
|
7 |
+
self.K = K
|
8 |
+
self.loss_type = loss_type
|
9 |
+
self.lambda_val = lambda_val
|
10 |
+
self.method = method
|
11 |
+
|
12 |
+
self.precomputed_etas = self._load_precomputed_etas()
|
13 |
+
# Couldn't find expected gradnorm
|
14 |
+
if not self.precomputed_etas and method == "expected_gradnorm":
|
15 |
+
self.method = "gradnorm"
|
16 |
+
print("Etas for this configuration not found. Switching to gradnorm.")
|
17 |
+
|
18 |
+
# Get the best lambda_val if it's not passed
|
19 |
+
if self.lambda_val is None:
|
20 |
+
if self.method == "expected_gradnorm":
|
21 |
+
self.lambda_val = self.precomputed_etas["lambda"]
|
22 |
+
else:
|
23 |
+
self.lambda_val = self.best_guess_lambda()
|
24 |
+
print(f"Using lambda {self.lambda_val}")
|
25 |
+
|
26 |
+
def _load_precomputed_etas(self):
|
27 |
+
steps_key = f"T{self.T}_K{self.K}"
|
28 |
+
with open("cdim/etas.json") as f:
|
29 |
+
all_etas = json.load(f)
|
30 |
+
|
31 |
+
return all_etas.get(self.task, {}).get(self.loss_type, {}).get(steps_key, {})
|
32 |
+
|
33 |
+
def get_step_size(self, t, grad_norm):
|
34 |
+
"""Use either precomputed expected gradnorm or gradnorm."""
|
35 |
+
if self.method == "expected_gradnorm":
|
36 |
+
step_size = self.lambda_val * 1 / self.precomputed_etas["etas"][t]
|
37 |
+
else:
|
38 |
+
step_size = self.lambda_val * 1 / grad_norm
|
39 |
+
return step_size
|
40 |
+
|
41 |
+
def best_guess_lambda(self):
|
42 |
+
"""Guess a lambda value if not provided. Based on trial and error"""
|
43 |
+
total_steps = self.T * self.K
|
44 |
+
|
45 |
+
# L2 tends to over optimize too aggressively, so the default lr is lower
|
46 |
+
if self.loss_type == "kl":
|
47 |
+
return 350 / total_steps
|
48 |
+
elif self.loss_type == "l2":
|
49 |
+
return 220 / total_steps
|
50 |
+
else:
|
51 |
+
raise ValueError(f"Please provide learning rate for loss type {self.loss_type}")
|
52 |
+
|
cdim/etas.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"box_inpainting": {
|
3 |
+
"kl" : {
|
4 |
+
"T50_K3": {
|
5 |
+
"lambda": 4.0,
|
6 |
+
"etas": {"20": 0.0007644674769138669, "40": 0.0007352405159032115, "60": 0.0007339323611117017, "80": 0.0007407785604665967, "100": 0.0007549588527783352, "120": 0.0007624788939223991, "140": 0.0007797297457310411, "160": 0.0008114829360981001, "180": 0.0008749328277143722, "200": 0.0009383883323579534, "220": 0.0010238270356100192, "240": 0.0010784138469520864, "260": 0.0011866835135642144, "280": 0.0013024192080464585, "300": 0.001475853782213996, "320": 0.001632385599489279, "340": 0.0018509140273253787, "360": 0.0020580830457055752, "380": 0.0023878463159672523, "400": 0.0026761896618615814, "420": 0.0031795443584943708, "440": 0.0034981176066313513, "460": 0.004353874833222008, "480": 0.004711352191433642, "500": 0.005726754667302881, "520": 0.006623946294438855, "540": 0.00863887341585953, "560": 0.009790771703642386, "580": 0.011950644941955498, "600": 0.013428242186066879, "620": 0.017515641845500456, "640": 0.01918877035399789, "660": 0.024788169966172627, "680": 0.026881696669367267, "700": 0.034675169438193654, "720": 0.039756030503755736, "740": 0.04854181173586306, "760": 0.05583204092049454, "780": 0.06820393054365007, "800": 0.0775198126468748, "820": 0.09482237250830504, "840": 0.10625708906759426, "860": 0.1277688814432282, "880": 0.1429259541984142, "900": 0.1730718536839941, "920": 0.19272152879630572, "940": 0.2197799661143284, "960": 0.2523353552270739, "980": 0.24785875279032762}
|
7 |
+
},
|
8 |
+
"T25_K1": {
|
9 |
+
"lambda": 9.0,
|
10 |
+
"etas": {"40": 0.002007311321255281, "80": 0.0017914309657763605, "120": 0.0018785649978129587, "160": 0.0019939875678122057, "200": 0.002077116515854901, "240": 0.0022416421013690723, "280": 0.002525056453252015, "320": 0.002990539520210802, "360": 0.003687435867772619, "400": 0.00461278213332024, "440": 0.005670295237520474, "480": 0.007450458259252178, "520": 0.01104231419851132, "560": 0.01304287128591384, "600": 0.01770141680890715, "640": 0.02673789939069817, "680": 0.03434346958799475, "720": 0.04184474771847407, "760": 0.05826798834659389, "800": 0.07392334565984734, "840": 0.09130451088376304, "880": 0.11364396078616029, "920": 0.1595581552644202, "960": 0.17525405659705573}
|
11 |
+
}
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"random_inpainting": {
|
15 |
+
"kl" : {
|
16 |
+
"T50_K3": {
|
17 |
+
"lambda": 3.0,
|
18 |
+
"etas": {"20": 4.2155235275353885e-05, "40": 2.9131803821087877e-05, "60": 3.138850996275609e-05, "80": 3.386381497611811e-05, "100": 3.631776565956598e-05, "120": 3.973121110139383e-05, "140": 4.3322363720455985e-05, "160": 4.863932933274831e-05, "180": 5.240606242179506e-05, "200": 5.814939458443417e-05, "220": 6.197553145350107e-05, "240": 6.662394612209601e-05, "260": 7.353187590939113e-05, "280": 7.957012424639822e-05, "300": 8.54105910002775e-05, "320": 9.482279846848642e-05, "340": 0.00010283310990513798, "360": 0.00011288007608863972, "380": 0.00012768665988629105, "400": 0.00014661738441383113, "420": 0.00016301237603343505, "440": 0.00018253602035619805, "460": 0.00020806616934057048, "480": 0.00022939095079289252, "500": 0.00024356857801398671, "520": 0.00027035908261852304, "540": 0.0003041859182782007, "560": 0.0003671558660027532, "580": 0.00043392769502977437, "600": 0.0004871999766289279, "620": 0.0006560545650662386, "640": 0.0011521633386963458, "660": 0.0010345512272746527, "680": 0.0014703241871573396, "700": 0.0022689739261214977, "720": 0.002714861129921377, "740": 0.005800220039199319, "760": 0.0072751867583119864, "780": 0.012355766658887645, "800": 0.014635295468316744, "820": 0.011244771247342284, "840": 0.010248225188751379, "860": 0.013109683696949682, "880": 0.01479392021174634, "900": 0.016892151099455268, "920": 0.017663442143744875, "940": 0.01820524420826548, "960": 0.1109586885644798, "980": 0.10368549774809521}
|
19 |
+
},
|
20 |
+
"T25_K1": {
|
21 |
+
"lambda": 6.0,
|
22 |
+
"etas": {"40": 3.772142176929246e-05, "80": 2.2706469625575782e-05, "120": 2.6326200988558663e-05, "160": 3.1264416616630894e-05, "200": 3.734036301684134e-05, "240": 4.846247119505167e-05, "280": 6.404192925339954e-05, "320": 8.471533678103274e-05, "360": 0.00011057770282582904, "400": 0.00015357450981873662, "440": 0.00020928137685536158, "480": 0.00028775575102761786, "520": 0.0003906344462048851, "560": 0.0006252858438352591, "600": 0.0008340419507301546, "640": 0.0011431077222103293, "680": 0.0017961608713683882, "720": 0.0031925221564149363, "760": 0.004980067265202975, "800": 0.008485425858496985, "840": 0.01032865422644551, "880": 0.01237070838947314, "920": 0.03340125151375621, "960": 0.08094067944719611}
|
23 |
+
}
|
24 |
+
}
|
25 |
+
},
|
26 |
+
"super_resolution": {
|
27 |
+
"kl" : {
|
28 |
+
"T50_K3": {
|
29 |
+
"lambda": 3.0,
|
30 |
+
"etas": {"20": 0.000778915347623841, "40": 0.0007633780339799099, "60": 0.0007908141231719096, "80": 0.000788543595116803, "100": 0.0008288430294987044, "120": 0.0008392907314376583, "140": 0.0008925030169959898, "160": 0.0009190516306847781, "180": 0.000986800913723534, "200": 0.0010337262985591182, "220": 0.0011219889907389093, "240": 0.0011943036330360538, "260": 0.001310226177130967, "280": 0.0014183460616152988, "300": 0.0015733242833945388, "320": 0.001730196133012654, "340": 0.001937516604413998, "360": 0.002161490631568267, "380": 0.0024432350890119504, "400": 0.002708168867773674, "420": 0.00307449542457626, "440": 0.003515836769282733, "460": 0.004125363116570879, "480": 0.004754074570960041, "500": 0.005524380339583777, "520": 0.006439562376976526, "540": 0.007540773754592587, "560": 0.009500516111354843, "580": 0.011261532776188401, "600": 0.01348216019559666, "620": 0.01757517906334087, "640": 0.01977655744741228, "660": 0.024533713245498356, "680": 0.02746674570831025, "700": 0.03332074050640432, "720": 0.040243771626849834, "740": 0.04550001243601292, "760": 0.05033544385241358, "780": 0.061410491652647066, "800": 0.07221948341628477, "820": 0.0856409270444583, "840": 0.09925753198252922, "860": 0.12306635464984625, "880": 0.14476034450020075, "900": 0.17468273558090733, "920": 0.20411002595502606, "940": 0.24450559755448473, "960": 0.28248054410917073, "980": 0.3147428834175}
|
31 |
+
},
|
32 |
+
"T25_K1": {
|
33 |
+
"lambda": 7.5,
|
34 |
+
"etas": {"40": 0.002727176113620816, "80": 0.002118729856897165, "120": 0.002106326729054324, "160": 0.002295490187260311, "200": 0.002577219155480812, "240": 0.0029055148578559465, "280": 0.0032434651627058597, "320": 0.003844541849902738, "360": 0.004323113740483335, "400": 0.00514218754706466, "440": 0.005974186955805667, "480": 0.0072513396445271225, "520": 0.009855938730607262, "560": 0.012369151964727335, "600": 0.015923742821581138, "640": 0.021501799745055887, "680": 0.02177964029551375, "720": 0.03612378754790826, "760": 0.04641127936774787, "800": 0.06494989916290186, "840": 0.07686161107101508, "880": 0.10411057668338695, "920": 0.21030777467686249, "960": 0.1347356123776581}
|
35 |
+
}
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"gaussian_blur": {
|
39 |
+
"kl" : {
|
40 |
+
"T50_K3": {
|
41 |
+
"lambda": 2.5,
|
42 |
+
"etas": {"20": 0.001311277637655049, "40": 0.0011314792208201052, "60": 0.0013186921772428831, "80": 0.001161965484160248, "100": 0.0013730683554339122, "120": 0.0012312984890480417, "140": 0.0014760163095296494, "160": 0.0013462895586367368, "180": 0.0016383561831276197, "200": 0.0015191852141152255, "220": 0.0018704796563358576, "240": 0.0017668750655908613, "260": 0.0022029732696644934, "280": 0.002117440832022425, "300": 0.0026667582305060513, "320": 0.002615673065864779, "340": 0.0033227378642475855, "360": 0.0033246281530489056, "380": 0.004257323198960001, "400": 0.004339347273384217, "420": 0.005592718291095555, "440": 0.005808965844831082, "460": 0.007536144611100317, "480": 0.00795709164619601, "500": 0.010367841265905342, "520": 0.011111897162984073, "540": 0.014601192421942862, "560": 0.01572962570722594, "580": 0.020453046896395446, "600": 0.02245033051610993, "620": 0.028807566692203808, "640": 0.03205876293402484, "660": 0.041381670842137, "680": 0.04627552317583222, "700": 0.05792698423136997, "720": 0.0641696207414127, "740": 0.07855217157377192, "760": 0.08430869243734988, "780": 0.10214072576914401, "800": 0.11871801801180902, "820": 0.13689782336609055, "840": 0.15777806808719547, "860": 0.180580820528807, "880": 0.20387982661039958, "900": 0.2241532458656858, "920": 0.26195837225996266, "940": 0.30521510937946494, "960": 0.3967396717073996, "980": 0.40298601822973107}
|
43 |
+
},
|
44 |
+
"T25_K1": {
|
45 |
+
"lambda": 10,
|
46 |
+
"etas": {"40": 0.00223172777962139, "80": 0.002324943896556253, "120": 0.002347626232887454, "160": 0.0025855993230245048, "200": 0.0027816403864307474, "240": 0.003232656561641102, "280": 0.0035120116186731034, "320": 0.0043515958429587035, "360": 0.005171759666352569, "400": 0.006783558909602239, "440": 0.008177996254973204, "480": 0.01063599121374264, "520": 0.014335847666818375, "560": 0.018337043722309928, "600": 0.024463545207151972, "640": 0.029903566958453323, "680": 0.043206551929219614, "720": 0.053882087120507724, "760": 0.06720757469247418, "800": 0.09491236004307764, "840": 0.08980582293523451, "880": 0.1322836370288231, "920": 0.3906382071142557, "960": 1.07926636526566}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
cdim/operators/__init__.py
CHANGED
@@ -8,11 +8,19 @@ def register_operator(name: str):
|
|
8 |
def wrapper(cls):
|
9 |
if __OPERATOR__.get(name, None):
|
10 |
raise NameError(f"Name {name} is already registered!")
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
return cls
|
13 |
return wrapper
|
14 |
|
15 |
-
|
16 |
def get_operator(name: str, **kwargs):
|
17 |
if __OPERATOR__.get(name, None) is None:
|
18 |
raise NameError(f"Name {name} is not defined.")
|
|
|
8 |
def wrapper(cls):
|
9 |
if __OPERATOR__.get(name, None):
|
10 |
raise NameError(f"Name {name} is already registered!")
|
11 |
+
|
12 |
+
original_init = cls.__init__
|
13 |
+
|
14 |
+
# Wrap the original __init__ to inject the `name` attribute.
|
15 |
+
def new_init(self, *args, **kwargs):
|
16 |
+
self.name = name # Set the name attribute
|
17 |
+
original_init(self, *args, **kwargs) # Call the original __init__
|
18 |
+
|
19 |
+
cls.__init__ = new_init # Replace the class's __init__ with the wrapped version.
|
20 |
+
__OPERATOR__[name] = cls # Register the class.
|
21 |
return cls
|
22 |
return wrapper
|
23 |
|
|
|
24 |
def get_operator(name: str, **kwargs):
|
25 |
if __OPERATOR__.get(name, None) is None:
|
26 |
raise NameError(f"Name {name} is not defined.")
|
inference.py
CHANGED
@@ -15,8 +15,9 @@ from cdim.image_utils import save_to_image
|
|
15 |
from cdim.dps_model.dps_unet import create_model
|
16 |
from cdim.diffusion.scheduling_ddim import DDIMScheduler
|
17 |
from cdim.diffusion.diffusion_pipeline import run_diffusion
|
|
|
18 |
|
19 |
-
torch.manual_seed(
|
20 |
|
21 |
def load_image(path):
|
22 |
"""
|
@@ -81,10 +82,14 @@ def main(args):
|
|
81 |
noisy_measurement = noise_function(operator(original_image))
|
82 |
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
|
83 |
|
|
|
|
|
|
|
84 |
t0 = time.time()
|
85 |
output_image = run_diffusion(
|
86 |
model, ddim_scheduler,
|
87 |
noisy_measurement, operator, noise_function, device,
|
|
|
88 |
num_inference_steps=args.T,
|
89 |
K=args.K,
|
90 |
model_type=model_type,
|
@@ -101,6 +106,11 @@ if __name__ == '__main__':
|
|
101 |
parser.add_argument("operator_config", type=str)
|
102 |
parser.add_argument("noise_config", type=str)
|
103 |
parser.add_argument("model_config", type=str)
|
|
|
|
|
|
|
|
|
|
|
104 |
parser.add_argument("--output-dir", default=".", type=str)
|
105 |
parser.add_argument("--loss", type=str,
|
106 |
choices=['l2', 'kl', 'categorical_kl'], default='l2',
|
|
|
15 |
from cdim.dps_model.dps_unet import create_model
|
16 |
from cdim.diffusion.scheduling_ddim import DDIMScheduler
|
17 |
from cdim.diffusion.diffusion_pipeline import run_diffusion
|
18 |
+
from cdim.eta_scheduler import EtaScheduler
|
19 |
|
20 |
+
# torch.manual_seed(7)
|
21 |
|
22 |
def load_image(path):
|
23 |
"""
|
|
|
82 |
noisy_measurement = noise_function(operator(original_image))
|
83 |
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
|
84 |
|
85 |
+
eta_scheduler = EtaScheduler(args.eta_type, operator.name, args.T,
|
86 |
+
args.K, args.loss, args.lambda_val)
|
87 |
+
|
88 |
t0 = time.time()
|
89 |
output_image = run_diffusion(
|
90 |
model, ddim_scheduler,
|
91 |
noisy_measurement, operator, noise_function, device,
|
92 |
+
eta_scheduler,
|
93 |
num_inference_steps=args.T,
|
94 |
K=args.K,
|
95 |
model_type=model_type,
|
|
|
106 |
parser.add_argument("operator_config", type=str)
|
107 |
parser.add_argument("noise_config", type=str)
|
108 |
parser.add_argument("model_config", type=str)
|
109 |
+
parser.add_argument("--eta-type", type=str,
|
110 |
+
choices=['gradnorm', 'expected_gradnorm'],
|
111 |
+
default='expected_gradnorm')
|
112 |
+
parser.add_argument("--lambda-val", type=float,
|
113 |
+
default=None)
|
114 |
parser.add_argument("--output-dir", default=".", type=str)
|
115 |
parser.add_argument("--loss", type=str,
|
116 |
choices=['l2', 'kl', 'categorical_kl'], default='l2',
|
operator_configs/super_resolution_config.yaml
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
name: super_resolution
|
2 |
-
scale:
|
3 |
in_shape: !!python/tuple [1, 3, 256, 256]
|
|
|
1 |
name: super_resolution
|
2 |
+
scale: 4
|
3 |
in_shape: !!python/tuple [1, 3, 256, 256]
|