Spaces:
Sleeping
Sleeping
fix
Browse files- factories.py +7 -8
factories.py
CHANGED
@@ -50,7 +50,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
50 |
|
51 |
if self.name == "MotionBlur_medium":
|
52 |
psf_size = 31
|
53 |
-
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
|
54 |
padding="valid", device=device_str)
|
55 |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str)
|
56 |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
|
@@ -60,7 +60,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
60 |
"psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
|
61 |
elif self.name == "MotionBlur_hard":
|
62 |
psf_size = 31
|
63 |
-
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1),
|
64 |
padding="valid", device=device_str)
|
65 |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str)
|
66 |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
|
@@ -70,7 +70,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
70 |
"psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
|
71 |
elif self.name == "GaussianBlur_easy":
|
72 |
psf_size = 31
|
73 |
-
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.01),
|
74 |
padding="valid", device=device_str)
|
75 |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
|
76 |
sigma_min=1.0, sigma_max=1.0,
|
@@ -83,7 +83,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
83 |
"blur_sigma": 1.0, "psf_size": 31, "num_channels": 1}}
|
84 |
elif self.name == "GaussianBlur_medium":
|
85 |
psf_size = 31
|
86 |
-
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05),
|
87 |
padding="valid", device=device_str)
|
88 |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
|
89 |
sigma_min=2.0, sigma_max=2.0,
|
@@ -96,7 +96,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
96 |
"blur_sigma": 2.0, "psf_size": 31, "num_channels": 1}}
|
97 |
elif self.name == "GaussianBlur_hard":
|
98 |
psf_size = 31
|
99 |
-
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05),
|
100 |
padding="valid", device=device_str)
|
101 |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
|
102 |
sigma_min=4.0, sigma_max=4.0,
|
@@ -108,7 +108,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
108 |
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
|
109 |
"blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
|
110 |
elif self.name == "MRI":
|
111 |
-
self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01),
|
112 |
img_size=(640, 320), device=device_str)
|
113 |
self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
|
114 |
self.generator = self.physics_generator
|
@@ -125,7 +125,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
125 |
circle=False,
|
126 |
normalize=True,
|
127 |
device=device_str,
|
128 |
-
noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
|
129 |
max_iter=10,
|
130 |
)
|
131 |
self.physics_generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
|
@@ -151,7 +151,6 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
151 |
"""Update value of an existing key in save_params."""
|
152 |
if value != "" and key in list(self.saved_params["updatable_params"].keys()):
|
153 |
if type(value) == str: # it may be only a str representation
|
154 |
-
# type: str -> ???
|
155 |
value = self.saved_params["updatable_params_converter"][key](value)
|
156 |
elif isinstance(value, torch.Tensor):
|
157 |
value = value.item() # type: torch.Tensor -> float
|
|
|
50 |
|
51 |
if self.name == "MotionBlur_medium":
|
52 |
psf_size = 31
|
53 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05).to(device_str),
|
54 |
padding="valid", device=device_str)
|
55 |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str)
|
56 |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
|
|
|
60 |
"psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
|
61 |
elif self.name == "MotionBlur_hard":
|
62 |
psf_size = 31
|
63 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1).to(device_str),
|
64 |
padding="valid", device=device_str)
|
65 |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str)
|
66 |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
|
|
|
70 |
"psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
|
71 |
elif self.name == "GaussianBlur_easy":
|
72 |
psf_size = 31
|
73 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.01).to(device_str),
|
74 |
padding="valid", device=device_str)
|
75 |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
|
76 |
sigma_min=1.0, sigma_max=1.0,
|
|
|
83 |
"blur_sigma": 1.0, "psf_size": 31, "num_channels": 1}}
|
84 |
elif self.name == "GaussianBlur_medium":
|
85 |
psf_size = 31
|
86 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05).to(device_str),
|
87 |
padding="valid", device=device_str)
|
88 |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
|
89 |
sigma_min=2.0, sigma_max=2.0,
|
|
|
96 |
"blur_sigma": 2.0, "psf_size": 31, "num_channels": 1}}
|
97 |
elif self.name == "GaussianBlur_hard":
|
98 |
psf_size = 31
|
99 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05).to(device_str),
|
100 |
padding="valid", device=device_str)
|
101 |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
|
102 |
sigma_min=4.0, sigma_max=4.0,
|
|
|
108 |
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
|
109 |
"blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
|
110 |
elif self.name == "MRI":
|
111 |
+
self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01).to(device_str),
|
112 |
img_size=(640, 320), device=device_str)
|
113 |
self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
|
114 |
self.generator = self.physics_generator
|
|
|
125 |
circle=False,
|
126 |
normalize=True,
|
127 |
device=device_str,
|
128 |
+
noise_model=dinv.physics.GaussianNoise(sigma=1e-4).to(device_str),
|
129 |
max_iter=10,
|
130 |
)
|
131 |
self.physics_generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
|
|
|
151 |
"""Update value of an existing key in save_params."""
|
152 |
if value != "" and key in list(self.saved_params["updatable_params"].keys()):
|
153 |
if type(value) == str: # it may be only a str representation
|
|
|
154 |
value = self.saved_params["updatable_params_converter"][key](value)
|
155 |
elif isinstance(value, torch.Tensor):
|
156 |
value = value.item() # type: torch.Tensor -> float
|