Spaces:
Runtime error
Runtime error
import os | |
import time | |
import sys | |
from contextlib import contextmanager | |
import torch | |
from constants import * | |
def suppress_stdout(): | |
with open(os.devnull, "w") as devnull: | |
old_stdout = sys.stdout | |
sys.stdout = devnull | |
try: | |
yield | |
finally: | |
sys.stdout = old_stdout | |
def save_checkpoint(state, save_path): | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
torch.save(state, save_path) | |
def freeze(module): | |
for param in module.parameters(): | |
param.requires_grad = False | |
def num_params(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def clamp(x, limit): | |
return max(-limit, min(x, limit)) | |
def pad_to_length(tensor, length, dim, value=0): | |
""" | |
Pad tensor to given length in given dim using given value (value should be numeric) | |
""" | |
assert tensor.size(dim) <= length | |
if tensor.size(dim) < length: | |
zeros_shape = list(tensor.shape) | |
zeros_shape[dim] = length - tensor.size(dim) | |
zeros_shape = tuple(zeros_shape) | |
return torch.cat([tensor, torch.zeros(zeros_shape).type(tensor.type()).to(tensor.device).fill_(value)], dim=dim) | |
else: | |
return tensor | |
def pad_mask(lengths: torch.LongTensor) -> torch.ByteTensor: | |
""" | |
Create a mask of seq x batch where seq = max(lengths), with 0 in padding locations and 1 otherwise. | |
""" | |
# lengths: bs. Ex: [2, 3, 1] | |
max_seqlen = torch.max(lengths) | |
expanded_lengths = lengths.unsqueeze(0).repeat((max_seqlen, 1)) # [[2, 3, 1], [2, 3, 1], [2, 3, 1]] | |
indices = torch.arange(max_seqlen).unsqueeze(1).repeat((1, lengths.size(0))).to(lengths.device) # [[0, 0, 0], [1, 1, 1], [2, 2, 2]] | |
return expanded_lengths > indices # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs | |
class ProgressMeter(object): | |
""" | |
Display meter | |
""" | |
def __init__(self, num_batches, meters, prefix=""): | |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
self.meters = meters | |
self.prefix = prefix | |
def display(self, batch): | |
entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
entries.append(time.ctime(time.time())) | |
entries += [str(meter) for meter in self.meters] | |
print('\t'.join(entries)) | |
def _get_batch_fmtstr(self, num_batches): | |
num_digits = len(str(num_batches // 1)) | |
fmt = '{:' + str(num_digits) + 'd}' | |
return '[' + fmt + '/' + fmt.format(num_batches) + ']' | |
class AverageMeter(object): | |
""" | |
Display meter | |
Computes and stores the average and current value | |
""" | |
def __init__(self, name, fmt=':f'): | |
self.name = name | |
self.fmt = fmt | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def __str__(self): | |
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
return fmtstr.format(**self.__dict__) |