Spaces:
Runtime error
Runtime error
from big.BigGAN2 import Generator,Discriminator | |
from big.losses import generator_loss, discriminator_loss | |
generator = Generator().cuda() | |
discriminator = Discriminator().cuda() | |
label_transformed_fake = label_fc_net(label_fake) | |
label_transformed_real = label_fc_net(label_real) | |
generated_images = generator(decoder_input,label_transformed_fake) | |
#disc training | |
prediction_fake = discriminator(generated_images.detach(),label_transformed_fake).view(-1) | |
prediction_real = discriminator(images,label_transformed_real).view(-1) | |
d_loss_real,d_loss_fake = discriminator_loss(prediction_fake,prediction_real) | |
discriminator.optim.step() | |
#gen training | |
prediction = discriminator(generated_images,label_transformed_fake).view(-1) | |
g_loss = generator_loss( prediction) | |
g_loss.backward() | |
generator.optim.step() | |