VIVEK JAYARAM commited on
Commit
e6c2b25
1 Parent(s): 64e8afb

Expected gradnorm

Browse files
cdim/diffusion/diffusion_pipeline.py CHANGED
@@ -23,6 +23,7 @@ def run_diffusion(
23
  operator,
24
  noise_function,
25
  device,
 
26
  num_inference_steps: int = 1000,
27
  K=5,
28
  image_dim=256,
@@ -90,7 +91,8 @@ def run_diffusion(
90
  else:
91
  raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
92
 
93
- image -= 10 / torch.linalg.norm(image.grad) * image.grad
 
94
  image = image.detach().requires_grad_()
95
 
96
  return image
 
23
  operator,
24
  noise_function,
25
  device,
26
+ eta_scheduler,
27
  num_inference_steps: int = 1000,
28
  K=5,
29
  image_dim=256,
 
91
  else:
92
  raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
93
 
94
+ step_size = eta_scheduler.get_step_size(str(t.item()), torch.linalg.norm(image.grad))
95
+ image -= step_size * image.grad
96
  image = image.detach().requires_grad_()
97
 
98
  return image
cdim/eta_scheduler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ class EtaScheduler:
4
+ def __init__(self, method, task, T, K, loss_type, lambda_val=None):
5
+ self.task = task
6
+ self.T = T
7
+ self.K = K
8
+ self.loss_type = loss_type
9
+ self.lambda_val = lambda_val
10
+ self.method = method
11
+
12
+ self.precomputed_etas = self._load_precomputed_etas()
13
+ # Couldn't find expected gradnorm
14
+ if not self.precomputed_etas and method == "expected_gradnorm":
15
+ self.method = "gradnorm"
16
+ print("Etas for this configuration not found. Switching to gradnorm.")
17
+
18
+ # Get the best lambda_val if it's not passed
19
+ if self.lambda_val is None:
20
+ if self.method == "expected_gradnorm":
21
+ self.lambda_val = self.precomputed_etas["lambda"]
22
+ else:
23
+ self.lambda_val = self.best_guess_lambda()
24
+ print(f"Using lambda {self.lambda_val}")
25
+
26
+ def _load_precomputed_etas(self):
27
+ steps_key = f"T{self.T}_K{self.K}"
28
+ with open("cdim/etas.json") as f:
29
+ all_etas = json.load(f)
30
+
31
+ return all_etas.get(self.task, {}).get(self.loss_type, {}).get(steps_key, {})
32
+
33
+ def get_step_size(self, t, grad_norm):
34
+ """Use either precomputed expected gradnorm or gradnorm."""
35
+ if self.method == "expected_gradnorm":
36
+ step_size = self.lambda_val * 1 / self.precomputed_etas["etas"][t]
37
+ else:
38
+ step_size = self.lambda_val * 1 / grad_norm
39
+ return step_size
40
+
41
+ def best_guess_lambda(self):
42
+ """Guess a lambda value if not provided. Based on trial and error"""
43
+ total_steps = self.T * self.K
44
+
45
+ # L2 tends to over optimize too aggressively, so the default lr is lower
46
+ if self.loss_type == "kl":
47
+ return 350 / total_steps
48
+ elif self.loss_type == "l2":
49
+ return 220 / total_steps
50
+ else:
51
+ raise ValueError(f"Please provide learning rate for loss type {self.loss_type}")
52
+
cdim/etas.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "box_inpainting": {
3
+ "kl" : {
4
+ "T50_K3": {
5
+ "lambda": 4.0,
6
+ "etas": {"20": 0.0007644674769138669, "40": 0.0007352405159032115, "60": 0.0007339323611117017, "80": 0.0007407785604665967, "100": 0.0007549588527783352, "120": 0.0007624788939223991, "140": 0.0007797297457310411, "160": 0.0008114829360981001, "180": 0.0008749328277143722, "200": 0.0009383883323579534, "220": 0.0010238270356100192, "240": 0.0010784138469520864, "260": 0.0011866835135642144, "280": 0.0013024192080464585, "300": 0.001475853782213996, "320": 0.001632385599489279, "340": 0.0018509140273253787, "360": 0.0020580830457055752, "380": 0.0023878463159672523, "400": 0.0026761896618615814, "420": 0.0031795443584943708, "440": 0.0034981176066313513, "460": 0.004353874833222008, "480": 0.004711352191433642, "500": 0.005726754667302881, "520": 0.006623946294438855, "540": 0.00863887341585953, "560": 0.009790771703642386, "580": 0.011950644941955498, "600": 0.013428242186066879, "620": 0.017515641845500456, "640": 0.01918877035399789, "660": 0.024788169966172627, "680": 0.026881696669367267, "700": 0.034675169438193654, "720": 0.039756030503755736, "740": 0.04854181173586306, "760": 0.05583204092049454, "780": 0.06820393054365007, "800": 0.0775198126468748, "820": 0.09482237250830504, "840": 0.10625708906759426, "860": 0.1277688814432282, "880": 0.1429259541984142, "900": 0.1730718536839941, "920": 0.19272152879630572, "940": 0.2197799661143284, "960": 0.2523353552270739, "980": 0.24785875279032762}
7
+ },
8
+ "T25_K1": {
9
+ "lambda": 9.0,
10
+ "etas": {"40": 0.002007311321255281, "80": 0.0017914309657763605, "120": 0.0018785649978129587, "160": 0.0019939875678122057, "200": 0.002077116515854901, "240": 0.0022416421013690723, "280": 0.002525056453252015, "320": 0.002990539520210802, "360": 0.003687435867772619, "400": 0.00461278213332024, "440": 0.005670295237520474, "480": 0.007450458259252178, "520": 0.01104231419851132, "560": 0.01304287128591384, "600": 0.01770141680890715, "640": 0.02673789939069817, "680": 0.03434346958799475, "720": 0.04184474771847407, "760": 0.05826798834659389, "800": 0.07392334565984734, "840": 0.09130451088376304, "880": 0.11364396078616029, "920": 0.1595581552644202, "960": 0.17525405659705573}
11
+ }
12
+ }
13
+ },
14
+ "random_inpainting": {
15
+ "kl" : {
16
+ "T50_K3": {
17
+ "lambda": 3.0,
18
+ "etas": {"20": 4.2155235275353885e-05, "40": 2.9131803821087877e-05, "60": 3.138850996275609e-05, "80": 3.386381497611811e-05, "100": 3.631776565956598e-05, "120": 3.973121110139383e-05, "140": 4.3322363720455985e-05, "160": 4.863932933274831e-05, "180": 5.240606242179506e-05, "200": 5.814939458443417e-05, "220": 6.197553145350107e-05, "240": 6.662394612209601e-05, "260": 7.353187590939113e-05, "280": 7.957012424639822e-05, "300": 8.54105910002775e-05, "320": 9.482279846848642e-05, "340": 0.00010283310990513798, "360": 0.00011288007608863972, "380": 0.00012768665988629105, "400": 0.00014661738441383113, "420": 0.00016301237603343505, "440": 0.00018253602035619805, "460": 0.00020806616934057048, "480": 0.00022939095079289252, "500": 0.00024356857801398671, "520": 0.00027035908261852304, "540": 0.0003041859182782007, "560": 0.0003671558660027532, "580": 0.00043392769502977437, "600": 0.0004871999766289279, "620": 0.0006560545650662386, "640": 0.0011521633386963458, "660": 0.0010345512272746527, "680": 0.0014703241871573396, "700": 0.0022689739261214977, "720": 0.002714861129921377, "740": 0.005800220039199319, "760": 0.0072751867583119864, "780": 0.012355766658887645, "800": 0.014635295468316744, "820": 0.011244771247342284, "840": 0.010248225188751379, "860": 0.013109683696949682, "880": 0.01479392021174634, "900": 0.016892151099455268, "920": 0.017663442143744875, "940": 0.01820524420826548, "960": 0.1109586885644798, "980": 0.10368549774809521}
19
+ },
20
+ "T25_K1": {
21
+ "lambda": 6.0,
22
+ "etas": {"40": 3.772142176929246e-05, "80": 2.2706469625575782e-05, "120": 2.6326200988558663e-05, "160": 3.1264416616630894e-05, "200": 3.734036301684134e-05, "240": 4.846247119505167e-05, "280": 6.404192925339954e-05, "320": 8.471533678103274e-05, "360": 0.00011057770282582904, "400": 0.00015357450981873662, "440": 0.00020928137685536158, "480": 0.00028775575102761786, "520": 0.0003906344462048851, "560": 0.0006252858438352591, "600": 0.0008340419507301546, "640": 0.0011431077222103293, "680": 0.0017961608713683882, "720": 0.0031925221564149363, "760": 0.004980067265202975, "800": 0.008485425858496985, "840": 0.01032865422644551, "880": 0.01237070838947314, "920": 0.03340125151375621, "960": 0.08094067944719611}
23
+ }
24
+ }
25
+ },
26
+ "super_resolution": {
27
+ "kl" : {
28
+ "T50_K3": {
29
+ "lambda": 3.0,
30
+ "etas": {"20": 0.000778915347623841, "40": 0.0007633780339799099, "60": 0.0007908141231719096, "80": 0.000788543595116803, "100": 0.0008288430294987044, "120": 0.0008392907314376583, "140": 0.0008925030169959898, "160": 0.0009190516306847781, "180": 0.000986800913723534, "200": 0.0010337262985591182, "220": 0.0011219889907389093, "240": 0.0011943036330360538, "260": 0.001310226177130967, "280": 0.0014183460616152988, "300": 0.0015733242833945388, "320": 0.001730196133012654, "340": 0.001937516604413998, "360": 0.002161490631568267, "380": 0.0024432350890119504, "400": 0.002708168867773674, "420": 0.00307449542457626, "440": 0.003515836769282733, "460": 0.004125363116570879, "480": 0.004754074570960041, "500": 0.005524380339583777, "520": 0.006439562376976526, "540": 0.007540773754592587, "560": 0.009500516111354843, "580": 0.011261532776188401, "600": 0.01348216019559666, "620": 0.01757517906334087, "640": 0.01977655744741228, "660": 0.024533713245498356, "680": 0.02746674570831025, "700": 0.03332074050640432, "720": 0.040243771626849834, "740": 0.04550001243601292, "760": 0.05033544385241358, "780": 0.061410491652647066, "800": 0.07221948341628477, "820": 0.0856409270444583, "840": 0.09925753198252922, "860": 0.12306635464984625, "880": 0.14476034450020075, "900": 0.17468273558090733, "920": 0.20411002595502606, "940": 0.24450559755448473, "960": 0.28248054410917073, "980": 0.3147428834175}
31
+ },
32
+ "T25_K1": {
33
+ "lambda": 7.5,
34
+ "etas": {"40": 0.002727176113620816, "80": 0.002118729856897165, "120": 0.002106326729054324, "160": 0.002295490187260311, "200": 0.002577219155480812, "240": 0.0029055148578559465, "280": 0.0032434651627058597, "320": 0.003844541849902738, "360": 0.004323113740483335, "400": 0.00514218754706466, "440": 0.005974186955805667, "480": 0.0072513396445271225, "520": 0.009855938730607262, "560": 0.012369151964727335, "600": 0.015923742821581138, "640": 0.021501799745055887, "680": 0.02177964029551375, "720": 0.03612378754790826, "760": 0.04641127936774787, "800": 0.06494989916290186, "840": 0.07686161107101508, "880": 0.10411057668338695, "920": 0.21030777467686249, "960": 0.1347356123776581}
35
+ }
36
+ }
37
+ },
38
+ "gaussian_blur": {
39
+ "kl" : {
40
+ "T50_K3": {
41
+ "lambda": 2.5,
42
+ "etas": {"20": 0.001311277637655049, "40": 0.0011314792208201052, "60": 0.0013186921772428831, "80": 0.001161965484160248, "100": 0.0013730683554339122, "120": 0.0012312984890480417, "140": 0.0014760163095296494, "160": 0.0013462895586367368, "180": 0.0016383561831276197, "200": 0.0015191852141152255, "220": 0.0018704796563358576, "240": 0.0017668750655908613, "260": 0.0022029732696644934, "280": 0.002117440832022425, "300": 0.0026667582305060513, "320": 0.002615673065864779, "340": 0.0033227378642475855, "360": 0.0033246281530489056, "380": 0.004257323198960001, "400": 0.004339347273384217, "420": 0.005592718291095555, "440": 0.005808965844831082, "460": 0.007536144611100317, "480": 0.00795709164619601, "500": 0.010367841265905342, "520": 0.011111897162984073, "540": 0.014601192421942862, "560": 0.01572962570722594, "580": 0.020453046896395446, "600": 0.02245033051610993, "620": 0.028807566692203808, "640": 0.03205876293402484, "660": 0.041381670842137, "680": 0.04627552317583222, "700": 0.05792698423136997, "720": 0.0641696207414127, "740": 0.07855217157377192, "760": 0.08430869243734988, "780": 0.10214072576914401, "800": 0.11871801801180902, "820": 0.13689782336609055, "840": 0.15777806808719547, "860": 0.180580820528807, "880": 0.20387982661039958, "900": 0.2241532458656858, "920": 0.26195837225996266, "940": 0.30521510937946494, "960": 0.3967396717073996, "980": 0.40298601822973107}
43
+ },
44
+ "T25_K1": {
45
+ "lambda": 10,
46
+ "etas": {"40": 0.00223172777962139, "80": 0.002324943896556253, "120": 0.002347626232887454, "160": 0.0025855993230245048, "200": 0.0027816403864307474, "240": 0.003232656561641102, "280": 0.0035120116186731034, "320": 0.0043515958429587035, "360": 0.005171759666352569, "400": 0.006783558909602239, "440": 0.008177996254973204, "480": 0.01063599121374264, "520": 0.014335847666818375, "560": 0.018337043722309928, "600": 0.024463545207151972, "640": 0.029903566958453323, "680": 0.043206551929219614, "720": 0.053882087120507724, "760": 0.06720757469247418, "800": 0.09491236004307764, "840": 0.08980582293523451, "880": 0.1322836370288231, "920": 0.3906382071142557, "960": 1.07926636526566}
47
+ }
48
+ }
49
+ }
50
+ }
cdim/operators/__init__.py CHANGED
@@ -8,11 +8,19 @@ def register_operator(name: str):
8
  def wrapper(cls):
9
  if __OPERATOR__.get(name, None):
10
  raise NameError(f"Name {name} is already registered!")
11
- __OPERATOR__[name] = cls
 
 
 
 
 
 
 
 
 
12
  return cls
13
  return wrapper
14
 
15
-
16
  def get_operator(name: str, **kwargs):
17
  if __OPERATOR__.get(name, None) is None:
18
  raise NameError(f"Name {name} is not defined.")
 
8
  def wrapper(cls):
9
  if __OPERATOR__.get(name, None):
10
  raise NameError(f"Name {name} is already registered!")
11
+
12
+ original_init = cls.__init__
13
+
14
+ # Wrap the original __init__ to inject the `name` attribute.
15
+ def new_init(self, *args, **kwargs):
16
+ self.name = name # Set the name attribute
17
+ original_init(self, *args, **kwargs) # Call the original __init__
18
+
19
+ cls.__init__ = new_init # Replace the class's __init__ with the wrapped version.
20
+ __OPERATOR__[name] = cls # Register the class.
21
  return cls
22
  return wrapper
23
 
 
24
  def get_operator(name: str, **kwargs):
25
  if __OPERATOR__.get(name, None) is None:
26
  raise NameError(f"Name {name} is not defined.")
inference.py CHANGED
@@ -15,8 +15,9 @@ from cdim.image_utils import save_to_image
15
  from cdim.dps_model.dps_unet import create_model
16
  from cdim.diffusion.scheduling_ddim import DDIMScheduler
17
  from cdim.diffusion.diffusion_pipeline import run_diffusion
 
18
 
19
- torch.manual_seed(8)
20
 
21
  def load_image(path):
22
  """
@@ -81,10 +82,14 @@ def main(args):
81
  noisy_measurement = noise_function(operator(original_image))
82
  save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
83
 
 
 
 
84
  t0 = time.time()
85
  output_image = run_diffusion(
86
  model, ddim_scheduler,
87
  noisy_measurement, operator, noise_function, device,
 
88
  num_inference_steps=args.T,
89
  K=args.K,
90
  model_type=model_type,
@@ -101,6 +106,11 @@ if __name__ == '__main__':
101
  parser.add_argument("operator_config", type=str)
102
  parser.add_argument("noise_config", type=str)
103
  parser.add_argument("model_config", type=str)
 
 
 
 
 
104
  parser.add_argument("--output-dir", default=".", type=str)
105
  parser.add_argument("--loss", type=str,
106
  choices=['l2', 'kl', 'categorical_kl'], default='l2',
 
15
  from cdim.dps_model.dps_unet import create_model
16
  from cdim.diffusion.scheduling_ddim import DDIMScheduler
17
  from cdim.diffusion.diffusion_pipeline import run_diffusion
18
+ from cdim.eta_scheduler import EtaScheduler
19
 
20
+ # torch.manual_seed(7)
21
 
22
  def load_image(path):
23
  """
 
82
  noisy_measurement = noise_function(operator(original_image))
83
  save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
84
 
85
+ eta_scheduler = EtaScheduler(args.eta_type, operator.name, args.T,
86
+ args.K, args.loss, args.lambda_val)
87
+
88
  t0 = time.time()
89
  output_image = run_diffusion(
90
  model, ddim_scheduler,
91
  noisy_measurement, operator, noise_function, device,
92
+ eta_scheduler,
93
  num_inference_steps=args.T,
94
  K=args.K,
95
  model_type=model_type,
 
106
  parser.add_argument("operator_config", type=str)
107
  parser.add_argument("noise_config", type=str)
108
  parser.add_argument("model_config", type=str)
109
+ parser.add_argument("--eta-type", type=str,
110
+ choices=['gradnorm', 'expected_gradnorm'],
111
+ default='expected_gradnorm')
112
+ parser.add_argument("--lambda-val", type=float,
113
+ default=None)
114
  parser.add_argument("--output-dir", default=".", type=str)
115
  parser.add_argument("--loss", type=str,
116
  choices=['l2', 'kl', 'categorical_kl'], default='l2',
operator_configs/super_resolution_config.yaml CHANGED
@@ -1,3 +1,3 @@
1
  name: super_resolution
2
- scale: 8
3
  in_shape: !!python/tuple [1, 3, 256, 256]
 
1
  name: super_resolution
2
+ scale: 4
3
  in_shape: !!python/tuple [1, 3, 256, 256]