Spaces:
Running
on
T4
Running
on
T4
import torch | |
from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet | |
from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost | |
def create_wgan(parameters, device, optimizer='adam'): | |
if parameters['model'] == "resnet": | |
generator, discriminator = init_resnet(parameters) | |
else: | |
raise NotImplementedError | |
if optimizer == 'adam': | |
optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas']) | |
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas']) | |
elif optimizer == 'rmsprop': | |
optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate']) | |
optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate']) | |
criterion = torch.nn.MSELoss() | |
gan = WassersteinGanQuadraticCost(generator, | |
discriminator, | |
optimizer_g, | |
optimizer_d, | |
criterion=criterion, | |
data_dimensions=parameters['data_dim'], | |
epochs=parameters['epochs'], | |
batch_size=parameters['batch_size'], | |
device=device, | |
n_max_iterations=parameters['n_max_iterations'], | |
gamma=parameters['gamma']) | |
return gan | |