gan-control / patch
hysts's picture
hysts HF staff
Add files
e909f79
raw
history blame
8.91 kB
diff --git a/src/gan_control/inference/controller.py b/src/gan_control/inference/controller.py
index ee464ba..d1907dd 100644
--- a/src/gan_control/inference/controller.py
+++ b/src/gan_control/inference/controller.py
@@ -13,9 +13,9 @@ _log = get_logger(__name__)
class Controller(Inference):
- def __init__(self, controller_dir):
+ def __init__(self, controller_dir, device):
_log.info('Init Controller class...')
- super(Controller, self).__init__(os.path.join(controller_dir, 'generator'))
+ super(Controller, self).__init__(os.path.join(controller_dir, 'generator'), device)
self.fc_controls = {}
self.config_controls = {}
for sub_group_name in self.batch_utils.sub_group_names:
@@ -29,21 +29,21 @@ class Controller(Inference):
@torch.no_grad()
def gen_batch_by_controls(self, batch_size=1, latent=None, normalize=True, input_is_latent=False, static_noise=True, **kwargs):
if latent is None:
- latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda')
+ latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device)
latent = latent.clone()
if input_is_latent:
latent_w = latent
else:
if isinstance(self.model, torch.nn.DataParallel):
- latent_w = self.model.module.style(latent.cuda())
+ latent_w = self.model.module.style(latent.to(self.device))
else:
- latent_w = self.model.style(latent.cuda())
+ latent_w = self.model.style(latent.to(self.device))
for group_key in kwargs.keys():
if self.check_if_group_has_control(group_key):
if group_key == 'expression' and kwargs[group_key].shape[1] == 8:
- group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].cuda().float())
+ group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].to(self.device).float())
else:
- group_w_latent = self.fc_controls[group_key](kwargs[group_key].cuda().float())
+ group_w_latent = self.fc_controls[group_key](kwargs[group_key].to(self.device).float())
latent_w = self.insert_group_w_latent(latent_w, group_w_latent, group_key)
injection_noise = None
if static_noise:
@@ -101,12 +101,12 @@ class Controller(Inference):
ckpt_path = ckpt_list[-1]
ckpt_iter = ckpt_path.split('.')[0]
config = read_json(config_path, return_obj=True)
- ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path))
+ ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=self.device)
group_chunk = self.batch_utils.place_in_latent_dict[sub_group_name if sub_group_name is not 'expression_q' else 'expression']
group_latent_size = group_chunk[1] - group_chunk[0]
_log.info('Init %s Controller...' % sub_group_name)
- controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).cuda()
+ controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).to(self.device)
controller.print()
_log.info('Loading Controller: %s, ckpt iter %s' % (controller_dir_path, ckpt_iter))
diff --git a/src/gan_control/inference/inference.py b/src/gan_control/inference/inference.py
index e6ccedb..4393bb7 100644
--- a/src/gan_control/inference/inference.py
+++ b/src/gan_control/inference/inference.py
@@ -15,10 +15,11 @@ _log = get_logger(__name__)
class Inference():
- def __init__(self, model_dir):
+ def __init__(self, model_dir, device):
_log.info('Init inference class...')
self.model_dir = model_dir
- self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir)
+ self.device = device
+ self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir, device)
self.noise = None
self.reset_noise()
self.mean_w_latent = None
@@ -28,7 +29,7 @@ class Inference():
_log.info('Calc mean_w_latents...')
mean_latent_w_list = []
for i in range(100):
- latent_z = torch.randn(1000, self.config.model_config['latent_size'], device='cuda')
+ latent_z = torch.randn(1000, self.config.model_config['latent_size'], device=self.device)
if isinstance(self.model, torch.nn.DataParallel):
latent_w = self.model.module.style(latent_z).cpu()
else:
@@ -41,9 +42,9 @@ class Inference():
def reset_noise(self):
if isinstance(self.model, torch.nn.DataParallel):
- self.noise = self.model.module.make_noise(device='cuda')
+ self.noise = self.model.module.make_noise(device=self.device)
else:
- self.noise = self.model.make_noise(device='cuda')
+ self.noise = self.model.make_noise(device=self.device)
@staticmethod
def expend_noise(noise, batch_size):
@@ -56,14 +57,14 @@ class Inference():
self.calc_mean_w_latents()
injection_noise = None
if latent is None:
- latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda')
+ latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device)
elif input_is_latent:
- latent = latent.cuda()
+ latent = latent.to(self.device)
for group_key in kwargs.keys():
if group_key not in self.batch_utils.sub_group_names:
raise ValueError('group_key: %s not in sub_group_names %s' % (group_key, str(self.batch_utils.sub_group_names)))
if isinstance(kwargs[group_key], str) and kwargs[group_key] == 'random':
- group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device='cuda'))
+ group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device=self.device))
group_latent_w = group_latent_w[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]]
latent[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] = group_latent_w
if static_noise:
@@ -82,11 +83,11 @@ class Inference():
latent[:, place_in_latent[0]: place_in_latent[1]] = \
truncation * (latent[:, place_in_latent[0]: place_in_latent[1]] - torch.cat(
[self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0
- ).cuda()) + torch.cat(
+ ).to(self.device)) + torch.cat(
[self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0
- ).cuda()
+ ).to(self.device)
- tensor, latent_w = self.model([latent.cuda()], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise)
+ tensor, latent_w = self.model([latent.to(self.device)], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise)
if normalize:
tensor = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu()
return tensor, latent, latent_w
@@ -107,7 +108,7 @@ class Inference():
return grid_image
@staticmethod
- def retrieve_model(model_dir):
+ def retrieve_model(model_dir, device):
config_path = os.path.join(model_dir, 'args.json')
_log.info('Retrieve config from %s' % config_path)
@@ -117,7 +118,7 @@ class Inference():
ckpt_path = ckpt_list[-1]
ckpt_iter = ckpt_path.split('.')[0]
config = read_json(config_path, return_obj=True)
- ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path))
+ ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=device)
batch_utils = None
if not config.model_config['vanilla']:
@@ -140,7 +141,7 @@ class Inference():
fc_config=None if config.model_config['vanilla'] else batch_utils.get_fc_config(),
conv_transpose=config.model_config['conv_transpose'],
noise_mode=config.model_config['g_noise_mode']
- ).cuda()
+ ).to(device)
_log.info('Loading Model: %s, ckpt iter %s' % (model_dir, ckpt_iter))
model.load_state_dict(ckpt['g_ema'])
model = torch.nn.DataParallel(model)