import torch disc_loss_criterion = torch.nn.BCELoss() gen_loss_criterion = torch.nn.BCELoss() real_label = 1 fake_label = 0 def discriminator_loss(gen_images, real_images, disc_net): real = real_images.new_full((real_images.shape[0], 1), real_label) gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) realloss = disc_loss_criterion(disc_net(real_images), real) genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) return (genloss + realloss) / 2 def generator_loss(gen_images, disc_net): output = disc_net(gen_images) cats = output.new_full(output.shape, real_label) return gen_loss_criterion(output, cats)