svoice_demo / svoice /utils.py
ahmedghani's picture
initial commit
8235b4f
raw
history blame
No virus
7.28 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez)
import functools
import logging
from contextlib import contextmanager
import inspect
import os
import time
import math
import torch
logger = logging.getLogger(__name__)
def capture_init(init):
"""
Decorate `__init__` with this, and you can then
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
"""
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__
def deserialize_model(package, strict=False):
klass = package['class']
if strict:
model = klass(*package['args'], **package['kwargs'])
else:
sig = inspect.signature(klass)
kw = package['kwargs']
for key in list(kw):
if key not in sig.parameters:
logger.warning("Dropping inexistant parameter %s", key)
del kw[key]
model = klass(*package['args'], **kw)
model.load_state_dict(package['state'])
return model
def copy_state(state):
return {k: v.cpu().clone() for k, v in state.items()}
def serialize_model(model):
args, kwargs = model._init_args_kwargs
state = copy_state(model.state_dict())
return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
@contextmanager
def swap_state(model, state):
old_state = copy_state(model.state_dict())
model.load_state_dict(state)
try:
yield
finally:
model.load_state_dict(old_state)
@contextmanager
def swap_cwd(cwd):
old_cwd = os.getcwd()
os.chdir(cwd)
try:
yield
finally:
os.chdir(old_cwd)
def pull_metric(history, name):
out = []
for metrics in history:
if name in metrics:
out.append(metrics[name])
return out
class LogProgress:
"""
Sort of like tqdm but using log lines and not as real time.
"""
def __init__(self, logger, iterable, updates=5, total=None,
name="LogProgress", level=logging.INFO):
self.iterable = iterable
self.total = total or len(iterable)
self.updates = updates
self.name = name
self.logger = logger
self.level = level
def update(self, **infos):
self._infos = infos
def __iter__(self):
self._iterator = iter(self.iterable)
self._index = -1
self._infos = {}
self._begin = time.time()
return self
def __next__(self):
self._index += 1
try:
value = next(self._iterator)
except StopIteration:
raise
else:
return value
finally:
log_every = max(1, self.total // self.updates)
# logging is delayed by 1 it, in order to have the metrics from update
if self._index >= 1 and self._index % log_every == 0:
self._log()
def _log(self):
self._speed = (1 + self._index) / (time.time() - self._begin)
infos = " | ".join(f"{k.capitalize()} {v}" for k,
v in self._infos.items())
if self._speed < 1e-4:
speed = "oo sec/it"
elif self._speed < 0.1:
speed = f"{1/self._speed:.1f} sec/it"
else:
speed = f"{self._speed:.1f} it/sec"
out = f"{self.name} | {self._index}/{self.total} | {speed}"
if infos:
out += " | " + infos
self.logger.log(self.level, out)
def colorize(text, color):
code = f"\033[{color}m"
restore = f"\033[0m"
return "".join([code, text, restore])
def bold(text):
return colorize(text, "1")
def calculate_grad_norm(model):
total_norm = 0.0
is_first = True
for p in model.parameters():
param_norm = p.data.grad.flatten()
if is_first:
total_norm = param_norm
is_first = False
else:
total_norm = torch.cat((total_norm.unsqueeze(
1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1)
return total_norm.norm(2) ** (1. / 2)
def calculate_weight_norm(model):
total_norm = 0.0
is_first = True
for p in model.parameters():
param_norm = p.data.flatten()
if is_first:
total_norm = param_norm
is_first = False
else:
total_norm = torch.cat((total_norm.unsqueeze(
1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1)
return total_norm.norm(2) ** (1. / 2)
def remove_pad(inputs, inputs_lengths):
"""
Args:
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
inputs_lengths: torch.Tensor, [B]
Returns:
results: a list containing B items, each item is [C, T], T varies
"""
results = []
dim = inputs.dim()
if dim == 3:
C = inputs.size(1)
for input, length in zip(inputs, inputs_lengths):
if dim == 3: # [B, C, T]
results.append(input[:, :length].view(C, -1).cpu().numpy())
elif dim == 2: # [B, T]
results.append(input[:length].view(-1).cpu().numpy())
return results
def overlap_and_add(signal, frame_step):
"""Reconstructs a signal from a framed representation.
Adds potentially overlapping frames of a signal with shape
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
The resulting tensor has shape `[..., output_size]` where
output_size = (frames - 1) * frame_step + frame_length
Args:
signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
Returns:
A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
output_size = (frames - 1) * frame_step + frame_length
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
"""
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]
# gcd=Greatest Common Divisor
subframe_length = math.gcd(frame_length, frame_step)
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
frame = torch.arange(0, output_subframes).unfold(
0, subframes_per_frame, subframe_step)
frame = frame.clone().detach().long().to(signal.device)
# frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU
frame = frame.contiguous().view(-1)
result = signal.new_zeros(
*outer_dimensions, output_subframes, subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result