Add Checkpoint 195 and 70
Browse files- layers.py +0 -2
- networks.py +1 -2
- pix2pix.py +118 -13
- pix2pix_disc_ckpt_195.pt +3 -0
- pix2pix_disc_ckpt_70.pt +3 -0
- pix2pix_gen_ckpt_195.pt +3 -0
- pix2pix_gen_ckpt_70.pt +3 -0
layers.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
-
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
|
5 |
|
6 |
class DownsamplingBlock(nn.Module):
|
|
|
|
|
1 |
import torch.nn as nn
|
|
|
2 |
|
3 |
|
4 |
class DownsamplingBlock(nn.Module):
|
networks.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
|
5 |
-
from
|
6 |
|
7 |
class UnetEncoder(nn.Module):
|
8 |
"""Create the Unet Encoder Network.
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
3 |
|
4 |
+
from layers import DownsamplingBlock, UpsamplingBlock
|
5 |
|
6 |
class UnetEncoder(nn.Module):
|
7 |
"""Create the Unet Encoder Network.
|
pix2pix.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
|
4 |
|
5 |
-
from
|
6 |
|
7 |
-
class Pix2Pix(
|
|
|
|
|
8 |
"""Create a Pix2Pix class. It is a model for image to image translation tasks.
|
9 |
By default, the model uses a Unet architecture for generator with transposed
|
10 |
convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
|
@@ -44,11 +46,12 @@ class Pix2Pix(nn.Module):
|
|
44 |
super(Pix2Pix, self).__init__()
|
45 |
self.is_CGAN = is_CGAN
|
46 |
self.lambda_L1 = lambda_L1
|
|
|
47 |
|
48 |
self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
|
49 |
self.gen = self.gen.apply(self.weights_init)
|
50 |
|
51 |
-
if is_train:
|
52 |
# Conditional GANs need both input and output together, the total input channel is c_in+c_out
|
53 |
disc_in = c_in + c_out if is_CGAN else c_out
|
54 |
self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
|
@@ -64,7 +67,7 @@ class Pix2Pix(nn.Module):
|
|
64 |
self.criterion = nn.BCEWithLogitsLoss()
|
65 |
self.criterion_L1 = nn.L1Loss()
|
66 |
|
67 |
-
def forward(self, x):
|
68 |
return self.gen(x)
|
69 |
|
70 |
@staticmethod
|
@@ -82,7 +85,11 @@ class Pix2Pix(nn.Module):
|
|
82 |
nn.init.normal_(m.weight, 1.0, 0.02)
|
83 |
nn.init.constant_(m.bias, 0)
|
84 |
|
85 |
-
def _get_disc_inputs(self,
|
|
|
|
|
|
|
|
|
86 |
"""Prepare discriminator inputs based on conditional/unconditional setup."""
|
87 |
if self.is_CGAN:
|
88 |
# Conditional GANs need both input and output together,
|
@@ -96,7 +103,10 @@ class Pix2Pix(nn.Module):
|
|
96 |
fake_AB = fake_images.detach()
|
97 |
return real_AB, fake_AB
|
98 |
|
99 |
-
def _get_gen_inputs(self,
|
|
|
|
|
|
|
100 |
"""Prepare discriminator inputs based on conditional/unconditional setup."""
|
101 |
if self.is_CGAN:
|
102 |
# Conditional GANs need both input and output together,
|
@@ -109,7 +119,11 @@ class Pix2Pix(nn.Module):
|
|
109 |
return fake_AB
|
110 |
|
111 |
|
112 |
-
def step_discriminator(self,
|
|
|
|
|
|
|
|
|
113 |
"""Discriminator forward/backward pass.
|
114 |
|
115 |
Args:
|
@@ -134,7 +148,11 @@ class Pix2Pix(nn.Module):
|
|
134 |
lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
|
135 |
return lossD
|
136 |
|
137 |
-
def step_generator(self,
|
|
|
|
|
|
|
|
|
138 |
"""Discriminator forward/backward pass.
|
139 |
|
140 |
Args:
|
@@ -162,7 +180,10 @@ class Pix2Pix(nn.Module):
|
|
162 |
'loss_G_L1': lossG_L1.item()
|
163 |
}
|
164 |
|
165 |
-
def train_step(self,
|
|
|
|
|
|
|
166 |
"""Performs a single training step.
|
167 |
|
168 |
Args:
|
@@ -177,13 +198,13 @@ class Pix2Pix(nn.Module):
|
|
177 |
|
178 |
# Update discriminator
|
179 |
self.disc_optimizer.zero_grad() # Reset the gradients for D
|
180 |
-
lossD = self.
|
181 |
lossD.backward()
|
182 |
self.disc_optimizer.step() # Update D
|
183 |
|
184 |
# Update generator
|
185 |
self.gen_optimizer.zero_grad() # Reset the gradients for D
|
186 |
-
lossG, G_losses = self.
|
187 |
lossG.backward()
|
188 |
self.gen_optimizer.step() # Update D
|
189 |
|
@@ -193,7 +214,91 @@ class Pix2Pix(nn.Module):
|
|
193 |
**G_losses
|
194 |
}
|
195 |
|
196 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
"""Return visualization images.
|
198 |
|
199 |
Args:
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
from huggingface_hub import PyTorchModelHubMixin
|
4 |
|
5 |
+
from networks import UnetGenerator, PatchGAN
|
6 |
|
7 |
+
class Pix2Pix(
|
8 |
+
nn.Module,
|
9 |
+
PyTorchModelHubMixin):
|
10 |
"""Create a Pix2Pix class. It is a model for image to image translation tasks.
|
11 |
By default, the model uses a Unet architecture for generator with transposed
|
12 |
convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
|
|
|
46 |
super(Pix2Pix, self).__init__()
|
47 |
self.is_CGAN = is_CGAN
|
48 |
self.lambda_L1 = lambda_L1
|
49 |
+
self.is_train = is_train
|
50 |
|
51 |
self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
|
52 |
self.gen = self.gen.apply(self.weights_init)
|
53 |
|
54 |
+
if self.is_train:
|
55 |
# Conditional GANs need both input and output together, the total input channel is c_in+c_out
|
56 |
disc_in = c_in + c_out if is_CGAN else c_out
|
57 |
self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
|
|
|
67 |
self.criterion = nn.BCEWithLogitsLoss()
|
68 |
self.criterion_L1 = nn.L1Loss()
|
69 |
|
70 |
+
def forward(self, x: torch.Tensor):
|
71 |
return self.gen(x)
|
72 |
|
73 |
@staticmethod
|
|
|
85 |
nn.init.normal_(m.weight, 1.0, 0.02)
|
86 |
nn.init.constant_(m.bias, 0)
|
87 |
|
88 |
+
def _get_disc_inputs(self,
|
89 |
+
real_images: torch.Tensor,
|
90 |
+
target_images: torch.Tensor,
|
91 |
+
fake_images: torch.Tensor
|
92 |
+
):
|
93 |
"""Prepare discriminator inputs based on conditional/unconditional setup."""
|
94 |
if self.is_CGAN:
|
95 |
# Conditional GANs need both input and output together,
|
|
|
103 |
fake_AB = fake_images.detach()
|
104 |
return real_AB, fake_AB
|
105 |
|
106 |
+
def _get_gen_inputs(self,
|
107 |
+
real_images: torch.Tensor,
|
108 |
+
fake_images: torch.Tensor
|
109 |
+
):
|
110 |
"""Prepare discriminator inputs based on conditional/unconditional setup."""
|
111 |
if self.is_CGAN:
|
112 |
# Conditional GANs need both input and output together,
|
|
|
119 |
return fake_AB
|
120 |
|
121 |
|
122 |
+
def step_discriminator(self,
|
123 |
+
real_images: torch.Tensor,
|
124 |
+
target_images: torch.Tensor,
|
125 |
+
fake_images: torch.Tensor
|
126 |
+
):
|
127 |
"""Discriminator forward/backward pass.
|
128 |
|
129 |
Args:
|
|
|
148 |
lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
|
149 |
return lossD
|
150 |
|
151 |
+
def step_generator(self,
|
152 |
+
real_images: torch.Tensor,
|
153 |
+
target_images: torch.Tensor,
|
154 |
+
fake_images: torch.Tensor
|
155 |
+
):
|
156 |
"""Discriminator forward/backward pass.
|
157 |
|
158 |
Args:
|
|
|
180 |
'loss_G_L1': lossG_L1.item()
|
181 |
}
|
182 |
|
183 |
+
def train_step(self,
|
184 |
+
real_images: torch.Tensor,
|
185 |
+
target_images: torch.Tensor
|
186 |
+
):
|
187 |
"""Performs a single training step.
|
188 |
|
189 |
Args:
|
|
|
198 |
|
199 |
# Update discriminator
|
200 |
self.disc_optimizer.zero_grad() # Reset the gradients for D
|
201 |
+
lossD = self.step_discriminator(real_images, target_images, fake_images) # Compute the loss
|
202 |
lossD.backward()
|
203 |
self.disc_optimizer.step() # Update D
|
204 |
|
205 |
# Update generator
|
206 |
self.gen_optimizer.zero_grad() # Reset the gradients for D
|
207 |
+
lossG, G_losses = self.step_generator(real_images, target_images, fake_images) # Compute the loss
|
208 |
lossG.backward()
|
209 |
self.gen_optimizer.step() # Update D
|
210 |
|
|
|
214 |
**G_losses
|
215 |
}
|
216 |
|
217 |
+
def validation_step(self,
|
218 |
+
real_images: torch.Tensor,
|
219 |
+
target_images: torch.Tensor
|
220 |
+
):
|
221 |
+
"""Performs a single validation step.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
real_images: Input images
|
225 |
+
target_images: Ground truth images
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Dictionary containing all loss values from this step
|
229 |
+
"""
|
230 |
+
with torch.no_grad():
|
231 |
+
# Forward pass through the generator
|
232 |
+
fake_images = self.forward(real_images)
|
233 |
+
|
234 |
+
# Compute the loss for D
|
235 |
+
lossD = self.step_discriminator(real_images, target_images, fake_images)
|
236 |
+
|
237 |
+
# Compute the loss for G
|
238 |
+
_, G_losses = self.step_generator(real_images, target_images, fake_images)
|
239 |
+
|
240 |
+
# Return all losses
|
241 |
+
return {
|
242 |
+
'loss_D': lossD.item(),
|
243 |
+
**G_losses
|
244 |
+
}
|
245 |
+
|
246 |
+
def generate(self,
|
247 |
+
real_images: torch.Tensor,
|
248 |
+
is_scaled: bool = False,
|
249 |
+
to_uint8: bool = False
|
250 |
+
):
|
251 |
+
if not is_scaled:
|
252 |
+
real_images = real_images.to(dtype=torch.float32) # Make sure it's a float tensor
|
253 |
+
real_images = real_images / 255.0 # Normalize to [0, 1]
|
254 |
+
real_images = (real_images - 0.5) / 0.5 # Scale to [-1, 1]
|
255 |
+
|
256 |
+
with torch.no_grad(): # generate image
|
257 |
+
generated_images = self.forward(real_images)
|
258 |
+
|
259 |
+
generated_images = (generated_images + 1) / 2 # Rescale to [0, 1]
|
260 |
+
if to_uint8:
|
261 |
+
generated_images = (generated_images* 255).to(dtype=torch.uint8) # Scale to [0, 255] and convert to uint8
|
262 |
+
|
263 |
+
return generated_images
|
264 |
+
|
265 |
+
|
266 |
+
def save_model(self, gen_path: str, disc_path: str = None):
|
267 |
+
"""
|
268 |
+
Saves the generator model's state dictionary to the specified path.
|
269 |
+
If in training mode and a discriminator path is provided, saves the
|
270 |
+
discriminator model's state dictionary as well.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
gen_path (str): The file path where the generator model's state dictionary will be saved.
|
274 |
+
disc_path (str, optional): The file path where the discriminator model's state dictionary will be saved. Defaults to None.
|
275 |
+
"""
|
276 |
+
torch.save(self.gen.state_dict(), gen_path)
|
277 |
+
if self.is_train and disc_path is not None:
|
278 |
+
torch.save(self.disc.state_dict(), disc_path)
|
279 |
+
|
280 |
+
def load_model(self, gen_path: str, disc_path: str = None, device: str = None):
|
281 |
+
"""
|
282 |
+
Loads the generator and optionally the discriminator model from the specified file paths.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
gen_path (str): Path to the generator model file.
|
286 |
+
disc_path (str, optional): Path to the discriminator model file. Defaults to None.
|
287 |
+
device (torch.device, optional): The device on which to load the models. If None, the device of the model's parameters will be used. Defaults to None.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
None
|
291 |
+
"""
|
292 |
+
device = device if device else next(self.gen.parameters()).device
|
293 |
+
self.gen.load_state_dict(torch.load(gen_path, map_location=device, weights_only=True), strict=False)
|
294 |
+
if disc_path is not None and self.is_train:
|
295 |
+
device = device if device else next(self.disc.parameters()).device
|
296 |
+
self.disc.load_state_dict(torch.load(gen_path, map_location=device, weights_only=True), strict=False)
|
297 |
+
|
298 |
+
def get_current_visuals(self,
|
299 |
+
real_images: torch.Tensor,
|
300 |
+
target_images: torch.Tensor
|
301 |
+
):
|
302 |
"""Return visualization images.
|
303 |
|
304 |
Args:
|
pix2pix_disc_ckpt_195.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a548668f92493f28819f8c693254aa3933035d995c383014a0ff7e54c674e71
|
3 |
+
size 11090624
|
pix2pix_disc_ckpt_70.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1c8049a9fe6b87ee521b33e237053189086d330cf0b9b320e7a4776edd5be43
|
3 |
+
size 11090598
|
pix2pix_gen_ckpt_195.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdcbe8c9f4d28858fd51de8d4fcf37194b0ba632b10c70231bb2556989803975
|
3 |
+
size 218246966
|
pix2pix_gen_ckpt_70.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f062c7083ffbedb96c98a1c2241d0101602c8756102dff650b96a80d2a87aae7
|
3 |
+
size 218246868
|