MotionBERT / lib /utils /learning.py
kzielins
motion bert project structure added
dbf90d0
import os
import numpy as np
import torch
import torch.nn as nn
from functools import partial
from lib.model.DSTformer import DSTformer
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def load_pretrained_weights(model, checkpoint):
"""Load pretrianed weights to model
Incompatible layers (unmatched in name or size) will be ignored
Args:
- model (nn.Module): network model, which must not be nn.DataParallel
- weight_path (str): path to pretrained weights
"""
import collections
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = collections.OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
# If the pretrained state_dict was saved as nn.DataParallel,
# keys would contain "module.", which should be ignored.
if k.startswith('module.'):
k = k[7:]
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict, strict=True)
print('load_weight', len(matched_layers))
return model
def partial_train_layers(model, partial_list):
"""Train partial layers of a given model."""
for name, p in model.named_parameters():
p.requires_grad = False
for trainable in partial_list:
if trainable in name:
p.requires_grad = True
break
return model
def load_backbone(args):
if not(hasattr(args, "backbone")):
args.backbone = 'DSTformer' # Default
if args.backbone=='DSTformer':
model_backbone = DSTformer(dim_in=3, dim_out=3, dim_feat=args.dim_feat, dim_rep=args.dim_rep,
depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio, norm_layer=partial(nn.LayerNorm, eps=1e-6),
maxlen=args.maxlen, num_joints=args.num_joints)
elif args.backbone=='TCN':
from lib.model.model_tcn import PoseTCN
model_backbone = PoseTCN()
elif args.backbone=='poseformer':
from lib.model.model_poseformer import PoseTransformer
model_backbone = PoseTransformer(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=32, depth=4,
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0, attn_mask=None)
elif args.backbone=='mixste':
from lib.model.model_mixste import MixSTE2
model_backbone = MixSTE2(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=512, depth=8,
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0)
elif args.backbone=='stgcn':
from lib.model.model_stgcn import Model as STGCN
model_backbone = STGCN()
else:
raise Exception("Undefined backbone type.")
return model_backbone