unknown
'torch_utils.py'
81fb07b
raw
history blame
3.6 kB
from collections import OrderedDict
import os
import numpy as np
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.append('..')
try:
import data
except:
import foleycrafter.models.specvqgan.onset_baseline.data
# ---------------------------------------------------- #
def load_model(cp_path, net, device=None, strict=True):
if not device:
device = torch.device('cpu')
if os.path.isfile(cp_path):
print("=> loading checkpoint '{}'".format(cp_path))
checkpoint = torch.load(cp_path, map_location=device)
# check if there is module
if list(checkpoint['state_dict'].keys())[0][:7] == 'module.':
state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:]
state_dict[name] = v
else:
state_dict = checkpoint['state_dict']
net.load_state_dict(state_dict, strict=strict)
print("=> loaded checkpoint '{}' (epoch {})"
.format(cp_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch']
else:
print("=> no checkpoint found at '{}'".format(cp_path))
start_epoch = 0
sys.exit()
return net, start_epoch
# ---------------------------------------------------- #
def binary_acc(pred, target, thred):
pred = pred > thred
acc = np.sum(pred == target) / target.shape[0]
return acc
def calc_acc(prob, labels, k):
pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]
top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0)
return top_k_acc
# ---------------------------------------------------- #
def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None):
data_loader = getattr(data, pr.dataloader)
if split == 'train':
read_list = pr.list_train
elif split == 'val':
read_list = pr.list_val
elif split == 'test':
read_list = pr.list_test
dataset = data_loader(args, pr, read_list, split=split)
batch_size = batch_size if batch_size else args.batch_size
dataset.getitem_test(1)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.num_workers,
pin_memory=True,
drop_last=drop_last)
return dataset, loader
# ---------------------------------------------------- #
def make_optimizer(model, args):
'''
Args:
model: NN to train
Returns:
optimizer: pytorch optmizer for updating the given model parameters.
'''
if args.optim == 'SGD':
optimizer = torch.optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=False
)
elif args.optim == 'Adam':
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay,
)
return optimizer
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.schedule == 'cos': # cosine lr schedule
lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs))
elif args.schedule == 'none': # no lr schedule
lr = args.lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr