Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import errno | |
import functools | |
import hashlib | |
import inspect | |
import io | |
import os | |
import random | |
import socket | |
import tempfile | |
import warnings | |
import zlib | |
from contextlib import contextmanager | |
from diffq import UniformQuantizer, DiffQuantizer | |
import torch as th | |
import tqdm | |
from torch import distributed | |
from torch.nn import functional as F | |
def center_trim(tensor, reference): | |
""" | |
Center trim `tensor` with respect to `reference`, along the last dimension. | |
`reference` can also be a number, representing the length to trim to. | |
If the size difference != 0 mod 2, the extra sample is removed on the right side. | |
""" | |
if hasattr(reference, "size"): | |
reference = reference.size(-1) | |
delta = tensor.size(-1) - reference | |
if delta < 0: | |
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") | |
if delta: | |
tensor = tensor[..., delta // 2:-(delta - delta // 2)] | |
return tensor | |
def average_metric(metric, count=1.): | |
""" | |
Average `metric` which should be a float across all hosts. `count` should be | |
the weight for this particular host (i.e. number of examples). | |
""" | |
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') | |
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) | |
return metric[1].item() / metric[0].item() | |
def free_port(host='', low=20000, high=40000): | |
""" | |
Return a port number that is most likely free. | |
This could suffer from a race condition although | |
it should be quite rare. | |
""" | |
sock = socket.socket() | |
while True: | |
port = random.randint(low, high) | |
try: | |
sock.bind((host, port)) | |
except OSError as error: | |
if error.errno == errno.EADDRINUSE: | |
continue | |
raise | |
return port | |
def sizeof_fmt(num, suffix='B'): | |
""" | |
Given `num` bytes, return human readable size. | |
Taken from https://stackoverflow.com/a/1094933 | |
""" | |
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: | |
if abs(num) < 1024.0: | |
return "%3.1f%s%s" % (num, unit, suffix) | |
num /= 1024.0 | |
return "%.1f%s%s" % (num, 'Yi', suffix) | |
def human_seconds(seconds, display='.2f'): | |
""" | |
Given `seconds` seconds, return human readable duration. | |
""" | |
value = seconds * 1e6 | |
ratios = [1e3, 1e3, 60, 60, 24] | |
names = ['us', 'ms', 's', 'min', 'hrs', 'days'] | |
last = names.pop(0) | |
for name, ratio in zip(names, ratios): | |
if value / ratio < 0.3: | |
break | |
value /= ratio | |
last = name | |
return f"{format(value, display)} {last}" | |
class TensorChunk: | |
def __init__(self, tensor, offset=0, length=None): | |
total_length = tensor.shape[-1] | |
assert offset >= 0 | |
assert offset < total_length | |
if length is None: | |
length = total_length - offset | |
else: | |
length = min(total_length - offset, length) | |
self.tensor = tensor | |
self.offset = offset | |
self.length = length | |
self.device = tensor.device | |
def shape(self): | |
shape = list(self.tensor.shape) | |
shape[-1] = self.length | |
return shape | |
def padded(self, target_length): | |
delta = target_length - self.length | |
total_length = self.tensor.shape[-1] | |
assert delta >= 0 | |
start = self.offset - delta // 2 | |
end = start + target_length | |
correct_start = max(0, start) | |
correct_end = min(total_length, end) | |
pad_left = correct_start - start | |
pad_right = end - correct_end | |
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) | |
assert out.shape[-1] == target_length | |
return out | |
def tensor_chunk(tensor_or_chunk): | |
if isinstance(tensor_or_chunk, TensorChunk): | |
return tensor_or_chunk | |
else: | |
assert isinstance(tensor_or_chunk, th.Tensor) | |
return TensorChunk(tensor_or_chunk) | |
def apply_model(model, mix, shifts=None, split=False, | |
overlap=0.25, transition_power=1., progress=False): | |
""" | |
Apply model to a given mixture. | |
Args: | |
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec | |
and apply the oppositve shift to the output. This is repeated `shifts` time and | |
all predictions are averaged. This effectively makes the model time equivariant | |
and improves SDR by up to 0.2 points. | |
split (bool): if True, the input will be broken down in 8 seconds extracts | |
and predictions will be performed individually on each and concatenated. | |
Useful for model with large memory footprint like Tasnet. | |
progress (bool): if True, show a progress bar (requires split=True) | |
""" | |
assert transition_power >= 1, "transition_power < 1 leads to weird behavior." | |
device = mix.device | |
channels, length = mix.shape | |
if split: | |
out = th.zeros(len(model.sources), channels, length, device=device) | |
sum_weight = th.zeros(length, device=device) | |
segment = model.segment_length | |
stride = int((1 - overlap) * segment) | |
offsets = range(0, length, stride) | |
scale = stride / model.samplerate | |
if progress: | |
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') | |
# We start from a triangle shaped weight, with maximal weight in the middle | |
# of the segment. Then we normalize and take to the power `transition_power`. | |
# Large values of transition power will lead to sharper transitions. | |
weight = th.cat([th.arange(1, segment // 2 + 1), | |
th.arange(segment - segment // 2, 0, -1)]).to(device) | |
assert len(weight) == segment | |
# If the overlap < 50%, this will translate to linear transition when | |
# transition_power is 1. | |
weight = (weight / weight.max())**transition_power | |
for offset in offsets: | |
chunk = TensorChunk(mix, offset, segment) | |
chunk_out = apply_model(model, chunk, shifts=shifts) | |
chunk_length = chunk_out.shape[-1] | |
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out | |
sum_weight[offset:offset + segment] += weight[:chunk_length] | |
offset += segment | |
assert sum_weight.min() > 0 | |
out /= sum_weight | |
return out | |
elif shifts: | |
max_shift = int(0.5 * model.samplerate) | |
mix = tensor_chunk(mix) | |
padded_mix = mix.padded(length + 2 * max_shift) | |
out = 0 | |
for _ in range(shifts): | |
offset = random.randint(0, max_shift) | |
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) | |
shifted_out = apply_model(model, shifted) | |
out += shifted_out[..., max_shift - offset:] | |
out /= shifts | |
return out | |
else: | |
valid_length = model.valid_length(length) | |
mix = tensor_chunk(mix) | |
padded_mix = mix.padded(valid_length) | |
with th.no_grad(): | |
out = model(padded_mix.unsqueeze(0))[0] | |
return center_trim(out, length) | |
def temp_filenames(count, delete=True): | |
names = [] | |
try: | |
for _ in range(count): | |
names.append(tempfile.NamedTemporaryFile(delete=False).name) | |
yield names | |
finally: | |
if delete: | |
for name in names: | |
os.unlink(name) | |
def get_quantizer(model, args, optimizer=None): | |
quantizer = None | |
if args.diffq: | |
quantizer = DiffQuantizer( | |
model, min_size=args.q_min_size, group_size=8) | |
if optimizer is not None: | |
quantizer.setup_optimizer(optimizer) | |
elif args.qat: | |
quantizer = UniformQuantizer( | |
model, bits=args.qat, min_size=args.q_min_size) | |
return quantizer | |
def load_model(path, strict=False): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
load_from = path | |
package = th.load(load_from, 'cpu') | |
klass = package["klass"] | |
args = package["args"] | |
kwargs = package["kwargs"] | |
if strict: | |
model = klass(*args, **kwargs) | |
else: | |
sig = inspect.signature(klass) | |
for key in list(kwargs): | |
if key not in sig.parameters: | |
warnings.warn("Dropping inexistant parameter " + key) | |
del kwargs[key] | |
model = klass(*args, **kwargs) | |
state = package["state"] | |
training_args = package["training_args"] | |
quantizer = get_quantizer(model, training_args) | |
set_state(model, quantizer, state) | |
return model | |
def get_state(model, quantizer): | |
if quantizer is None: | |
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} | |
else: | |
state = quantizer.get_quantized_state() | |
buf = io.BytesIO() | |
th.save(state, buf) | |
state = {'compressed': zlib.compress(buf.getvalue())} | |
return state | |
def set_state(model, quantizer, state): | |
if quantizer is None: | |
model.load_state_dict(state) | |
else: | |
buf = io.BytesIO(zlib.decompress(state["compressed"])) | |
state = th.load(buf, "cpu") | |
quantizer.restore_quantized_state(state) | |
return state | |
def save_state(state, path): | |
buf = io.BytesIO() | |
th.save(state, buf) | |
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] | |
path = path.parent / (path.stem + "-" + sig + path.suffix) | |
path.write_bytes(buf.getvalue()) | |
def save_model(model, quantizer, training_args, path): | |
args, kwargs = model._init_args_kwargs | |
klass = model.__class__ | |
state = get_state(model, quantizer) | |
save_to = path | |
package = { | |
'klass': klass, | |
'args': args, | |
'kwargs': kwargs, | |
'state': state, | |
'training_args': training_args, | |
} | |
th.save(package, save_to) | |
def capture_init(init): | |
def __init__(self, *args, **kwargs): | |
self._init_args_kwargs = (args, kwargs) | |
init(self, *args, **kwargs) | |
return __init__ | |