yuulind commited on
Commit
1e50af8
1 Parent(s): a664a45

Add Checkpoint 195 and 70

Browse files
layers.py CHANGED
@@ -1,6 +1,4 @@
1
- import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
 
5
 
6
  class DownsamplingBlock(nn.Module):
 
 
1
  import torch.nn as nn
 
2
 
3
 
4
  class DownsamplingBlock(nn.Module):
networks.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
 
5
- from .layers import DownsamplingBlock, UpsamplingBlock
6
 
7
  class UnetEncoder(nn.Module):
8
  """Create the Unet Encoder Network.
 
1
  import torch
2
  import torch.nn as nn
 
3
 
4
+ from layers import DownsamplingBlock, UpsamplingBlock
5
 
6
  class UnetEncoder(nn.Module):
7
  """Create the Unet Encoder Network.
pix2pix.py CHANGED
@@ -1,10 +1,12 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
 
5
- from .networks import UnetGenerator, PatchGAN
6
 
7
- class Pix2Pix(nn.Module):
 
 
8
  """Create a Pix2Pix class. It is a model for image to image translation tasks.
9
  By default, the model uses a Unet architecture for generator with transposed
10
  convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
@@ -44,11 +46,12 @@ class Pix2Pix(nn.Module):
44
  super(Pix2Pix, self).__init__()
45
  self.is_CGAN = is_CGAN
46
  self.lambda_L1 = lambda_L1
 
47
 
48
  self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
49
  self.gen = self.gen.apply(self.weights_init)
50
 
51
- if is_train:
52
  # Conditional GANs need both input and output together, the total input channel is c_in+c_out
53
  disc_in = c_in + c_out if is_CGAN else c_out
54
  self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
@@ -64,7 +67,7 @@ class Pix2Pix(nn.Module):
64
  self.criterion = nn.BCEWithLogitsLoss()
65
  self.criterion_L1 = nn.L1Loss()
66
 
67
- def forward(self, x):
68
  return self.gen(x)
69
 
70
  @staticmethod
@@ -82,7 +85,11 @@ class Pix2Pix(nn.Module):
82
  nn.init.normal_(m.weight, 1.0, 0.02)
83
  nn.init.constant_(m.bias, 0)
84
 
85
- def _get_disc_inputs(self, real_images, target_images, fake_images):
 
 
 
 
86
  """Prepare discriminator inputs based on conditional/unconditional setup."""
87
  if self.is_CGAN:
88
  # Conditional GANs need both input and output together,
@@ -96,7 +103,10 @@ class Pix2Pix(nn.Module):
96
  fake_AB = fake_images.detach()
97
  return real_AB, fake_AB
98
 
99
- def _get_gen_inputs(self, real_images, fake_images):
 
 
 
100
  """Prepare discriminator inputs based on conditional/unconditional setup."""
101
  if self.is_CGAN:
102
  # Conditional GANs need both input and output together,
@@ -109,7 +119,11 @@ class Pix2Pix(nn.Module):
109
  return fake_AB
110
 
111
 
112
- def step_discriminator(self, real_images, target_images, fake_images):
 
 
 
 
113
  """Discriminator forward/backward pass.
114
 
115
  Args:
@@ -134,7 +148,11 @@ class Pix2Pix(nn.Module):
134
  lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
135
  return lossD
136
 
137
- def step_generator(self, real_images, target_images, fake_images):
 
 
 
 
138
  """Discriminator forward/backward pass.
139
 
140
  Args:
@@ -162,7 +180,10 @@ class Pix2Pix(nn.Module):
162
  'loss_G_L1': lossG_L1.item()
163
  }
164
 
165
- def train_step(self, real_images, target_images):
 
 
 
166
  """Performs a single training step.
167
 
168
  Args:
@@ -177,13 +198,13 @@ class Pix2Pix(nn.Module):
177
 
178
  # Update discriminator
179
  self.disc_optimizer.zero_grad() # Reset the gradients for D
180
- lossD = self.stepD(real_images, target_images, fake_images) # Compute the loss
181
  lossD.backward()
182
  self.disc_optimizer.step() # Update D
183
 
184
  # Update generator
185
  self.gen_optimizer.zero_grad() # Reset the gradients for D
186
- lossG, G_losses = self.stepG(real_images, target_images, fake_images) # Compute the loss
187
  lossG.backward()
188
  self.gen_optimizer.step() # Update D
189
 
@@ -193,7 +214,91 @@ class Pix2Pix(nn.Module):
193
  **G_losses
194
  }
195
 
196
- def get_current_visuals(self, real_images, target_images):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  """Return visualization images.
198
 
199
  Args:
 
1
  import torch
2
  import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
 
5
+ from networks import UnetGenerator, PatchGAN
6
 
7
+ class Pix2Pix(
8
+ nn.Module,
9
+ PyTorchModelHubMixin):
10
  """Create a Pix2Pix class. It is a model for image to image translation tasks.
11
  By default, the model uses a Unet architecture for generator with transposed
12
  convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
 
46
  super(Pix2Pix, self).__init__()
47
  self.is_CGAN = is_CGAN
48
  self.lambda_L1 = lambda_L1
49
+ self.is_train = is_train
50
 
51
  self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
52
  self.gen = self.gen.apply(self.weights_init)
53
 
54
+ if self.is_train:
55
  # Conditional GANs need both input and output together, the total input channel is c_in+c_out
56
  disc_in = c_in + c_out if is_CGAN else c_out
57
  self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
 
67
  self.criterion = nn.BCEWithLogitsLoss()
68
  self.criterion_L1 = nn.L1Loss()
69
 
70
+ def forward(self, x: torch.Tensor):
71
  return self.gen(x)
72
 
73
  @staticmethod
 
85
  nn.init.normal_(m.weight, 1.0, 0.02)
86
  nn.init.constant_(m.bias, 0)
87
 
88
+ def _get_disc_inputs(self,
89
+ real_images: torch.Tensor,
90
+ target_images: torch.Tensor,
91
+ fake_images: torch.Tensor
92
+ ):
93
  """Prepare discriminator inputs based on conditional/unconditional setup."""
94
  if self.is_CGAN:
95
  # Conditional GANs need both input and output together,
 
103
  fake_AB = fake_images.detach()
104
  return real_AB, fake_AB
105
 
106
+ def _get_gen_inputs(self,
107
+ real_images: torch.Tensor,
108
+ fake_images: torch.Tensor
109
+ ):
110
  """Prepare discriminator inputs based on conditional/unconditional setup."""
111
  if self.is_CGAN:
112
  # Conditional GANs need both input and output together,
 
119
  return fake_AB
120
 
121
 
122
+ def step_discriminator(self,
123
+ real_images: torch.Tensor,
124
+ target_images: torch.Tensor,
125
+ fake_images: torch.Tensor
126
+ ):
127
  """Discriminator forward/backward pass.
128
 
129
  Args:
 
148
  lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
149
  return lossD
150
 
151
+ def step_generator(self,
152
+ real_images: torch.Tensor,
153
+ target_images: torch.Tensor,
154
+ fake_images: torch.Tensor
155
+ ):
156
  """Discriminator forward/backward pass.
157
 
158
  Args:
 
180
  'loss_G_L1': lossG_L1.item()
181
  }
182
 
183
+ def train_step(self,
184
+ real_images: torch.Tensor,
185
+ target_images: torch.Tensor
186
+ ):
187
  """Performs a single training step.
188
 
189
  Args:
 
198
 
199
  # Update discriminator
200
  self.disc_optimizer.zero_grad() # Reset the gradients for D
201
+ lossD = self.step_discriminator(real_images, target_images, fake_images) # Compute the loss
202
  lossD.backward()
203
  self.disc_optimizer.step() # Update D
204
 
205
  # Update generator
206
  self.gen_optimizer.zero_grad() # Reset the gradients for D
207
+ lossG, G_losses = self.step_generator(real_images, target_images, fake_images) # Compute the loss
208
  lossG.backward()
209
  self.gen_optimizer.step() # Update D
210
 
 
214
  **G_losses
215
  }
216
 
217
+ def validation_step(self,
218
+ real_images: torch.Tensor,
219
+ target_images: torch.Tensor
220
+ ):
221
+ """Performs a single validation step.
222
+
223
+ Args:
224
+ real_images: Input images
225
+ target_images: Ground truth images
226
+
227
+ Returns:
228
+ Dictionary containing all loss values from this step
229
+ """
230
+ with torch.no_grad():
231
+ # Forward pass through the generator
232
+ fake_images = self.forward(real_images)
233
+
234
+ # Compute the loss for D
235
+ lossD = self.step_discriminator(real_images, target_images, fake_images)
236
+
237
+ # Compute the loss for G
238
+ _, G_losses = self.step_generator(real_images, target_images, fake_images)
239
+
240
+ # Return all losses
241
+ return {
242
+ 'loss_D': lossD.item(),
243
+ **G_losses
244
+ }
245
+
246
+ def generate(self,
247
+ real_images: torch.Tensor,
248
+ is_scaled: bool = False,
249
+ to_uint8: bool = False
250
+ ):
251
+ if not is_scaled:
252
+ real_images = real_images.to(dtype=torch.float32) # Make sure it's a float tensor
253
+ real_images = real_images / 255.0 # Normalize to [0, 1]
254
+ real_images = (real_images - 0.5) / 0.5 # Scale to [-1, 1]
255
+
256
+ with torch.no_grad(): # generate image
257
+ generated_images = self.forward(real_images)
258
+
259
+ generated_images = (generated_images + 1) / 2 # Rescale to [0, 1]
260
+ if to_uint8:
261
+ generated_images = (generated_images* 255).to(dtype=torch.uint8) # Scale to [0, 255] and convert to uint8
262
+
263
+ return generated_images
264
+
265
+
266
+ def save_model(self, gen_path: str, disc_path: str = None):
267
+ """
268
+ Saves the generator model's state dictionary to the specified path.
269
+ If in training mode and a discriminator path is provided, saves the
270
+ discriminator model's state dictionary as well.
271
+
272
+ Args:
273
+ gen_path (str): The file path where the generator model's state dictionary will be saved.
274
+ disc_path (str, optional): The file path where the discriminator model's state dictionary will be saved. Defaults to None.
275
+ """
276
+ torch.save(self.gen.state_dict(), gen_path)
277
+ if self.is_train and disc_path is not None:
278
+ torch.save(self.disc.state_dict(), disc_path)
279
+
280
+ def load_model(self, gen_path: str, disc_path: str = None, device: str = None):
281
+ """
282
+ Loads the generator and optionally the discriminator model from the specified file paths.
283
+
284
+ Args:
285
+ gen_path (str): Path to the generator model file.
286
+ disc_path (str, optional): Path to the discriminator model file. Defaults to None.
287
+ device (torch.device, optional): The device on which to load the models. If None, the device of the model's parameters will be used. Defaults to None.
288
+
289
+ Returns:
290
+ None
291
+ """
292
+ device = device if device else next(self.gen.parameters()).device
293
+ self.gen.load_state_dict(torch.load(gen_path, map_location=device, weights_only=True), strict=False)
294
+ if disc_path is not None and self.is_train:
295
+ device = device if device else next(self.disc.parameters()).device
296
+ self.disc.load_state_dict(torch.load(gen_path, map_location=device, weights_only=True), strict=False)
297
+
298
+ def get_current_visuals(self,
299
+ real_images: torch.Tensor,
300
+ target_images: torch.Tensor
301
+ ):
302
  """Return visualization images.
303
 
304
  Args:
pix2pix_disc_ckpt_195.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a548668f92493f28819f8c693254aa3933035d995c383014a0ff7e54c674e71
3
+ size 11090624
pix2pix_disc_ckpt_70.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1c8049a9fe6b87ee521b33e237053189086d330cf0b9b320e7a4776edd5be43
3
+ size 11090598
pix2pix_gen_ckpt_195.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdcbe8c9f4d28858fd51de8d4fcf37194b0ba632b10c70231bb2556989803975
3
+ size 218246966
pix2pix_gen_ckpt_70.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f062c7083ffbedb96c98a1c2241d0101602c8756102dff650b96a80d2a87aae7
3
+ size 218246868