Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pretty_midi | |
from collections import defaultdict | |
from copy import deepcopy | |
from dataclasses import dataclass, field | |
from itertools import chain | |
import logging | |
import math | |
from pathlib import Path | |
import random | |
import re | |
import typing as tp | |
import warnings | |
import einops | |
from num2words import num2words | |
import spacy | |
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pad_sequence | |
from .chroma import ChromaExtractor | |
from .streaming import StreamingModule | |
from .transformer import create_sin_embedding | |
from ..data.audio import audio_read | |
from ..data.audio_dataset import SegmentInfo | |
from ..data.audio_utils import convert_audio | |
from ..environment import AudioCraftEnvironment | |
from ..quantization import ResidualVectorQuantizer | |
from ..utils.autocast import TorchAutocast | |
from ..utils.cache import EmbeddingCache | |
from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once | |
logger = logging.getLogger(__name__) | |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) | |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask | |
class WavCondition(tp.NamedTuple): | |
wav: torch.Tensor | |
length: torch.Tensor | |
sample_rate: tp.List[int] | |
path: tp.List[tp.Optional[str]] = [] | |
seek_time: tp.List[tp.Optional[float]] = [] | |
class ChordCondition(tp.NamedTuple): | |
chord: torch.Tensor | |
length: torch.Tensor | |
bpm: tp.List[tp.Optional[float]] = [] | |
path: tp.List[tp.Optional[str]] = [] | |
seek_frame: tp.List[tp.Optional[float]] = [] | |
class BeatCondition(tp.NamedTuple): | |
beat: torch.Tensor | |
length: torch.Tensor | |
bpm: tp.List[tp.Optional[float]] = [] | |
path: tp.List[tp.Optional[str]] = [] | |
seek_frame: tp.List[tp.Optional[float]] = [] | |
class JointEmbedCondition(tp.NamedTuple): | |
wav: torch.Tensor | |
text: tp.List[tp.Optional[str]] | |
length: torch.Tensor | |
sample_rate: tp.List[int] | |
path: tp.List[tp.Optional[str]] = [] | |
seek_time: tp.List[tp.Optional[float]] = [] | |
class ConditioningAttributes: | |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) | |
wav: tp.Dict[str, WavCondition] = field(default_factory=dict) | |
beat: tp.Dict[str, BeatCondition] = field(default_factory=dict) | |
chord: tp.Dict[str, ChordCondition] = field(default_factory=dict) | |
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) | |
def __getitem__(self, item): | |
return getattr(self, item) | |
def text_attributes(self): | |
return self.text.keys() | |
def wav_attributes(self): | |
return self.wav.keys() | |
def beat_attributes(self): | |
return self.beat.keys() | |
def chord_attributes(self): | |
return self.chord.keys() | |
def joint_embed_attributes(self): | |
return self.joint_embed.keys() | |
def attributes(self): | |
return { | |
"text": self.text_attributes, | |
"wav": self.wav_attributes, | |
"beat" : self.beat_attributes, | |
"chord": self.chord_attributes, | |
"joint_embed": self.joint_embed_attributes, | |
} | |
def to_flat_dict(self): | |
return { | |
**{f"text.{k}": v for k, v in self.text.items()}, | |
**{f"wav.{k}": v for k, v in self.wav.items()}, | |
**{f"beat.{k}": v for k, v in self.beat.items()}, | |
**{f"chord.{k}": v for k, v in self.chord.items()}, | |
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} | |
} | |
def from_flat_dict(cls, x): | |
out = cls() | |
for k, v in x.items(): | |
kind, att = k.split(".") | |
out[kind][att] = v | |
return out | |
class SegmentWithAttributes(SegmentInfo): | |
"""Base class for all dataclasses that are used for conditioning. | |
All child classes should implement `to_condition_attributes` that converts | |
the existing attributes to a dataclass of type ConditioningAttributes. | |
""" | |
def to_condition_attributes(self) -> ConditioningAttributes: | |
raise NotImplementedError() | |
def nullify_condition(condition: ConditionType, dim: int = 1): | |
"""Transform an input condition to a null condition. | |
The way it is done by converting it to a single zero vector similarly | |
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. | |
Args: | |
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) | |
dim (int): The dimension that will be truncated (should be the time dimension) | |
WARNING!: dim should not be the batch dimension! | |
Returns: | |
ConditionType: A tuple of null condition and mask | |
""" | |
assert dim != 0, "dim cannot be the batch dimension!" | |
assert isinstance(condition, tuple) and \ | |
isinstance(condition[0], torch.Tensor) and \ | |
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" | |
cond, mask = condition | |
B = cond.shape[0] | |
last_dim = cond.dim() - 1 | |
out = cond.transpose(dim, last_dim) | |
out = 0. * out[..., :1] | |
out = out.transpose(dim, last_dim) | |
mask = torch.zeros((B, 1), device=out.device).int() | |
assert cond.dim() == out.dim() | |
return out, mask | |
def nullify_wav(cond: WavCondition) -> WavCondition: | |
"""Transform a WavCondition to a nullified WavCondition. | |
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. | |
Args: | |
cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. | |
Returns: | |
WavCondition: Nullified wav condition. | |
""" | |
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) | |
return WavCondition( | |
wav=null_wav, | |
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), | |
sample_rate=cond.sample_rate, | |
path=[None] * cond.wav.shape[0], | |
seek_time=[None] * cond.wav.shape[0], | |
) | |
def nullify_chord(cond: ChordCondition) -> ChordCondition: | |
"""Transform a ChordCondition to a nullified ChordCondition. | |
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. | |
Args: | |
cond (ChordCondition): Chord condition with chord, tensor of shape [B, C, T]. | |
Returns: | |
ChordCondition: Nullified chord condition. | |
""" | |
null_chord, _ = nullify_condition((cond.chord, torch.zeros_like(cond.chord)), dim=cond.chord.dim() - 1) | |
return ChordCondition( | |
chord=null_chord, | |
length=torch.tensor([0] * cond.chord.shape[0], device=cond.chord.device), | |
bpm=[None] * cond.chord.shape[0], | |
path=[None] * cond.chord.shape[0], | |
seek_frame=[None] * cond.chord.shape[0], | |
) | |
def nullify_beat(cond: BeatCondition) -> BeatCondition: | |
""" | |
Args: | |
cond (ChordCondition): Chord condition with chord, tensor of shape [B, C, T]. | |
Returns: | |
ChordCondition: Nullified chord condition. | |
""" | |
null_beat, _ = nullify_condition((cond.beat, torch.zeros_like(cond.beat)), dim=cond.beat.dim() - 1) | |
return BeatCondition( | |
beat=null_beat, | |
length=torch.tensor([0] * cond.beat.shape[0], device=cond.beat.device), | |
bpm=[None] * cond.beat.shape[0], | |
path=[None] * cond.beat.shape[0], | |
seek_frame=[None] * cond.beat.shape[0], | |
) | |
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: | |
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, | |
and replacing metadata by dummy attributes. | |
Args: | |
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. | |
""" | |
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) | |
return JointEmbedCondition( | |
wav=null_wav, text=[None] * len(embed.text), | |
length=torch.LongTensor([0]).to(embed.wav.device), | |
sample_rate=embed.sample_rate, | |
path=[None] * embed.wav.shape[0], | |
seek_time=[0] * embed.wav.shape[0], | |
) | |
class Tokenizer: | |
"""Base tokenizer implementation | |
(in case we want to introduce more advances tokenizers in the future). | |
""" | |
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
raise NotImplementedError() | |
class WhiteSpaceTokenizer(Tokenizer): | |
"""This tokenizer should be used for natural language descriptions. | |
For example: | |
["he didn't, know he's going home.", 'shorter sentence'] => | |
[[78, 62, 31, 4, 78, 25, 19, 34], | |
[59, 77, 0, 0, 0, 0, 0, 0]] | |
""" | |
PUNCTUATION = "?:!.,;" | |
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", | |
lemma: bool = True, stopwords: bool = True) -> None: | |
self.n_bins = n_bins | |
self.pad_idx = pad_idx | |
self.lemma = lemma | |
self.stopwords = stopwords | |
try: | |
self.nlp = spacy.load(language) | |
except IOError: | |
spacy.cli.download(language) # type: ignore | |
self.nlp = spacy.load(language) | |
def __call__(self, texts: tp.List[tp.Optional[str]], | |
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
"""Take a list of strings and convert them to a tensor of indices. | |
Args: | |
texts (list[str]): List of strings. | |
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: | |
- Indices of words in the LUT. | |
- And a mask indicating where the padding tokens are | |
""" | |
output, lengths = [], [] | |
texts = deepcopy(texts) | |
for i, text in enumerate(texts): | |
# if current sample doesn't have a certain attribute, replace with pad token | |
if text is None: | |
output.append(torch.Tensor([self.pad_idx])) | |
lengths.append(0) | |
continue | |
# convert numbers to words | |
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore | |
# normalize text | |
text = self.nlp(text) # type: ignore | |
# remove stopwords | |
if self.stopwords: | |
text = [w for w in text if not w.is_stop] # type: ignore | |
# remove punctuation | |
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore | |
# lemmatize if needed | |
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore | |
texts[i] = " ".join(text) | |
lengths.append(len(text)) | |
# convert to tensor | |
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) | |
output.append(tokens) | |
mask = length_to_mask(torch.IntTensor(lengths)).int() | |
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() | |
if return_text: | |
return padded_output, mask, texts # type: ignore | |
return padded_output, mask | |
class NoopTokenizer(Tokenizer): | |
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc. | |
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split | |
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will | |
split it to ["Jeff", "Buckley"] and return an index per word. | |
For example: | |
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] | |
["Metal", "Rock", "Classical"] => [0, 223, 51] | |
""" | |
def __init__(self, n_bins: int, pad_idx: int = 0): | |
self.n_bins = n_bins | |
self.pad_idx = pad_idx | |
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
output, lengths = [], [] | |
for text in texts: | |
# if current sample doesn't have a certain attribute, replace with pad token | |
if text is None: | |
output.append(self.pad_idx) | |
lengths.append(0) | |
else: | |
output.append(hash_trick(text, self.n_bins)) | |
lengths.append(1) | |
tokens = torch.LongTensor(output).unsqueeze(1) | |
mask = length_to_mask(torch.IntTensor(lengths)).int() | |
return tokens, mask | |
class BaseConditioner(nn.Module): | |
"""Base model for all conditioner modules. | |
We allow the output dim to be different than the hidden dim for two reasons: | |
1) keep our LUTs small when the vocab is large; | |
2) make all condition dims consistent. | |
Args: | |
dim (int): Hidden dim of the model. | |
output_dim (int): Output dim of the conditioner. | |
""" | |
def __init__(self, dim: int, output_dim: int): | |
super().__init__() | |
self.dim = dim | |
self.output_dim = output_dim | |
self.output_proj = nn.Linear(dim, output_dim) | |
def tokenize(self, *args, **kwargs) -> tp.Any: | |
"""Should be any part of the processing that will lead to a synchronization | |
point, e.g. BPE tokenization with transfer to the GPU. | |
The returned value will be saved and return later when calling forward(). | |
""" | |
raise NotImplementedError() | |
def forward(self, inputs: tp.Any) -> ConditionType: | |
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform). | |
Outputs a ConditionType, after the input data was embedded as a dense vector. | |
Returns: | |
ConditionType: | |
- A tensor of size [B, T, D] where B is the batch size, T is the length of the | |
output embedding and D is the dimension of the embedding. | |
- And a mask indicating where the padding tokens. | |
""" | |
raise NotImplementedError() | |
class TextConditioner(BaseConditioner): | |
... | |
class LUTConditioner(TextConditioner): | |
"""Lookup table TextConditioner. | |
Args: | |
n_bins (int): Number of bins. | |
dim (int): Hidden dim of the model (text-encoder/LUT). | |
output_dim (int): Output dim of the conditioner. | |
tokenizer (str): Name of the tokenizer. | |
pad_idx (int, optional): Index for padding token. Defaults to 0. | |
""" | |
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): | |
super().__init__(dim, output_dim) | |
self.embed = nn.Embedding(n_bins, dim) | |
self.tokenizer: Tokenizer | |
if tokenizer == 'whitespace': | |
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) | |
elif tokenizer == 'noop': | |
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) | |
else: | |
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") | |
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
device = self.embed.weight.device | |
tokens, mask = self.tokenizer(x) | |
tokens, mask = tokens.to(device), mask.to(device) | |
return tokens, mask | |
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: | |
tokens, mask = inputs | |
embeds = self.embed(tokens) | |
embeds = self.output_proj(embeds) | |
embeds = (embeds * mask.unsqueeze(-1)) | |
return embeds, mask | |
class T5Conditioner(TextConditioner): | |
"""T5-based TextConditioner. | |
Args: | |
name (str): Name of the T5 model. | |
output_dim (int): Output dim of the conditioner. | |
finetune (bool): Whether to fine-tune T5 at train time. | |
device (str): Device for T5 Conditioner. | |
autocast_dtype (tp.Optional[str], optional): Autocast dtype. | |
word_dropout (float, optional): Word dropout probability. | |
normalize_text (bool, optional): Whether to apply text normalization. | |
""" | |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", | |
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", | |
"google/flan-t5-xl", "google/flan-t5-xxl"] | |
MODELS_DIMS = { | |
"t5-small": 512, | |
"t5-base": 768, | |
"t5-large": 1024, | |
"t5-3b": 1024, | |
"t5-11b": 1024, | |
"google/flan-t5-small": 512, | |
"google/flan-t5-base": 768, | |
"google/flan-t5-large": 1024, | |
"google/flan-t5-3b": 1024, | |
"google/flan-t5-11b": 1024, | |
} | |
def __init__(self, name: str, output_dim: int, finetune: bool, device: str, | |
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., | |
normalize_text: bool = False): | |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" | |
super().__init__(self.MODELS_DIMS[name], output_dim) | |
self.device = device | |
self.name = name | |
self.finetune = finetune | |
self.word_dropout = word_dropout | |
if autocast_dtype is None or self.device == 'cpu': | |
self.autocast = TorchAutocast(enabled=False) | |
if self.device != 'cpu': | |
logger.warning("T5 has no autocast, this might lead to NaN") | |
else: | |
dtype = getattr(torch, autocast_dtype) | |
assert isinstance(dtype, torch.dtype) | |
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") | |
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) | |
# Let's disable logging temporarily because T5 will vomit some errors otherwise. | |
# thanks https://gist.github.com/simon-weber/7853144 | |
previous_level = logging.root.manager.disable | |
logging.disable(logging.ERROR) | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
try: | |
self.t5_tokenizer = T5Tokenizer.from_pretrained(name) | |
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) | |
finally: | |
logging.disable(previous_level) | |
if finetune: | |
self.t5 = t5 | |
else: | |
# this makes sure that the t5 models is not part | |
# of the saved checkpoint | |
self.__dict__['t5'] = t5.to(device) | |
self.normalize_text = normalize_text | |
if normalize_text: | |
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) | |
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: | |
# if current sample doesn't have a certain attribute, replace with empty string | |
entries: tp.List[str] = [xi if xi is not None else "" for xi in x] | |
if self.normalize_text: | |
_, _, entries = self.text_normalizer(entries, return_text=True) | |
if self.word_dropout > 0. and self.training: | |
new_entries = [] | |
for entry in entries: | |
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] | |
new_entries.append(" ".join(words)) | |
entries = new_entries | |
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) | |
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) | |
mask = inputs['attention_mask'] | |
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant | |
return inputs | |
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: | |
mask = inputs['attention_mask'] | |
with torch.set_grad_enabled(self.finetune), self.autocast: | |
embeds = self.t5(**inputs).last_hidden_state | |
embeds = self.output_proj(embeds.to(self.output_proj.weight)) | |
embeds = (embeds * mask.unsqueeze(-1)) | |
return embeds, mask | |
class WaveformConditioner(BaseConditioner): | |
"""Base class for all conditioners that take a waveform as input. | |
Classes that inherit must implement `_get_wav_embedding` that outputs | |
a continuous tensor, and `_downsampling_factor` that returns the down-sampling | |
factor of the embedding model. | |
Args: | |
dim (int): The internal representation dimension. | |
output_dim (int): Output dimension. | |
device (tp.Union[torch.device, str]): Device. | |
""" | |
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): | |
super().__init__(dim, output_dim) | |
self.device = device | |
def tokenize(self, x: WavCondition) -> WavCondition: | |
wav, length, sample_rate, path, seek_time = x | |
assert length is not None | |
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) | |
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: | |
"""Gets as input a WavCondition and returns a dense embedding.""" | |
raise NotImplementedError() | |
def _downsampling_factor(self): | |
"""Returns the downsampling factor of the embedding model.""" | |
raise NotImplementedError() | |
def forward(self, x: WavCondition) -> ConditionType: | |
"""Extract condition embedding and mask from a waveform and its metadata. | |
Args: | |
x (WavCondition): Waveform condition containing raw waveform and metadata. | |
Returns: | |
ConditionType: a dense vector representing the conditioning along with its mask | |
""" | |
wav, lengths, *_ = x | |
with torch.no_grad(): | |
embeds = self._get_wav_embedding(x) | |
embeds = embeds.to(self.output_proj.weight) | |
embeds = self.output_proj(embeds) | |
if lengths is not None: | |
lengths = lengths / self._downsampling_factor() | |
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore | |
else: | |
mask = torch.ones_like(embeds) | |
embeds = (embeds * mask.unsqueeze(2).to(self.device)) | |
return embeds, mask | |
class ChromaStemConditioner(WaveformConditioner): | |
"""Chroma conditioner based on stems. | |
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as | |
the drums and bass often dominate the chroma leading to the chroma features | |
not containing information about the melody. | |
Args: | |
output_dim (int): Output dimension for the conditioner. | |
sample_rate (int): Sample rate for the chroma extractor. | |
n_chroma (int): Number of chroma bins for the chroma extractor. | |
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12). | |
duration (int): duration used during training. This is later used for correct padding | |
in case we are using chroma as prefix. | |
match_len_on_eval (bool, optional): if True then all chromas are padded to the training | |
duration. Defaults to False. | |
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as | |
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). | |
Defaults to None. | |
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0. | |
device (tp.Union[torch.device, str], optional): Device for the conditioner. | |
**kwargs: Additional parameters for the chroma extractor. | |
""" | |
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, | |
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, | |
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None, | |
device: tp.Union[torch.device, str] = 'cpu', **kwargs): | |
from demucs import pretrained | |
super().__init__(dim=n_chroma, output_dim=output_dim, device=device) | |
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) | |
self.sample_rate = sample_rate | |
self.match_len_on_eval = match_len_on_eval | |
self.duration = duration | |
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) | |
stem_sources: list = self.demucs.sources # type: ignore | |
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device) | |
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, | |
radix2_exp=radix2_exp, **kwargs).to(device) | |
self.chroma_len = self._get_chroma_len() | |
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs) | |
self.cache = None | |
if cache_path is not None: | |
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, | |
compute_embed_fn=self._get_full_chroma_for_cache, | |
extract_embed_fn=self._extract_chroma_chunk) | |
def _downsampling_factor(self) -> int: | |
return self.chroma.winhop | |
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]: | |
"""Load pre-defined waveforms from a json. | |
These waveforms will be used for chroma extraction during evaluation. | |
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps). | |
""" | |
if path is None: | |
return None | |
logger.info(f"Loading evaluation wavs from {path}") | |
from audiocraft.data.audio_dataset import AudioDataset | |
dataset: AudioDataset = AudioDataset.from_meta( | |
path, segment_duration=self.duration, min_audio_duration=self.duration, | |
sample_rate=self.sample_rate, channels=1) | |
if len(dataset) > 0: | |
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device) | |
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner") | |
return eval_wavs | |
else: | |
raise ValueError("Could not find evaluation wavs, check lengths of wavs") | |
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: | |
self.eval_wavs = eval_wavs | |
def has_eval_wavs(self) -> bool: | |
return self.eval_wavs is not None | |
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: | |
"""Sample wavs from a predefined list.""" | |
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided." | |
total_eval_wavs = len(self.eval_wavs) | |
out = self.eval_wavs | |
if num_samples > total_eval_wavs: | |
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1) | |
return out[torch.randperm(len(out))][:num_samples] | |
def _get_chroma_len(self) -> int: | |
"""Get length of chroma during training.""" | |
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device) | |
dummy_chr = self.chroma(dummy_wav) | |
return dummy_chr.shape[1] | |
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: | |
"""Get parts of the wav that holds the melody, extracting the main stems from the wav.""" | |
from demucs.apply import apply_model | |
from demucs.audio import convert_audio | |
with self.autocast: | |
wav = convert_audio( | |
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore | |
stems = apply_model(self.demucs, wav, device=self.device) | |
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning | |
mix_wav = stems.sum(1) # merge extracted stems to single waveform | |
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore | |
return mix_wav | |
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: | |
"""Extract chroma features from the waveform.""" | |
with self.autocast: | |
return self.chroma(wav) | |
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: | |
"""Compute wav embedding, applying stem and chroma extraction.""" | |
# avoid 0-size tensors when we are working with null conds | |
if wav.shape[-1] == 1: | |
return self._extract_chroma(wav) | |
stems = self._get_stemmed_wav(wav, sample_rate) | |
chroma = self._extract_chroma(stems) | |
return chroma | |
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor: | |
"""Extract chroma from the whole audio waveform at the given path.""" | |
wav, sr = audio_read(path) | |
wav = wav[None].to(self.device) | |
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) | |
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0] | |
return chroma | |
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: | |
"""Extract a chunk of chroma from the full chroma derived from the full waveform.""" | |
wav_length = x.wav.shape[-1] | |
seek_time = x.seek_time[idx] | |
assert seek_time is not None, ( | |
"WavCondition seek_time is required " | |
"when extracting chroma chunks from pre-computed chroma.") | |
full_chroma = full_chroma.float() | |
frame_rate = self.sample_rate / self._downsampling_factor() | |
target_length = int(frame_rate * wav_length / self.sample_rate) | |
index = int(frame_rate * seek_time) | |
out = full_chroma[index: index + target_length] | |
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0] | |
return out.to(self.device) | |
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: | |
"""Get the wav embedding from the WavCondition. | |
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly | |
or will rely on the embedding cache to load the pre-computed embedding if relevant. | |
""" | |
sampled_wav: tp.Optional[torch.Tensor] = None | |
if not self.training and self.eval_wavs is not None: | |
warn_once(logger, "Using precomputed evaluation wavs!") | |
sampled_wav = self._sample_eval_wavs(len(x.wav)) | |
no_undefined_paths = all(p is not None for p in x.path) | |
no_nullified_cond = x.wav.shape[-1] > 1 | |
if sampled_wav is not None: | |
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate) | |
elif self.cache is not None and no_undefined_paths and no_nullified_cond: | |
paths = [Path(p) for p in x.path if p is not None] | |
chroma = self.cache.get_embed_from_cache(paths, x) | |
else: | |
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." | |
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0]) | |
if self.match_len_on_eval: | |
B, T, C = chroma.shape | |
if T > self.chroma_len: | |
chroma = chroma[:, :self.chroma_len] | |
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})") | |
elif T < self.chroma_len: | |
n_repeat = int(math.ceil(self.chroma_len / T)) | |
chroma = chroma.repeat(1, n_repeat, 1) | |
chroma = chroma[:, :self.chroma_len] | |
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})") | |
return chroma | |
def tokenize(self, x: WavCondition) -> WavCondition: | |
"""Apply WavConditioner tokenization and populate cache if needed.""" | |
x = super().tokenize(x) | |
no_undefined_paths = all(p is not None for p in x.path) | |
if self.cache is not None and no_undefined_paths: | |
paths = [Path(p) for p in x.path if p is not None] | |
self.cache.populate_embed_cache(paths, x) | |
return x | |
class ChordProgressionConditioner(BaseConditioner): | |
"""Chord progression conditioning supporting chord progression conditioning. | |
Args: | |
dim (int): Dimension. | |
output_dim (int): Output dimension. | |
device (str): Device. | |
attribute (str): Attribute used by the conditioner. | |
autocast_dtype (str): Autocast for the conditioner. | |
""" | |
def __init__(self, output_dim: int, device: str, name: str): | |
n_chroma = 12 | |
# n_chroma = 24 | |
super().__init__(dim=n_chroma, output_dim=output_dim) | |
self.device = device | |
def forward(self, x: ChordCondition) -> ConditionType: | |
chord, lengths, *_ = x | |
embeds = chord.to(self.output_proj.weight) # chrod is already a tensor, [N, C] | |
embeds = self.output_proj(embeds) | |
if lengths is not None: | |
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore | |
else: | |
mask = torch.ones_like(embeds) | |
embeds = (embeds * mask.unsqueeze(2).to(self.device)) | |
return embeds, mask | |
def tokenize(self, x: ChordCondition) -> ChordCondition: | |
"""Apply ChordConditioner tokenization and populate cache if needed.""" | |
chord, length, bpm, path, seek_frame = x | |
chord = F.pad(chord, (0, length[0] - chord.shape[-1])) # [B, C, t] -> [B, C, T] | |
chord = chord.permute(0, 2, 1) # [B, T, C] | |
x = ChordCondition(chord.to(self.device), length.to(self.device), bpm, path, seek_frame) | |
return x | |
class BeatConditioner(BaseConditioner): | |
"""Beat conditioning supporting beat conditioning. | |
Args: | |
dim (int): Dimension. | |
output_dim (int): Output dimension. | |
device (str): Device. | |
attribute (str): Attribute used by the conditioner. | |
autocast_dtype (str): Autocast for the conditioner. | |
""" | |
def __init__(self, output_dim: int, device: str, name: str): | |
beat_channel = 1 | |
super().__init__(dim=beat_channel, output_dim=output_dim) | |
self.device = device | |
def forward(self, x: BeatCondition) -> ConditionType: | |
beat, lengths, *_ = x | |
embeds = beat.to(self.output_proj.weight) # chrod is already a tensor, [N, C] | |
embeds = self.output_proj(embeds) | |
if lengths is not None: | |
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore | |
else: | |
mask = torch.ones_like(embeds) | |
embeds = (embeds * mask.unsqueeze(2).to(self.device)) | |
return embeds, mask | |
def tokenize(self, x: BeatCondition) -> BeatCondition: | |
"""Apply ChordConditioner tokenization and populate cache if needed.""" | |
beat, length, bpm, path, seek_frame = x | |
beat = F.pad(beat, (0, length[0] - beat.shape[-1])) # [B, C, t] -> [B, C, T] | |
beat = beat.permute(0, 2, 1) # [B, T, C] | |
x = BeatCondition(beat.to(self.device), length.to(self.device), bpm, path, seek_frame) | |
return x | |
class JointEmbeddingConditioner(BaseConditioner): | |
"""Joint embedding conditioning supporting both audio or text conditioning. | |
Args: | |
dim (int): Dimension. | |
output_dim (int): Output dimension. | |
device (str): Device. | |
attribute (str): Attribute used by the conditioner. | |
autocast_dtype (str): Autocast for the conditioner. | |
quantize (bool): Whether to quantize the CLAP embedding. | |
n_q (int): Number of residual quantizers (used if quantize is true). | |
bins (int): Quantizers' codebooks size (used if quantize is true). | |
kwargs: Additional parameters for residual vector quantizer. | |
""" | |
def __init__(self, dim: int, output_dim: int, device: str, attribute: str, | |
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, | |
n_q: int = 12, bins: int = 1024, **kwargs): | |
super().__init__(dim=dim, output_dim=output_dim) | |
self.device = device | |
self.attribute = attribute | |
if autocast_dtype is None or device == 'cpu': | |
self.autocast = TorchAutocast(enabled=False) | |
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") | |
else: | |
dtype = getattr(torch, autocast_dtype) | |
assert isinstance(dtype, torch.dtype) | |
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") | |
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) | |
# residual vector quantizer to discretize the conditioned embedding | |
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None | |
if quantize: | |
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) | |
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
"""Get joint embedding in latent space from the inputs. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding | |
and corresponding empty indexes. | |
""" | |
raise NotImplementedError() | |
def forward(self, x: JointEmbedCondition) -> ConditionType: | |
with self.autocast: | |
embed, empty_idx = self._get_embed(x) | |
if self.quantizer is not None: | |
embed = embed.view(-1, self.dim, 1) | |
q_res = self.quantizer(embed, frame_rate=1) | |
out_embed = q_res.x.view(-1, self.dim) | |
else: | |
out_embed = embed | |
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) | |
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) | |
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant | |
out_embed = (out_embed * mask.unsqueeze(-1)) | |
return out_embed, mask | |
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: | |
return x | |
class CLAPEmbeddingConditioner(JointEmbeddingConditioner): | |
"""Joint Embedding conditioner based on pre-trained CLAP model. | |
This CLAP-based conditioner supports a caching mechanism | |
over the computed embeddings for faster training. | |
Args: | |
dim (int): Dimension. | |
output_dim (int): Output dimension. | |
device (str): Device. | |
attribute (str): Attribute used by the conditioner. | |
quantize (bool): Whether to quantize the CLAP embedding. | |
n_q (int): Number of residual quantizers (used if quantize is true). | |
bins (int): Quantizers' codebooks size (used if quantize is true). | |
checkpoint (str): Path to CLAP checkpoint. | |
model_arch (str): CLAP model architecture. | |
enable_fusion (bool): Enable fusion for CLAP model. | |
sample_rate (int): Sample rate used by CLAP model. | |
max_audio_length (float): Maximum audio length for CLAP model. | |
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. | |
normalize (bool): Whether to normalize the CLAP embedding. | |
text_p (float): Probability of using text representation instead of audio at train time. | |
batch_size (Optional[int]): Batch size for CLAP embedding computation. | |
autocast_dtype (str): Autocast for the conditioner. | |
cache_path (Optional[str]): Path for pre-computed embeddings caching. | |
kwargs: Additional parameters for residual vector quantizer. | |
""" | |
def __init__(self, dim: int, output_dim: int, device: str, attribute: str, | |
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, | |
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, | |
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, | |
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): | |
try: | |
import laion_clap # type: ignore | |
except ImportError: | |
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") | |
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) | |
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') | |
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) | |
load_clap_state_dict(clap_model, checkpoint) | |
clap_model.eval() | |
clap_model.to(device) | |
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, | |
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, | |
**kwargs) | |
self.checkpoint = checkpoint | |
self.enable_fusion = enable_fusion | |
self.model_arch = model_arch | |
self.clap: laion_clap.CLAP_Module | |
self.clap_tokenize: RobertaTokenizer | |
self.clap_sample_rate = sample_rate | |
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) | |
self.clap_stride = int(self.clap_sample_rate * audio_stride) | |
self.batch_size = batch_size or 1 | |
self.normalize = normalize | |
self.text_p = text_p | |
self.__dict__['clap_tokenize'] = clap_tokenize | |
self.__dict__['clap'] = clap_model | |
self.wav_cache, self.text_cache = None, None | |
if cache_path is not None: | |
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, | |
compute_embed_fn=self._get_wav_embedding_for_cache, | |
extract_embed_fn=self._extract_wav_embedding_chunk) | |
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, | |
compute_embed_fn=self._get_text_embedding_for_cache) | |
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: | |
# we use the default params from CLAP module here as well | |
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") | |
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: | |
"""Compute text embedding from CLAP model on a given a batch of text. | |
Args: | |
text (list[str]): List of text for the batch, with B items. | |
Returns: | |
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. | |
""" | |
with torch.no_grad(): | |
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) | |
return embed.view(embed.size(0), 1, embed.size(-1)) | |
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], | |
x: JointEmbedCondition, idx: int) -> torch.Tensor: | |
"""Get text embedding function for the cache.""" | |
text = x.text[idx] | |
text = text if text is not None else "" | |
return self._compute_text_embedding([text])[0] | |
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: | |
"""Preprocess wav to expected format by CLAP model. | |
Args: | |
wav (torch.Tensor): Audio wav, of shape [B, C, T]. | |
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. | |
sample_rates (list[int]): Sample rates for each sample in the batch | |
Returns: | |
torch.Tensor: Audio wav of shape [B, T]. | |
""" | |
assert wav.dim() == 3, "Expecting wav to be [B, C, T]" | |
if sample_rates is not None: | |
_wav = [] | |
for i, audio in enumerate(wav): | |
sr = sample_rates[i] | |
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) | |
_wav.append(audio) | |
wav = torch.stack(_wav, dim=0) | |
wav = wav.mean(dim=1) | |
return wav | |
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, | |
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: | |
"""Compute audio wave embedding from CLAP model. | |
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, | |
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and | |
average the resulting embeddings. | |
Args: | |
wav (torch.Tensor): Audio wav, of shape [B, C, T]. | |
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. | |
sample_rates (list[int]): Sample rates for each sample in the batch. | |
reduce_mean (bool): Whether to get the average tensor. | |
Returns: | |
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. | |
""" | |
with torch.no_grad(): | |
wav = self._preprocess_wav(wav, length, sample_rates) | |
B, T = wav.shape | |
if T >= self.clap_max_frames: | |
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] | |
else: | |
wav = wav.view(-1, 1, T) # [B, F, T] with F=1 | |
wav = einops.rearrange(wav, 'b f t -> (b f) t') | |
embed_list = [] | |
for i in range(0, wav.size(0), self.batch_size): | |
_wav = wav[i:i+self.batch_size, ...] | |
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) | |
embed_list.append(_embed) | |
embed = torch.cat(embed_list, dim=0) | |
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) | |
if reduce_mean: | |
embed = embed.mean(dim=1, keepdim=True) | |
return embed # [B, F, D] with F=1 if reduce_mean is True | |
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], | |
x: JointEmbedCondition, idx: int) -> torch.Tensor: | |
"""Compute audio wave embedding for the cache. | |
The embedding is computed on a given audio read from file. | |
Args: | |
path (str or Path): Path to the full audio file. | |
Returns: | |
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. | |
""" | |
wav, sr = audio_read(path) # [C, T] | |
wav = wav.unsqueeze(0).to(self.device) # [1, C, T] | |
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) | |
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] | |
return embed.squeeze(0) # [F, D] | |
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: | |
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. | |
Args: | |
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. | |
x (JointEmbedCondition): Joint embedding condition for the full batch. | |
idx (int): Index considered for the given embedding to extract. | |
Returns: | |
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. | |
""" | |
sample_rate = x.sample_rate[idx] | |
seek_time = x.seek_time[idx] | |
seek_time = 0. if seek_time is None else seek_time | |
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate | |
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate | |
start_offset = int(seek_time * sample_rate // clap_stride) | |
end_offset = int(end_seek_time * sample_rate // clap_stride) | |
wav_embed = full_embed[start_offset:end_offset, ...] | |
wav_embed = wav_embed.mean(dim=0, keepdim=True) | |
return wav_embed.to(self.device) # [F, D] | |
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: | |
"""Get CLAP embedding from a batch of text descriptions.""" | |
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout | |
if self.text_cache is not None and no_nullified_cond: | |
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" | |
paths = [Path(p) for p in x.path if p is not None] | |
embed = self.text_cache.get_embed_from_cache(paths, x) | |
else: | |
text = [xi if xi is not None else "" for xi in x.text] | |
embed = self._compute_text_embedding(text) | |
if self.normalize: | |
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) | |
return embed | |
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: | |
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" | |
no_undefined_paths = all(p is not None for p in x.path) | |
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout | |
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: | |
paths = [Path(p) for p in x.path if p is not None] | |
embed = self.wav_cache.get_embed_from_cache(paths, x) | |
else: | |
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) | |
if self.normalize: | |
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) | |
return embed | |
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: | |
# Trying to limit as much as possible sync points when the cache is warm. | |
no_undefined_paths = all(p is not None for p in x.path) | |
if self.wav_cache is not None and no_undefined_paths: | |
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" | |
paths = [Path(p) for p in x.path if p is not None] | |
self.wav_cache.populate_embed_cache(paths, x) | |
if self.text_cache is not None and no_undefined_paths: | |
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" | |
paths = [Path(p) for p in x.path if p is not None] | |
self.text_cache.populate_embed_cache(paths, x) | |
return x | |
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
"""Extract shared latent representation from either the wav or the text using CLAP.""" | |
# decide whether to use text embedding at train time or not | |
use_text_embed = random.random() < self.text_p | |
if self.training and not use_text_embed: | |
embed = self._get_wav_embedding(x) | |
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav | |
else: | |
embed = self._get_text_embedding(x) | |
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) | |
return embed, empty_idx | |
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: | |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object. | |
If the condition is of type "wav", then nullify it using `nullify_condition` function. | |
If the condition is of any other type, set its value to None. | |
Works in-place. | |
""" | |
if condition_type not in ['text', 'wav', 'beat', 'chord', 'joint_embed']: | |
raise ValueError( | |
"dropout_condition got an unexpected condition type!" | |
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" | |
) | |
if condition not in getattr(sample, condition_type): | |
raise ValueError( | |
"dropout_condition received an unexpected condition!" | |
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" | |
f" but got '{condition}' of type '{condition_type}'!" | |
) | |
if condition_type == 'wav': | |
wav_cond = sample.wav[condition] | |
sample.wav[condition] = nullify_wav(wav_cond) | |
elif condition_type == 'beat': | |
beat_cond = sample.beat[condition] | |
sample.beat[condition] = nullify_beat(beat_cond) | |
elif condition_type == 'chord': | |
chord_cond = sample.chord[condition] | |
sample.chord[condition] = nullify_chord(chord_cond) | |
elif condition_type == 'joint_embed': | |
embed = sample.joint_embed[condition] | |
sample.joint_embed[condition] = nullify_joint_embed(embed) | |
else: | |
sample.text[condition] = None | |
return sample | |
class DropoutModule(nn.Module): | |
"""Base module for all dropout modules.""" | |
def __init__(self, seed: int = 1234): | |
super().__init__() | |
self.rng = torch.Generator() | |
self.rng.manual_seed(seed) | |
class AttributeDropout(DropoutModule): | |
"""Dropout with a given probability per attribute. | |
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes | |
to be dropped out separately. For example, "artist" can be dropped while "genre" remains. | |
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" | |
must also be dropped. | |
Args: | |
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: | |
... | |
"genre": 0.1, | |
"artist": 0.5, | |
"wav": 0.25, | |
... | |
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. | |
seed (int, optional): Random seed. | |
""" | |
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): | |
super().__init__(seed=seed) | |
self.active_on_eval = active_on_eval | |
# construct dict that return the values from p otherwise 0 | |
self.p = {} | |
for condition_type, probs in p.items(): | |
self.p[condition_type] = defaultdict(lambda: 0, probs) | |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: | |
""" | |
Args: | |
samples (list[ConditioningAttributes]): List of conditions. | |
Returns: | |
list[ConditioningAttributes]: List of conditions after certain attributes were set to None. | |
""" | |
if not self.training and not self.active_on_eval: | |
return samples | |
samples = deepcopy(samples) | |
for condition_type, ps in self.p.items(): # for condition types [text, wav] | |
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) | |
if torch.rand(1, generator=self.rng).item() < p: | |
for sample in samples: | |
dropout_condition(sample, condition_type, condition) | |
return samples | |
def __repr__(self): | |
return f"AttributeDropout({dict(self.p)})" | |
class ClassifierFreeGuidanceDropout(DropoutModule): | |
"""Classifier Free Guidance dropout. | |
All attributes are dropped with the same probability. | |
Args: | |
p (float): Probability to apply condition dropout during training. | |
seed (int): Random seed. | |
""" | |
def __init__(self, p: float, seed: int = 1234): | |
super().__init__(seed=seed) | |
self.p = p | |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: | |
""" | |
Args: | |
samples (list[ConditioningAttributes]): List of conditions. | |
Returns: | |
list[ConditioningAttributes]: List of conditions after all attributes were set to None. | |
""" | |
if not self.training: | |
return samples | |
# decide on which attributes to drop in a batched fashion | |
drop = torch.rand(1, generator=self.rng).item() < self.p | |
if not drop: | |
return samples | |
# nullify conditions of all attributes | |
samples = deepcopy(samples) | |
for condition_type in ["wav", "text", "beat", "chord"]: | |
for sample in samples: | |
for condition in sample.attributes[condition_type]: | |
dropout_condition(sample, condition_type, condition) | |
return samples | |
def __repr__(self): | |
return f"ClassifierFreeGuidanceDropout(p={self.p})" | |
class ConditioningProvider(nn.Module): | |
"""Prepare and provide conditions given all the supported conditioners. | |
Args: | |
conditioners (dict): Dictionary of conditioners. | |
device (torch.device or str, optional): Device for conditioners and output condition types. | |
""" | |
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): | |
super().__init__() | |
self.device = device | |
self.conditioners = nn.ModuleDict(conditioners) | |
def joint_embed_conditions(self): | |
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] | |
def has_joint_embed_conditions(self): | |
return len(self.joint_embed_conditions) > 0 | |
def text_conditions(self): | |
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] | |
def wav_conditions(self): | |
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] | |
def beat_conditions(self): | |
return [k for k, v in self.conditioners.items() if isinstance(v, BeatConditioner)] | |
def chord_conditions(self): | |
return [k for k, v in self.conditioners.items() if isinstance(v, ChordProgressionConditioner)] | |
def has_wav_condition(self): | |
return len(self.wav_conditions) > 0 | |
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: | |
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. | |
This should be called before starting any real GPU work to avoid synchronization points. | |
This will return a dict matching conditioner names to their arbitrary tokenized representations. | |
Args: | |
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing | |
text and wav conditions. | |
""" | |
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( | |
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", | |
f" but types were {set([type(x) for x in inputs])}" | |
) | |
output = {} | |
text = self._collate_text(inputs) | |
beats = self._collate_beats(inputs) | |
chords = self._collate_chords(inputs) | |
wavs = self._collate_wavs(inputs) | |
joint_embeds = self._collate_joint_embeds(inputs) | |
assert set(text.keys() | wavs.keys() | chords.keys() | beats.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( | |
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", | |
f"got {text.keys(), wavs.keys(), chords.keys(), beats.keys(), joint_embeds.keys()}" | |
) | |
for attribute, batch in chain(text.items(), wavs.items(), chords.items(), beats.items(), joint_embeds.items()): | |
output[attribute] = self.conditioners[attribute].tokenize(batch) | |
return output | |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: | |
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. | |
The output is for example: | |
{ | |
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), | |
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), | |
... | |
} | |
Args: | |
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. | |
""" | |
output = {} | |
for attribute, inputs in tokenized.items(): | |
condition, mask = self.conditioners[attribute](inputs) | |
output[attribute] = (condition, mask) | |
return output | |
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: | |
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys | |
are the attributes and the values are the aggregated input per attribute. | |
For example: | |
Input: | |
[ | |
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), | |
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), | |
] | |
Output: | |
{ | |
"genre": ["Rock", "Hip-hop"], | |
"description": ["A rock song with a guitar solo", "A hip-hop verse"] | |
} | |
Args: | |
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. | |
Returns: | |
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. | |
""" | |
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) | |
texts = [x.text for x in samples] | |
for text in texts: | |
for condition in self.text_conditions: | |
out[condition].append(text[condition]) | |
return out | |
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: | |
"""Generate a dict where the keys are attributes by which we fetch similar wavs, | |
and the values are Tensors of wavs according to said attributes. | |
*Note*: by the time the samples reach this function, each sample should have some waveform | |
inside the "wav" attribute. It should be either: | |
1. A real waveform | |
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) | |
3. A null waveform due to it being dropped in a dropout module (nullified by dropout) | |
Args: | |
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. | |
Returns: | |
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. | |
""" | |
wavs = defaultdict(list) | |
lengths = defaultdict(list) | |
sample_rates = defaultdict(list) | |
paths = defaultdict(list) | |
seek_times = defaultdict(list) | |
out: tp.Dict[str, WavCondition] = {} | |
for sample in samples: | |
for attribute in self.wav_conditions: | |
wav, length, sample_rate, path, seek_time = sample.wav[attribute] | |
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" | |
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" | |
# mono-channel conditioning | |
wav = wav.mean(1, keepdim=True) # [1, 1, T] | |
wavs[attribute].append(wav.flatten()) # [T] | |
lengths[attribute].append(length) | |
sample_rates[attribute].extend(sample_rate) | |
paths[attribute].extend(path) | |
seek_times[attribute].extend(seek_time) | |
# stack all wavs to a single tensor | |
for attribute in self.wav_conditions: | |
stacked_wav, _ = collate(wavs[attribute], dim=0) | |
out[attribute] = WavCondition( | |
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], | |
paths[attribute], seek_times[attribute]) | |
return out | |
def _collate_chords(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, ChordCondition]: | |
"""Generate a dict where the keys are attributes by which we fetch similar wavs, | |
and the values are Tensors of wavs according to said attributes. | |
*Note*: by the time the samples reach this function, each sample should have some waveform | |
inside the "wav" attribute. It should be either: | |
1. A real waveform | |
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) | |
3. A null waveform due to it being dropped in a dropout module (nullified by dropout) | |
Args: | |
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. | |
Returns: | |
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. | |
""" | |
chords = defaultdict(list) | |
lengths = defaultdict(list) | |
bpms = defaultdict(list) | |
paths = defaultdict(list) | |
seek_frames = defaultdict(list) | |
out: tp.Dict[str, ChordCondition] = {} | |
for sample in samples: # sample = ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...) | |
for attribute in self.chord_conditions: # self.chord_conditions = ['chord'] | |
chord, length, bpm, path, seek_frame = sample.chord[attribute] | |
assert chord.dim() == 3, f"Got chord with dim={chord.dim()}, but expected 3 [1, C, T]" | |
assert chord.size(0) == 1, f"Got chord [B, C, T] with shape={chord.shape}, but expected B == 1" | |
chords[attribute].append(chord.squeeze(0)) # [1, C, T] -> [N * [C, T]] | |
lengths[attribute].append(length) # [N, 1] | |
bpms[attribute].extend(bpm) # [N] | |
paths[attribute].extend(path) # [N] | |
seek_frames[attribute].extend(seek_frame) # [N] | |
# stack all chords to a single tensor | |
for attribute in self.chord_conditions: | |
stacked_chord, _ = collate(chords[attribute], dim=1) # tensor padded here | |
out[attribute] = ChordCondition( | |
stacked_chord, torch.cat(lengths[attribute]), bpms[attribute], | |
paths[attribute], seek_frames[attribute]) | |
# print(f"chords shape: {chords[attribute][0].shape}") | |
# print(f"stack chords shape: {stacked_chord.shape}") | |
return out | |
def _collate_beats(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, ChordCondition]: | |
"""Generate a dict where the keys are attributes by which we fetch similar wavs, | |
and the values are Tensors of wavs according to said attributes. | |
Args: | |
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. | |
Returns: | |
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. | |
""" | |
beats = defaultdict(list) | |
lengths = defaultdict(list) | |
bpms = defaultdict(list) | |
paths = defaultdict(list) | |
seek_frames = defaultdict(list) | |
out: tp.Dict[str, ChordCondition] = {} | |
for sample in samples: # sample = ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...) | |
for attribute in self.beat_conditions: # self.chord_conditions = ['chord'] | |
beat, length, bpm, path, seek_frame = sample.beat[attribute] | |
assert beat.dim() == 3, f"Got chord with dim={beat.dim()}, but expected 3 [1, C, T]" | |
assert beat.size(0) == 1, f"Got chord [B, C, T] with shape={beat.shape}, but expected B == 1" | |
beats[attribute].append(beat.squeeze(0)) # [1, C, T] -> [N * [C, T]] | |
lengths[attribute].append(length) # [N, 1] | |
bpms[attribute].extend(bpm) # [N] | |
paths[attribute].extend(path) # [N] | |
seek_frames[attribute].extend(seek_frame) # [N] | |
# stack all chords to a single tensor | |
for attribute in self.beat_conditions: | |
stacked_beat, _ = collate(beats[attribute], dim=1) # tensor padded here | |
out[attribute] = BeatCondition( | |
stacked_beat, torch.cat(lengths[attribute]), bpms[attribute], | |
paths[attribute], seek_frames[attribute]) | |
# print(f"chords shape: {chords[attribute][0].shape}") | |
# print(f"stack chords shape: {stacked_chord.shape}") | |
return out | |
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: | |
"""Generate a dict where the keys are attributes by which we compute joint embeddings, | |
and the values are Tensors of pre-computed embeddings and the corresponding text attributes. | |
Args: | |
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. | |
Returns: | |
A dictionary mapping an attribute name to joint embeddings. | |
""" | |
texts = defaultdict(list) | |
wavs = defaultdict(list) | |
lengths = defaultdict(list) | |
sample_rates = defaultdict(list) | |
paths = defaultdict(list) | |
seek_times = defaultdict(list) | |
channels: int = 0 | |
out = {} | |
for sample in samples: | |
for attribute in self.joint_embed_conditions: | |
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] | |
assert wav.dim() == 3 | |
if channels == 0: | |
channels = wav.size(1) | |
else: | |
assert channels == wav.size(1), "not all audio has same number of channels in batch" | |
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" | |
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] | |
wavs[attribute].append(wav) | |
texts[attribute].extend(text) | |
lengths[attribute].append(length) | |
sample_rates[attribute].extend(sample_rate) | |
paths[attribute].extend(path) | |
seek_times[attribute].extend(seek_time) | |
for attribute in self.joint_embed_conditions: | |
stacked_texts = texts[attribute] | |
stacked_paths = paths[attribute] | |
stacked_seek_times = seek_times[attribute] | |
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) | |
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) | |
stacked_sample_rates = sample_rates[attribute] | |
stacked_lengths = torch.cat(lengths[attribute]).to(self.device) | |
assert stacked_lengths.size(0) == stacked_wavs.size(0) | |
assert len(stacked_sample_rates) == stacked_wavs.size(0) | |
assert len(stacked_texts) == stacked_wavs.size(0) | |
out[attribute] = JointEmbedCondition( | |
text=stacked_texts, wav=stacked_wavs, | |
length=stacked_lengths, sample_rate=stacked_sample_rates, | |
path=stacked_paths, seek_time=stacked_seek_times) | |
return out | |
class ConditionFuser(StreamingModule): | |
"""Condition fuser handles the logic to combine the different conditions | |
to the actual model input. | |
Args: | |
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse | |
each condition. For example: | |
{ | |
"prepend": ["description"], | |
"sum": ["genre", "bpm"], | |
"cross": ["description"], | |
} | |
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. | |
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. | |
""" | |
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate", "concat"] | |
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, | |
cross_attention_pos_emb_scale: float = 1.0, in_attn: bool = False): | |
super().__init__() | |
assert all( | |
[k in self.FUSING_METHODS for k in fuse2cond.keys()] | |
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" | |
self.cross_attention_pos_emb = cross_attention_pos_emb | |
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale | |
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond | |
self.cond2fuse: tp.Dict[str, str] = {} | |
self.in_attn = in_attn | |
for fuse_method, conditions in fuse2cond.items(): | |
for condition in conditions: | |
if not condition in self.cond2fuse.keys(): | |
self.cond2fuse[condition] = [fuse_method] | |
else: | |
self.cond2fuse[condition].append(fuse_method) | |
def forward( | |
self, | |
input: torch.Tensor, | |
conditions: tp.Dict[str, ConditionType] | |
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
"""Fuse the conditions to the provided model input. | |
Args: | |
input (torch.Tensor): Transformer input. | |
conditions (dict[str, ConditionType]): Dict of conditions. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input | |
after the conditions have been fused. The second output tensor is the tensor | |
used for cross-attention or None if no cross attention inputs exist. | |
""" | |
B, T, _ = input.shape # [B, T, C] | |
if self.in_attn: | |
in_attn_cond = torch.zeros_like(input) | |
else: | |
in_attn_cond = None | |
if 'offsets' in self._streaming_state: | |
first_step = False | |
offsets = self._streaming_state['offsets'] | |
else: | |
first_step = True | |
offsets = torch.zeros(B, dtype=torch.long, device=input.device) | |
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ | |
f"given conditions contain unknown attributes for fuser, " \ | |
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" | |
cross_attention_output = None | |
for cond_type, (cond, cond_mask) in conditions.items(): | |
fuse_methods = self.cond2fuse[cond_type] | |
for op in fuse_methods: | |
if op == 'sum': | |
cond_sum = cond[:, offsets[0]:offsets[0]+T] | |
if cond_sum.shape[1] != 0: | |
if cond_sum.shape[1] < T: | |
cond_sum = F.pad(cond_sum, (0, 0, 0, T-cond_sum.shape[1]), "constant", 0) # pad last special token dim | |
input[:, -cond_sum.shape[1]:, :] = input[:, -cond_sum.shape[1]:, :] + cond_sum | |
if self.in_attn: | |
in_attn_cond += cond_sum | |
elif op == 'input_interpolate': | |
cond = einops.rearrange(cond, "b t d -> b d t") | |
cond = F.interpolate(cond, size=input.shape[1]) | |
input += einops.rearrange(cond, "b d t -> b t d") | |
elif op == 'prepend': | |
if cond_type == 'chord': | |
cond_prepend = torch.zeros(cond.shape[0], 235, cond.shape[2], device=cond.device) # original musicgen melody has 235 length chroma | |
if cond.shape[1] == 1500: # if condition not dropout | |
for i in range(235): | |
cond_prepend[:, i, :] = cond[:, round(i * (1500/235)), :] # n_frame of chord = 30*50 into 235 time steps | |
else: | |
cond_prepend = cond | |
if first_step: | |
input = torch.cat([cond_prepend, input], dim=1) | |
elif op == 'cross': | |
if cross_attention_output is not None: | |
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) | |
else: | |
cross_attention_output = cond | |
else: | |
raise ValueError(f"unknown op ({op})") | |
if self.cross_attention_pos_emb and cross_attention_output is not None: | |
positions = torch.arange( | |
cross_attention_output.shape[1], | |
device=cross_attention_output.device | |
).view(1, -1, 1) | |
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) | |
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb | |
if self._is_streaming: | |
self._streaming_state['offsets'] = offsets + T | |
return input, in_attn_cond, cross_attention_output |