mterris commited on
Commit
9037f29
·
1 Parent(s): c7ae131
Files changed (1) hide show
  1. factories.py +7 -8
factories.py CHANGED
@@ -50,7 +50,7 @@ class PhysicsWithGenerator(torch.nn.Module):
50
 
51
  if self.name == "MotionBlur_medium":
52
  psf_size = 31
53
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
54
  padding="valid", device=device_str)
55
  self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str)
56
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
@@ -60,7 +60,7 @@ class PhysicsWithGenerator(torch.nn.Module):
60
  "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
61
  elif self.name == "MotionBlur_hard":
62
  psf_size = 31
63
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1),
64
  padding="valid", device=device_str)
65
  self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str)
66
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
@@ -70,7 +70,7 @@ class PhysicsWithGenerator(torch.nn.Module):
70
  "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
71
  elif self.name == "GaussianBlur_easy":
72
  psf_size = 31
73
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.01),
74
  padding="valid", device=device_str)
75
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
76
  sigma_min=1.0, sigma_max=1.0,
@@ -83,7 +83,7 @@ class PhysicsWithGenerator(torch.nn.Module):
83
  "blur_sigma": 1.0, "psf_size": 31, "num_channels": 1}}
84
  elif self.name == "GaussianBlur_medium":
85
  psf_size = 31
86
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05),
87
  padding="valid", device=device_str)
88
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
89
  sigma_min=2.0, sigma_max=2.0,
@@ -96,7 +96,7 @@ class PhysicsWithGenerator(torch.nn.Module):
96
  "blur_sigma": 2.0, "psf_size": 31, "num_channels": 1}}
97
  elif self.name == "GaussianBlur_hard":
98
  psf_size = 31
99
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05),
100
  padding="valid", device=device_str)
101
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
102
  sigma_min=4.0, sigma_max=4.0,
@@ -108,7 +108,7 @@ class PhysicsWithGenerator(torch.nn.Module):
108
  "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
109
  "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
110
  elif self.name == "MRI":
111
- self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01),
112
  img_size=(640, 320), device=device_str)
113
  self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
114
  self.generator = self.physics_generator
@@ -125,7 +125,7 @@ class PhysicsWithGenerator(torch.nn.Module):
125
  circle=False,
126
  normalize=True,
127
  device=device_str,
128
- noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
129
  max_iter=10,
130
  )
131
  self.physics_generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
@@ -151,7 +151,6 @@ class PhysicsWithGenerator(torch.nn.Module):
151
  """Update value of an existing key in save_params."""
152
  if value != "" and key in list(self.saved_params["updatable_params"].keys()):
153
  if type(value) == str: # it may be only a str representation
154
- # type: str -> ???
155
  value = self.saved_params["updatable_params_converter"][key](value)
156
  elif isinstance(value, torch.Tensor):
157
  value = value.item() # type: torch.Tensor -> float
 
50
 
51
  if self.name == "MotionBlur_medium":
52
  psf_size = 31
53
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05).to(device_str),
54
  padding="valid", device=device_str)
55
  self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str)
56
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
 
60
  "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
61
  elif self.name == "MotionBlur_hard":
62
  psf_size = 31
63
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1).to(device_str),
64
  padding="valid", device=device_str)
65
  self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str)
66
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
 
70
  "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
71
  elif self.name == "GaussianBlur_easy":
72
  psf_size = 31
73
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.01).to(device_str),
74
  padding="valid", device=device_str)
75
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
76
  sigma_min=1.0, sigma_max=1.0,
 
83
  "blur_sigma": 1.0, "psf_size": 31, "num_channels": 1}}
84
  elif self.name == "GaussianBlur_medium":
85
  psf_size = 31
86
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05).to(device_str),
87
  padding="valid", device=device_str)
88
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
89
  sigma_min=2.0, sigma_max=2.0,
 
96
  "blur_sigma": 2.0, "psf_size": 31, "num_channels": 1}}
97
  elif self.name == "GaussianBlur_hard":
98
  psf_size = 31
99
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05).to(device_str),
100
  padding="valid", device=device_str)
101
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
102
  sigma_min=4.0, sigma_max=4.0,
 
108
  "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
109
  "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
110
  elif self.name == "MRI":
111
+ self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01).to(device_str),
112
  img_size=(640, 320), device=device_str)
113
  self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
114
  self.generator = self.physics_generator
 
125
  circle=False,
126
  normalize=True,
127
  device=device_str,
128
+ noise_model=dinv.physics.GaussianNoise(sigma=1e-4).to(device_str),
129
  max_iter=10,
130
  )
131
  self.physics_generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
 
151
  """Update value of an existing key in save_params."""
152
  if value != "" and key in list(self.saved_params["updatable_params"].keys()):
153
  if type(value) == str: # it may be only a str representation
 
154
  value = self.saved_params["updatable_params_converter"][key](value)
155
  elif isinstance(value, torch.Tensor):
156
  value = value.item() # type: torch.Tensor -> float