|
from typing import TypedDict |
|
|
|
import torch |
|
import torchaudio |
|
|
|
|
|
class AudioDict(TypedDict): |
|
"""Comfy's representation of AUDIO data.""" |
|
|
|
sample_rate: int |
|
waveform: torch.Tensor |
|
|
|
|
|
AudioData = AudioDict | list[AudioDict] |
|
|
|
|
|
class MtbAudio: |
|
"""Base class for audio processing.""" |
|
|
|
@classmethod |
|
def is_stereo( |
|
cls, |
|
audios: AudioData, |
|
) -> bool: |
|
if isinstance(audios, list): |
|
return any(cls.is_stereo(audio) for audio in audios) |
|
else: |
|
return audios["waveform"].shape[1] == 2 |
|
|
|
@staticmethod |
|
def resample(audio: AudioDict, common_sample_rate: int) -> AudioDict: |
|
if audio["sample_rate"] != common_sample_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=audio["sample_rate"], new_freq=common_sample_rate |
|
) |
|
return { |
|
"sample_rate": common_sample_rate, |
|
"waveform": resampler(audio["waveform"]), |
|
} |
|
else: |
|
return audio |
|
|
|
@staticmethod |
|
def to_stereo(audio: AudioDict) -> AudioDict: |
|
if audio["waveform"].shape[1] == 1: |
|
return { |
|
"sample_rate": audio["sample_rate"], |
|
"waveform": torch.cat( |
|
[audio["waveform"], audio["waveform"]], dim=1 |
|
), |
|
} |
|
else: |
|
return audio |
|
|
|
@classmethod |
|
def preprocess_audios( |
|
cls, audios: list[AudioDict] |
|
) -> tuple[list[AudioDict], bool, int]: |
|
max_sample_rate = max([audio["sample_rate"] for audio in audios]) |
|
|
|
resampled_audios = [ |
|
cls.resample(audio, max_sample_rate) for audio in audios |
|
] |
|
|
|
is_stereo = cls.is_stereo(audios) |
|
if is_stereo: |
|
audios = [cls.to_stereo(audio) for audio in resampled_audios] |
|
|
|
return (audios, is_stereo, max_sample_rate) |
|
|
|
|
|
class MTB_AudioCut(MtbAudio): |
|
"""Basic audio cutter, values are in ms.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"audio": ("AUDIO",), |
|
"length": ( |
|
("FLOAT"), |
|
{ |
|
"default": 1000.0, |
|
"min": 0.0, |
|
"max": 999999.0, |
|
"step": 1, |
|
}, |
|
), |
|
"offset": ( |
|
("FLOAT"), |
|
{"default": 0.0, "min": 0.0, "max": 999999.0, "step": 1}, |
|
), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("AUDIO",) |
|
RETURN_NAMES = ("cut_audio",) |
|
CATEGORY = "mtb/audio" |
|
FUNCTION = "cut" |
|
|
|
def cut(self, audio: AudioDict, length: float, offset: float): |
|
sample_rate = audio["sample_rate"] |
|
start_idx = int(offset * sample_rate / 1000) |
|
end_idx = min( |
|
start_idx + int(length * sample_rate / 1000), |
|
audio["waveform"].shape[-1], |
|
) |
|
cut_waveform = audio["waveform"][:, :, start_idx:end_idx] |
|
|
|
return ( |
|
{ |
|
"sample_rate": sample_rate, |
|
"waveform": cut_waveform, |
|
}, |
|
) |
|
|
|
|
|
class MTB_AudioStack(MtbAudio): |
|
"""Stack/Overlay audio inputs (dynamic inputs). |
|
|
|
- pad audios to the longest inputs. |
|
- resample audios to the highest sample rate in the inputs. |
|
- convert them all to stereo if one of the inputs is. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return {"required": {}} |
|
|
|
RETURN_TYPES = ("AUDIO",) |
|
RETURN_NAMES = ("stacked_audio",) |
|
CATEGORY = "mtb/audio" |
|
FUNCTION = "stack" |
|
|
|
def stack(self, **kwargs: AudioDict) -> tuple[AudioDict]: |
|
audios, is_stereo, max_rate = self.preprocess_audios( |
|
list(kwargs.values()) |
|
) |
|
|
|
max_length = max([audio["waveform"].shape[-1] for audio in audios]) |
|
|
|
padded_audios: list[torch.Tensor] = [] |
|
for audio in audios: |
|
padding = torch.zeros( |
|
( |
|
1, |
|
2 if is_stereo else 1, |
|
max_length - audio["waveform"].shape[-1], |
|
) |
|
) |
|
padded_audio = torch.cat([audio["waveform"], padding], dim=-1) |
|
padded_audios.append(padded_audio) |
|
|
|
stacked_waveform = torch.stack(padded_audios, dim=0).sum(dim=0) |
|
|
|
return ( |
|
{ |
|
"sample_rate": max_rate, |
|
"waveform": stacked_waveform, |
|
}, |
|
) |
|
|
|
|
|
class MTB_AudioSequence(MtbAudio): |
|
"""Sequence audio inputs (dynamic inputs). |
|
|
|
- adding silence_duration between each segment |
|
can now also be negative to overlap the clips, safely bound |
|
to the the input length. |
|
- resample audios to the highest sample rate in the inputs. |
|
- convert them all to stereo if one of the inputs is. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"silence_duration": ( |
|
("FLOAT"), |
|
{"default": 0.0, "min": -999.0, "max": 999, "step": 0.01}, |
|
) |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("AUDIO",) |
|
RETURN_NAMES = ("sequenced_audio",) |
|
CATEGORY = "mtb/audio" |
|
FUNCTION = "sequence" |
|
|
|
def sequence(self, silence_duration: float, **kwargs: AudioDict): |
|
audios, is_stereo, max_rate = self.preprocess_audios( |
|
list(kwargs.values()) |
|
) |
|
|
|
sequence: list[torch.Tensor] = [] |
|
for i, audio in enumerate(audios): |
|
if i > 0: |
|
if silence_duration > 0: |
|
silence = torch.zeros( |
|
( |
|
1, |
|
2 if is_stereo else 1, |
|
int(silence_duration * max_rate), |
|
) |
|
) |
|
sequence.append(silence) |
|
elif silence_duration < 0: |
|
overlap = int(abs(silence_duration) * max_rate) |
|
previous_audio = sequence[-1] |
|
overlap = min( |
|
overlap, |
|
previous_audio.shape[-1], |
|
audio["waveform"].shape[-1], |
|
) |
|
if overlap > 0: |
|
overlap_part = ( |
|
previous_audio[:, :, -overlap:] |
|
+ audio["waveform"][:, :, :overlap] |
|
) |
|
sequence[-1] = previous_audio[:, :, :-overlap] |
|
sequence.append(overlap_part) |
|
audio["waveform"] = audio["waveform"][:, :, overlap:] |
|
|
|
sequence.append(audio["waveform"]) |
|
|
|
sequenced_waveform = torch.cat(sequence, dim=-1) |
|
return ( |
|
{ |
|
"sample_rate": max_rate, |
|
"waveform": sequenced_waveform, |
|
}, |
|
) |
|
|
|
|
|
__nodes__ = [MTB_AudioSequence, MTB_AudioStack, MTB_AudioCut] |
|
|