Dusan's picture
Update fudge/util.py
52c1e95
raw
history blame
3.16 kB
import os
import time
import sys
from contextlib import contextmanager
import torch
from fudge.constants import *
@contextmanager
def suppress_stdout():
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
def save_checkpoint(state, save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(state, save_path)
def freeze(module):
for param in module.parameters():
param.requires_grad = False
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def clamp(x, limit):
return max(-limit, min(x, limit))
def pad_to_length(tensor, length, dim, value=0):
"""
Pad tensor to given length in given dim using given value (value should be numeric)
"""
assert tensor.size(dim) <= length
if tensor.size(dim) < length:
zeros_shape = list(tensor.shape)
zeros_shape[dim] = length - tensor.size(dim)
zeros_shape = tuple(zeros_shape)
return torch.cat([tensor, torch.zeros(zeros_shape).type(tensor.type()).to(tensor.device).fill_(value)], dim=dim)
else:
return tensor
def pad_mask(lengths: torch.LongTensor) -> torch.ByteTensor:
"""
Create a mask of seq x batch where seq = max(lengths), with 0 in padding locations and 1 otherwise.
"""
# lengths: bs. Ex: [2, 3, 1]
max_seqlen = torch.max(lengths)
expanded_lengths = lengths.unsqueeze(0).repeat((max_seqlen, 1)) # [[2, 3, 1], [2, 3, 1], [2, 3, 1]]
indices = torch.arange(max_seqlen).unsqueeze(1).repeat((1, lengths.size(0))).to(lengths.device) # [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
return expanded_lengths > indices # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs
class ProgressMeter(object):
"""
Display meter
"""
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries.append(time.ctime(time.time()))
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
class AverageMeter(object):
"""
Display meter
Computes and stores the average and current value
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)