File size: 1,609 Bytes
9e275b8
 
70399da
 
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch

from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet
from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost


def create_wgan(parameters, device, optimizer='adam'):
    if parameters['model'] == "resnet":
        generator, discriminator = init_resnet(parameters)
    else:
        raise NotImplementedError

    if optimizer == 'adam':
        optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
        optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
    elif optimizer == 'rmsprop':
        optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
        optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])

    criterion = torch.nn.MSELoss()

    gan = WassersteinGanQuadraticCost(generator,
                                      discriminator,
                                      optimizer_g,
                                      optimizer_d,
                                      criterion=criterion,
                                      data_dimensions=parameters['data_dim'],
                                      epochs=parameters['epochs'],
                                      batch_size=parameters['batch_size'],
                                      device=device,
                                      n_max_iterations=parameters['n_max_iterations'],
                                      gamma=parameters['gamma'])

    return gan