|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
import random |
|
import subprocess as sp |
|
import tempfile |
|
|
|
import numpy as np |
|
import torch |
|
from scipy.io import wavfile |
|
|
|
|
|
def i16_pcm(wav): |
|
if wav.dtype == np.int16: |
|
return wav |
|
return (wav * 2**15).clamp_(-2**15, 2**15 - 1).short() |
|
|
|
|
|
def f32_pcm(wav): |
|
if wav.dtype == np.float: |
|
return wav |
|
return wav.float() / 2**15 |
|
|
|
|
|
class RepitchedWrapper: |
|
""" |
|
Wrap a dataset to apply online change of pitch / tempo. |
|
""" |
|
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, tempo_std=5, vocals=[3]): |
|
self.dataset = dataset |
|
self.proba = proba |
|
self.max_pitch = max_pitch |
|
self.max_tempo = max_tempo |
|
self.tempo_std = tempo_std |
|
self.vocals = vocals |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, index): |
|
streams = self.dataset[index] |
|
in_length = streams.shape[-1] |
|
out_length = int((1 - 0.01 * self.max_tempo) * in_length) |
|
|
|
if random.random() < self.proba: |
|
delta_pitch = random.randint(-self.max_pitch, self.max_pitch) |
|
delta_tempo = random.gauss(0, self.tempo_std) |
|
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) |
|
outs = [] |
|
for idx, stream in enumerate(streams): |
|
stream = repitch( |
|
stream, |
|
delta_pitch, |
|
delta_tempo, |
|
voice=idx in self.vocals) |
|
outs.append(stream[:, :out_length]) |
|
streams = torch.stack(outs) |
|
else: |
|
streams = streams[..., :out_length] |
|
return streams |
|
|
|
|
|
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): |
|
""" |
|
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! |
|
pitch is in semi tones. |
|
Requires `soundstretch` to be installed, see |
|
https://www.surina.net/soundtouch/soundstretch.html |
|
""" |
|
outfile = tempfile.NamedTemporaryFile(suffix=".wav") |
|
in_ = io.BytesIO() |
|
wavfile.write(in_, samplerate, i16_pcm(wav).t().numpy()) |
|
command = [ |
|
"soundstretch", |
|
"stdin", |
|
outfile.name, |
|
f"-pitch={pitch}", |
|
f"-tempo={tempo:.6f}", |
|
] |
|
if quick: |
|
command += ["-quick"] |
|
if voice: |
|
command += ["-speech"] |
|
try: |
|
sp.run(command, capture_output=True, input=in_.getvalue(), check=True) |
|
except sp.CalledProcessError as error: |
|
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") |
|
sr, wav = wavfile.read(outfile.name) |
|
wav = wav.copy() |
|
wav = f32_pcm(torch.from_numpy(wav).t()) |
|
assert sr == samplerate |
|
return wav |
|
|