Spaces:
Running
on
T4
Running
on
T4
File size: 1,609 Bytes
9e275b8 70399da 9e275b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet
from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost
def create_wgan(parameters, device, optimizer='adam'):
if parameters['model'] == "resnet":
generator, discriminator = init_resnet(parameters)
else:
raise NotImplementedError
if optimizer == 'adam':
optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
elif optimizer == 'rmsprop':
optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
criterion = torch.nn.MSELoss()
gan = WassersteinGanQuadraticCost(generator,
discriminator,
optimizer_g,
optimizer_d,
criterion=criterion,
data_dimensions=parameters['data_dim'],
epochs=parameters['epochs'],
batch_size=parameters['batch_size'],
device=device,
n_max_iterations=parameters['n_max_iterations'],
gamma=parameters['gamma'])
return gan
|