mterris commited on
Commit
94e6664
·
1 Parent(s): 539d600
Files changed (2) hide show
  1. factories.py +11 -7
  2. model_factory.py +2 -2
factories.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, List
3
  import deepinv as dinv
4
  import numpy as np
5
  import torch
6
- from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
7
  from torchvision import transforms
8
 
9
  from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
@@ -127,7 +127,7 @@ class PhysicsWithGenerator(torch.nn.Module):
127
  "fixed_params": {"acceleration_factor": 4}}
128
  elif self.name == "CT":
129
  acceleration_factor = 10
130
- img_h = 480
131
  angles = int(img_h / acceleration_factor)
132
  # angles = torch.linspace(0, 180, steps=10)
133
  self.physics = dinv.physics.Tomography(
@@ -139,11 +139,11 @@ class PhysicsWithGenerator(torch.nn.Module):
139
  noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
140
  max_iter=10,
141
  )
142
- self.physics_generator = None
143
- self.generator = SigmaGenerator(sigma_min=1e-5, sigma_max=1e-4, device=device_str)
144
- self.saved_params = {"updatable_params": {"sigma": 1e-4},
145
  "updatable_params_converter": {"sigma": float},
146
- "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
147
  "angles": angles, "max_iter": 10}}
148
 
149
  def display_saved_params(self) -> str:
@@ -235,6 +235,10 @@ class EvalModel(torch.nn.Module):
235
  self.model.eval()
236
 
237
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
 
 
 
 
238
  return self.model(y, physics=physics)
239
 
240
 
@@ -401,7 +405,7 @@ class BaselineModel(torch.nn.Module):
401
  # Set the DPIR algorithm parameters
402
  sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
403
  lip_const = physics.compute_norm(physics.A_adjoint(y))
404
- lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8,
405
  lip_cons=lip_const.item())
406
  params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
407
  early_stop = False # Do not stop algorithm with convergence criteria
 
3
  import deepinv as dinv
4
  import numpy as np
5
  import torch
6
+ from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator, GainGenerator
7
  from torchvision import transforms
8
 
9
  from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
 
127
  "fixed_params": {"acceleration_factor": 4}}
128
  elif self.name == "CT":
129
  acceleration_factor = 10
130
+ img_h = 512
131
  angles = int(img_h / acceleration_factor)
132
  # angles = torch.linspace(0, 180, steps=10)
133
  self.physics = dinv.physics.Tomography(
 
139
  noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
140
  max_iter=10,
141
  )
142
+ self.physics_generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
143
+ self.generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
144
+ self.saved_params = {"updatable_params": {"sigma": 0.1},
145
  "updatable_params_converter": {"sigma": float},
146
+ "fixed_params": {"noise_sigma_min": 1e-4, "noise_sigma_max": 1e-4,
147
  "angles": angles, "max_iter": 10}}
148
 
149
  def display_saved_params(self) -> str:
 
235
  self.model.eval()
236
 
237
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
238
+ physics.noise_model.sigma = torch.nn.Parameter(torch.tensor([1e-06]))
239
+ physics.noise_model.gain = torch.nn.Parameter(torch.tensor([1e-06]))
240
+ print('sigma = ', physics.noise_model.sigma)
241
+ print('gain = ', physics.noise_model.gain)
242
  return self.model(y, physics=physics)
243
 
244
 
 
405
  # Set the DPIR algorithm parameters
406
  sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
407
  lip_const = physics.compute_norm(physics.A_adjoint(y))
408
+ lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=1, # for debugging speed
409
  lip_cons=lip_const.item())
410
  params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
411
  early_stop = False # Do not stop algorithm with convergence criteria
model_factory.py CHANGED
@@ -51,12 +51,12 @@ class ArtifactRemoval(nn.Module):
51
  if hasattr(physics.noise_model, "sigma"):
52
  sigma = physics.noise_model.sigma
53
  else:
54
- sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
55
 
56
  if hasattr(physics.noise_model, "gain"):
57
  gamma = physics.noise_model.gain
58
  else:
59
- gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
60
 
61
  out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
62
 
 
51
  if hasattr(physics.noise_model, "sigma"):
52
  sigma = physics.noise_model.sigma
53
  else:
54
+ sigma = 1e-5 # WARNING: this is a default value that we may not want to use?
55
 
56
  if hasattr(physics.noise_model, "gain"):
57
  gamma = physics.noise_model.gain
58
  else:
59
+ gamma = 1e-5 # WARNING: this is a default value that we may not want to use?
60
 
61
  out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
62