|
import logging |
|
import os |
|
import random |
|
from typing import Union |
|
import warnings |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
|
|
try: |
|
import augment |
|
except ImportError: |
|
raise ImportError( |
|
"augment is not installed, please install it first using:" |
|
"\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e" |
|
) |
|
|
|
from .base import Effect |
|
|
|
_logger = logging.getLogger(__name__) |
|
_DEBUG = bool(os.environ.get("DEBUG", False)) |
|
|
|
|
|
class AttachableEffect(Effect): |
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain: |
|
raise NotImplementedError |
|
|
|
def apply(self, wav: np.ndarray, sr: int): |
|
chain = augment.EffectChain() |
|
chain = self.attach(chain) |
|
tensor = torch.from_numpy(wav)[None].float() |
|
tensor = chain.apply( |
|
tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr} |
|
) |
|
wav = tensor.numpy()[0] |
|
return wav |
|
|
|
|
|
class SoxEffect(AttachableEffect): |
|
def __init__(self, effect_name: str, *args, **kwargs): |
|
self.effect_name = effect_name |
|
self.args = args |
|
self.kwargs = kwargs |
|
|
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain: |
|
_logger.debug( |
|
f"Attaching {self.effect_name} with {self.args} and {self.kwargs}" |
|
) |
|
if not hasattr(chain, self.effect_name): |
|
raise ValueError(f"EffectChain has no attribute {self.effect_name}") |
|
return getattr(chain, self.effect_name)(*self.args, **self.kwargs) |
|
|
|
|
|
class Maybe(AttachableEffect): |
|
""" |
|
Attach an effect with a probability. |
|
""" |
|
|
|
def __init__(self, prob: float, effect: AttachableEffect): |
|
self.prob = prob |
|
self.effect = effect |
|
if _DEBUG: |
|
warnings.warn("DEBUG mode is on. Maybe -> Must.") |
|
self.prob = 1 |
|
|
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain: |
|
if random.random() > self.prob: |
|
return chain |
|
return self.effect.attach(chain) |
|
|
|
|
|
class Chain(AttachableEffect): |
|
""" |
|
Attach a chain of effects. |
|
""" |
|
|
|
def __init__(self, *effects: AttachableEffect): |
|
self.effects = effects |
|
|
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain: |
|
for effect in self.effects: |
|
chain = effect.attach(chain) |
|
return chain |
|
|
|
|
|
class Choice(AttachableEffect): |
|
""" |
|
Attach one of the effects randomly. |
|
""" |
|
|
|
def __init__(self, *effects: AttachableEffect): |
|
self.effects = effects |
|
|
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain: |
|
return random.choice(self.effects).attach(chain) |
|
|
|
|
|
class Generator: |
|
def __call__(self) -> str: |
|
raise NotImplementedError |
|
|
|
|
|
class Uniform(Generator): |
|
def __init__(self, low, high): |
|
self.low = low |
|
self.high = high |
|
|
|
def __call__(self) -> str: |
|
return str(random.uniform(self.low, self.high)) |
|
|
|
|
|
class Randint(Generator): |
|
def __init__(self, low, high): |
|
self.low = low |
|
self.high = high |
|
|
|
def __call__(self) -> str: |
|
return str(random.randint(self.low, self.high)) |
|
|
|
|
|
class Concat(Generator): |
|
def __init__(self, *parts: Union[Generator, str]): |
|
self.parts = parts |
|
|
|
def __call__(self): |
|
return "".join( |
|
[part if isinstance(part, str) else part() for part in self.parts] |
|
) |
|
|
|
|
|
class RandomLowpassDistorter(SoxEffect): |
|
def __init__(self, low=2000, high=16000): |
|
super().__init__( |
|
"sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)) |
|
) |
|
|
|
|
|
class RandomBandpassDistorter(SoxEffect): |
|
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000): |
|
super().__init__( |
|
"sinc", |
|
"-n", |
|
Randint(50, 200), |
|
partial(self._fn, low, high, min_width, max_width), |
|
) |
|
|
|
@staticmethod |
|
def _fn(low, high, min_width, max_width): |
|
start = random.randint(low, high) |
|
stop = start + random.randint(min_width, max_width) |
|
return f"{start}-{stop}" |
|
|
|
|
|
class RandomEqualizer(SoxEffect): |
|
def __init__( |
|
self, |
|
low=100, |
|
high=4000, |
|
q_low=1, |
|
q_high=5, |
|
db_low: int = -30, |
|
db_high: int = 30, |
|
): |
|
super().__init__( |
|
"equalizer", |
|
Uniform(low, high), |
|
lambda: f"{random.randint(q_low, q_high)}q", |
|
lambda: random.randint(db_low, db_high), |
|
) |
|
|
|
|
|
class RandomOverdrive(SoxEffect): |
|
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80): |
|
super().__init__( |
|
"overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high) |
|
) |
|
|
|
|
|
class RandomReverb(Chain): |
|
def __init__(self, deterministic=False): |
|
super().__init__( |
|
SoxEffect( |
|
"reverb", |
|
Uniform(50, 50) if deterministic else Uniform(0, 100), |
|
Uniform(50, 50) if deterministic else Uniform(0, 100), |
|
Uniform(50, 50) if deterministic else Uniform(0, 100), |
|
), |
|
SoxEffect("channels", 1), |
|
) |
|
|
|
|
|
class Flanger(SoxEffect): |
|
def __init__(self): |
|
super().__init__("flanger") |
|
|
|
|
|
class Phaser(SoxEffect): |
|
def __init__(self): |
|
super().__init__("phaser") |
|
|