MusicLM / musiclm_pytorch.py
Gertie01's picture
Update musiclm_pytorch.py
e77217f
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
from audiolm_pytorch import AudioLM
from audiolm_pytorch.utils import AudioConditionerBase
from x_clip.tokenizer import tokenizer
from vector_quantize_pytorch import ResidualVQ
from einops import rearrange, repeat, reduce, pack, unpack
from beartype.typing import List, Optional, Tuple
from beartype import beartype
# functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def round_down_nearest_multiple(n, divisor):
return n // divisor * divisor
# tensor functions
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, p = 2, dim = -1)
# 2d sinusoidal positional embedding
# simple vit paper shows it is good enough compared to learned
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
# biasless layernorm
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# feedforward
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return F.gelu(gate) * x
def FeedForward(dim, mult = 4, dropout = 0.):
dim_hidden = int(dim * mult * 2 / 3)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim_hidden * 2, bias = False),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim_hidden, dim, bias = False)
)
# attention
class Attention(nn.Module):
def __init__(
self,
dim,
causal = False,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.attn_dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
mask = None
):
b, n, _, device = *x.shape, x.device
# prenorm
x = self.norm(x)
# project for queries, keys, values
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
# split for multi-headed attention
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
q = q * self.scale
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# transformer
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return x
# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778
def pair(t):
return (t, t) if not isinstance(t, tuple) else t
class AudioSpectrogramTransformer(nn.Module):
def __init__(
self,
dim,
depth,
patch_size = 16,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
spec_n_fft = 128,
spec_power = 2,
spec_win_length = 24,
spec_hop_length = None,
spec_pad = 0,
spec_center = True,
spec_pad_mode = 'reflect',
spec_aug_stretch_factor = 0.8,
spec_aug_freq_mask = 80,
spec_aug_time_mask = 80
):
super().__init__()
self.dim = dim
self.patch_size = pair(patch_size)
self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1)
self.spec = Spectrogram(
n_fft = spec_n_fft,
power = spec_power,
win_length = spec_win_length,
hop_length = spec_hop_length,
pad = spec_pad,
center = spec_center,
pad_mode = spec_pad_mode
)
# SpecAugment - seems to be widely used in audio field https://arxiv.org/abs/1904.08779
self.aug = torch.nn.Sequential(
TimeStretch(spec_aug_stretch_factor, fixed_rate=True),
FrequencyMasking(freq_mask_param = spec_aug_freq_mask),
TimeMasking(time_mask_param = spec_aug_time_mask),
)
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_mult = ff_mult,
ff_dropout = ff_dropout
)
self.norm = LayerNorm(dim)
def forward(self, x):
x = self.spec(x)
if self.training:
x = self.aug(x)
# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
height, width = x.shape[-2:]
patch_height, patch_width = self.patch_size
rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))
if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed
print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
x = x[..., :rounded_height, :rounded_width]
# to patches
x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width)
x = self.to_patch_tokens(x)
# 2d sinusoidal positional embedding
x = rearrange(x, 'b c h w -> b h w c')
x = x + posemb_sincos_2d(x)
# attention, what else
x = rearrange(x, 'b ... c -> b (...) c')
x = self.transformer(x)
# final global average and norm (most recent papers show this is superior to CLS token)
x = reduce(x, 'b n d -> b d', 'mean')
return self.norm(x)
# text transformer
@beartype
class TextTransformer(nn.Module):
def __init__(
self,
dim,
depth,
num_tokens = tokenizer.vocab_size,
max_seq_len = 256,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_dropout = 0.,
ff_mult = 4,
pad_id = 0
):
super().__init__()
self.dim = dim
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.cls_token = nn.Parameter(torch.randn(dim))
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult
)
self.pad_id = pad_id
self.norm = LayerNorm(dim)
def forward(
self,
x = None,
raw_texts: Optional[List[str]] = None,
mask = None
):
assert exists(x) ^ exists(raw_texts)
if exists(raw_texts):
x = tokenizer.tokenize(raw_texts)
if not exists(mask):
mask = x != self.pad_id
b, n, device = *x.shape, x.device
# token embedding + positional embedding
x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(n, device = device))
# cls tokens, as in bert
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
x, ps = pack([cls_tokens, x], 'b * d')
# account for attending to cls token with self attention mask
mask = F.pad(mask, (1, 0), value = True)
# attention
x = self.transformer(x, mask = mask)
# unpack the cls tokens
cls_tokens, _ = unpack(x, ps, 'b * d')
return self.norm(cls_tokens)
# main classes
@beartype
class MuLaN(nn.Module):
def __init__(
self,
audio_transformer: AudioSpectrogramTransformer,
text_transformer: TextTransformer,
dim_latent = 128, # they use 128
decoupled_contrastive_learning = True, # think this was used, make it optional
):
super().__init__()
self.dim_latent = dim_latent
self.audio = audio_transformer
self.text = text_transformer
self.temperature = nn.Parameter(torch.tensor(1.))
self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)
self.decoupled_contrastive_learning = decoupled_contrastive_learning
def get_audio_latents(
self,
wavs
):
audio_embeds = self.audio(wavs)
audio_latents = self.audio_to_latents(audio_embeds)
return l2norm(audio_latents)
def get_text_latents(
self,
texts = None,
raw_texts: Optional[List[str]] = None
):
text_embeds = self.text(texts)
text_latents = self.text_to_latents(text_embeds)
return l2norm(text_latents)
def forward(
self,
wavs,
texts = None,
raw_texts: Optional[List[str]] = None,
return_similarities = False
):
batch, device = wavs.shape[0], wavs.device
audio_latents = self.get_audio_latents(wavs)
text_latents = self.get_text_latents(texts, raw_texts = raw_texts)
cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'
if return_similarities:
return cosine_sim
cosine_sim = cosine_sim * self.temperature.exp()
cosine_sim_exp = cosine_sim.exp()
numerator = cosine_sim_exp.diag()
if self.decoupled_contrastive_learning:
eye = torch.eye(batch, device = device)
cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)
denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
contrastive_loss = -log(numerator / denominator)
return contrastive_loss.mean()
# music lm
@beartype
class MuLaNEmbedQuantizer(AudioConditionerBase):
def __init__(
self,
mulan: MuLaN,
conditioning_dims: Tuple[int, ...],
rq_num_quantizers = 8,
rq_ema_decay = 0.9,
codebook_size = 1024,
namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),
):
super().__init__()
self.mulan = mulan
assert len(namespaces) > 0
self.namespaces = namespaces
self.conditioning_dims = conditioning_dims
assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'
dim = mulan.dim_latent
self.rq = ResidualVQ(
dim = dim,
num_quantizers = rq_num_quantizers,
codebook_size = codebook_size,
decay = rq_ema_decay,
commitment_weight = 0, # only use EMA to update codebooks
kmeans_init = True,
threshold_ema_dead_code = 2,
quantize_dropout = False # no quantize dropout
)
self.dim = dim
self.num_codebooks = rq_num_quantizers
self.cond_embeddings = nn.ParameterDict({})
for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
nn.init.normal_(cond_embeddings, std = 0.02)
self.cond_embeddings[namespace] = cond_embeddings
self.set_default_namespace(namespaces[0])
def parameters(self):
return self.cond_embeddings.parameters()
def set_default_namespace(self, namespace):
self._default_namespace = namespace
def forward(
self,
wavs = None,
texts = None,
namespace = None
):
assert exists(wavs) ^ exists(texts)
namespace = default(namespace, self._default_namespace)
assert namespace in self.namespaces, f'namespace {namespace} not found'
cond_embeddings = self.cond_embeddings[namespace]
with torch.no_grad():
self.mulan.eval()
# sound and language live in joint embedding space because of contrastive learning
if exists(wavs):
latents = self.mulan.get_audio_latents(wavs)
elif exists(texts):
latents = self.mulan.get_text_latents(texts)
_, indices, _ = self.rq(latents)
batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]
cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)
indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)
cond_embeddings = cond_embeddings.gather(2, indices)
return rearrange(cond_embeddings, 'b q 1 d -> b q d')
@beartype
class MusicLM(nn.Module):
def __init__(
self,
audio_lm: AudioLM,
mulan_embed_quantizer: MuLaNEmbedQuantizer
):
super().__init__()
assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis'
self.mulan_embed_quantizer = mulan_embed_quantizer
self.audio_lm = audio_lm
@torch.no_grad()
def forward(
self,
raw_texts: List[str],
**audio_lm_kwargs
):
self.eval()
texts = tokenizer.tokenize(raw_texts)
text_embeds = self.mulan_embed_quantizer(texts = texts)
return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)