Spaces:
Running
on
Zero
Running
on
Zero
from collections import OrderedDict | |
import os | |
import numpy as np | |
import random | |
import sys | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
sys.path.append('..') | |
try: | |
import data | |
except: | |
import foleycrafter.models.specvqgan.onset_baseline.data | |
# ---------------------------------------------------- # | |
def load_model(cp_path, net, device=None, strict=True): | |
if not device: | |
device = torch.device('cpu') | |
if os.path.isfile(cp_path): | |
print("=> loading checkpoint '{}'".format(cp_path)) | |
checkpoint = torch.load(cp_path, map_location=device) | |
# check if there is module | |
if list(checkpoint['state_dict'].keys())[0][:7] == 'module.': | |
state_dict = OrderedDict() | |
for k, v in checkpoint['state_dict'].items(): | |
name = k[7:] | |
state_dict[name] = v | |
else: | |
state_dict = checkpoint['state_dict'] | |
net.load_state_dict(state_dict, strict=strict) | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(cp_path, checkpoint['epoch'])) | |
start_epoch = checkpoint['epoch'] | |
else: | |
print("=> no checkpoint found at '{}'".format(cp_path)) | |
start_epoch = 0 | |
sys.exit() | |
return net, start_epoch | |
# ---------------------------------------------------- # | |
def binary_acc(pred, target, thred): | |
pred = pred > thred | |
acc = np.sum(pred == target) / target.shape[0] | |
return acc | |
def calc_acc(prob, labels, k): | |
pred = torch.argsort(prob, dim=-1, descending=True)[..., :k] | |
top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0) | |
return top_k_acc | |
# ---------------------------------------------------- # | |
def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None): | |
data_loader = getattr(data, pr.dataloader) | |
if split == 'train': | |
read_list = pr.list_train | |
elif split == 'val': | |
read_list = pr.list_val | |
elif split == 'test': | |
read_list = pr.list_test | |
dataset = data_loader(args, pr, read_list, split=split) | |
batch_size = batch_size if batch_size else args.batch_size | |
dataset.getitem_test(1) | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=drop_last) | |
return dataset, loader | |
# ---------------------------------------------------- # | |
def make_optimizer(model, args): | |
''' | |
Args: | |
model: NN to train | |
Returns: | |
optimizer: pytorch optmizer for updating the given model parameters. | |
''' | |
if args.optim == 'SGD': | |
optimizer = torch.optim.SGD( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=args.lr, | |
momentum=args.momentum, | |
weight_decay=args.weight_decay, | |
nesterov=False | |
) | |
elif args.optim == 'Adam': | |
optimizer = torch.optim.Adam( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=args.lr, | |
weight_decay=args.weight_decay, | |
) | |
return optimizer | |
def adjust_learning_rate(optimizer, epoch, args): | |
"""Decay the learning rate based on schedule""" | |
lr = args.lr | |
if args.schedule == 'cos': # cosine lr schedule | |
lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs)) | |
elif args.schedule == 'none': # no lr schedule | |
lr = args.lr | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr |