Last commit not found
from dataclasses import dataclass | |
import logging | |
import math | |
import typing as tp | |
import torch | |
import torch.nn.functional as F | |
from audiocraft.transformer import StreamingTransformer | |
from dataclasses import dataclass | |
from functools import partial | |
from torch import nn | |
from audiocraft.activations import get_activation_fn | |
import numpy as np | |
def _shift(x): | |
# cyclic shift of [1, 4, seq_len] slices from [bs, 4, seq_len] | |
print(x.shape, 'SHIFT\n= = = = = ') | |
for i, _slice in enumerate(x): | |
n = x.shape[2] | |
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD | |
print(offset) | |
x[i, :, :] = torch.roll(_slice, offset, dims=1) | |
return x | |
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): | |
"""LM layer initialization. | |
Inspired from xlformers: https://github.com/fairinternal/xlformers | |
Args: | |
method (str): Method name for init function. Valid options are: | |
'gaussian', 'uniform'. | |
input_dim (int): Input dimension of the initialized module. | |
init_depth (int, optional): Optional init depth value used to rescale | |
the standard deviation if defined. | |
""" | |
# Compute std | |
std = 1 / math.sqrt(input_dim) | |
# Rescale with depth | |
if init_depth is not None: | |
std = std / math.sqrt(2 * init_depth) | |
if method == 'gaussian': | |
return partial( | |
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std | |
) | |
elif method == 'uniform': | |
bound = math.sqrt(3) * std # ensure the standard deviation is `std` | |
return partial(torch.nn.init.uniform_, a=-bound, b=bound) | |
else: | |
raise ValueError("Unsupported layer initialization method") | |
def init_layer(m: nn.Module, | |
method: str, | |
init_depth: tp.Optional[int] = None, | |
zero_bias_init: bool = False): | |
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules. | |
Args: | |
m (nn.Module): Module to initialize. | |
method (str): Method name for the init function. | |
init_depth (int, optional): Optional init depth value used to rescale | |
the standard deviation if defined. | |
zero_bias_init (bool): Whether to initialize the bias to 0 or not. | |
""" | |
if isinstance(m, nn.Linear): | |
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) | |
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: | |
weight = m.weight.float() | |
init_fn(weight) | |
m.weight.data[:] = weight.half() | |
else: | |
init_fn(m.weight) | |
if zero_bias_init and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Embedding): | |
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) | |
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: | |
weight = m.weight.float() | |
init_fn(weight) | |
m.weight.data[:] = weight.half() | |
else: | |
init_fn(m.weight) | |
class ScaledEmbedding(nn.Embedding): | |
"""Boost learning rate for embeddings (with `scale`). | |
""" | |
def __init__(self, *args, lr=None, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.lr = lr | |
def make_optim_group(self): | |
group = {"params": list(self.parameters())} | |
if self.lr is not None: | |
group["lr"] = self.lr | |
return group | |
class LMOutput: | |
# The logits are already re-aligned with the input codes | |
# hence no extra shift is required, e.g. when computing CE | |
logits: torch.Tensor # [B, K, T, card] | |
mask: torch.Tensor # [B, K, T] | |
class LMModel(nn.Module): | |
"""Transformer-based language model on multiple streams of codes. | |
Args: | |
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. | |
condition_provider (MusicConditioningProvider): Conditioning provider from metadata. | |
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. | |
n_q (int): Number of parallel streams to model. | |
card (int): Cardinality, vocabulary size. | |
dim (int): Dimension of the transformer encoder. | |
num_heads (int): Number of heads for the transformer encoder. | |
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. | |
norm (str): Normalization method. | |
norm_first (bool): Use pre-norm instead of post-norm. | |
emb_lr (float, optional): Embedding-specific learning rate. | |
bias_proj (bool): Use bias for output projections. | |
weight_init (str, optional): Method for weight initialization. | |
depthwise_init (str, optional): Method for depthwise weight initialization. | |
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. | |
cfg_dropout (float): Classifier-free guidance dropout. | |
cfg_coef (float): Classifier-free guidance coefficient. | |
attribute_dropout (dict): Attribute dropout probabilities. | |
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. | |
**kwargs: Additional parameters for the transformer encoder. | |
""" | |
def __init__(self, | |
pattern_provider, | |
condition_provider, | |
n_q: int = 8, | |
card: int = 1024, | |
dim: int = 128, | |
num_heads: int = 8, | |
hidden_scale: int = 4, | |
norm: str = 'layer_norm', | |
norm_first: bool = False, | |
emb_lr: tp.Optional[float] = None, | |
bias_proj: bool = True, | |
weight_init: tp.Optional[str] = None, | |
depthwise_init: tp.Optional[str] = None, | |
zero_bias_init: bool = False, cfg_dropout: float = 0, | |
cfg_coef: float = 1.0, | |
two_step_cfg: bool = False, | |
**kwargs): | |
super().__init__() | |
self.cfg_coef = cfg_coef | |
self.condition_provider = condition_provider | |
self.card = card # 2048 ? | |
self.n_draw = 1 # replicate so many times the generation of each text in batch | |
embed_dim = self.card + 1 | |
self.n_q = n_q | |
self.dim = dim | |
self.pattern_provider = pattern_provider | |
self.two_step_cfg = two_step_cfg | |
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) | |
if 'activation' in kwargs: | |
kwargs['activation'] = get_activation_fn(kwargs['activation']) | |
# ======================================================================== | |
# { | |
# 'dtype': torch.float16, 'device': 'cuda', | |
# 'num_layers': 48, 'dropout': 0.0, 'activation': 'gelu', | |
# 'bias_ff': False, 'bias_attn': False, | |
# 'past_context': None, 'causal': True, | |
# 'custom': False, 'memory_efficient': True, | |
# 'attention_as_float32': False, 'positional_embedding': 'sin', 'xpos': False, | |
# 'checkpointing': 'none', 'cross_attention': True, 'qk_layer_norm': False, | |
# 'qk_layer_norm_cross': False, 'attention_dropout': None, 'kv_repeat': 1 | |
# } | |
# ========================================================================== | |
kwargs.pop('layer_scale') # nn.Indentity() | |
self.transformer = StreamingTransformer( | |
d_model=dim, | |
num_heads=num_heads, | |
dim_feedforward=int(hidden_scale * dim), | |
norm=norm, | |
norm_first=norm_first, **kwargs) | |
self.out_norm: tp.Optional[nn.Module] = None | |
if norm_first: | |
self.out_norm = nn.LayerNorm(dim, eps=1e-5) | |
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) | |
self._init_weights(weight_init, depthwise_init, zero_bias_init) | |
self._fsdp: tp.Optional[nn.Module] | |
self.__dict__['_fsdp'] = None | |
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): | |
"""Initialization of the transformer module weights. | |
Args: | |
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. | |
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: | |
'current' where the depth corresponds to the current layer index or 'global' where the total number | |
of layer is used as depth. If not set, no depthwise initialization strategy is used. | |
zero_bias_init (bool): Whether to initialize bias to zero or not. | |
""" | |
assert depthwise_init is None or depthwise_init in ['current', 'global'] | |
assert depthwise_init is None or weight_init is not None, \ | |
"If 'depthwise_init' is defined, a 'weight_init' method should be provided." | |
assert not zero_bias_init or weight_init is not None, \ | |
"If 'zero_bias_init', a 'weight_init' method should be provided" | |
if weight_init is None: | |
return | |
for emb_layer in self.emb: | |
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) | |
for layer_idx, tr_layer in enumerate(self.transformer.layers): | |
depth = None | |
if depthwise_init == 'current': | |
depth = layer_idx + 1 | |
elif depthwise_init == 'global': | |
depth = len(self.transformer.layers) | |
init_fn = partial(init_layer, | |
method=weight_init, | |
init_depth=depth, | |
zero_bias_init=zero_bias_init) | |
tr_layer.apply(init_fn) | |
for linear in self.linears: | |
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) | |
def special_token_id(self) -> int: | |
return self.card | |
def forward(self, | |
sequence, | |
condition_tensors=None, | |
token_count=None): | |
# takes bs=3 duplicates null condition to bs=6 splits logits to cfg returns bs=3 | |
bs, _, _ = sequence.shape # sequence [bs, n_draw,4] | |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)]) | |
out = self.transformer(torch.cat([input_, input_], 0), | |
cross_attention_src=condition_tensors, | |
token_count=token_count) | |
if self.out_norm: | |
out = self.out_norm(out) | |
logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048] | |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048] | |
# SAMPLE TOP K | |
k = 250 | |
p = torch.softmax(logits, dim=3) | |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k] | |
min_value_top_k = top_k_value[:, :, :, -1:] | |
p *= (p >= min_value_top_k).float() # zero low probs | |
p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs | |
# BRING THE nq = 4 IN BATCH | |
p = p.reshape(bs * self.n_q, 2048) | |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples] | |
num_samples=self.n_draw, | |
replacement=True) # [bs*4, self.n_draw] | |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4] | |
def generate(self, | |
descriptions = ['windy day', 'rain storm'], | |
max_gen_len = 256): | |
text_condition = self.condition_provider(descriptions) | |
# NULL CONDITION | |
# text_condition = cfg_conditions['description'][0] | |
bs, _, _ = text_condition.shape | |
text_condition = torch.cat( | |
[ | |
text_condition, | |
torch.zeros_like(text_condition) | |
], 0) | |
pattern = self.pattern_provider.get_pattern(max_gen_len) | |
gen_codes = torch.full((bs, | |
self.n_q, | |
max_gen_len), -1, dtype=torch.long, | |
device=text_condition.device) | |
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) | |
_, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur | |
# print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration] | |
# print(f'{mask=}') | |
# | |
# torch.Size([3, 4, 7]) torch.Size([4, 7]) F | |
# mask=tensor([[False, True, True, True, False, False, False], | |
# [False, False, True, True, True, False, False], | |
# [False, False, False, True, True, True, False], | |
# [False, False, False, False, True, True, True]], device='cuda:0') | |
mask = mask[None, None, :, :].repeat(bs, self.n_draw, 1, 1) # [bs, n_draw, 4, audio duration] | |
gen_sequence = gen_sequence[:, None, :, :].repeat(1, self.n_draw, 1, 1) # bs,n_draw,4,dur | |
for offset in range(1, audiodur): | |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token | |
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset], | |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ? | |
token_count=offset-1) # [bs, 4, 1, 2048] | |
# MASK is not full 1---- HAS 4 x audioduration PATTERN | |
m = mask[:, :, :, offset] | |
next_token[~m] = self.special_token_id | |
gen_sequence[:, :, :, offset] = torch.where( | |
gen_sequence[:, :, :, offset] == -1, #unknown_token, | |
next_token, | |
gen_sequence[:, :, :, offset] | |
) | |
# 1. reshape n_draw as bs * n_draw | |
# 2. invert all short-sequences | |
# 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION | |
out_codes, _, _ = pattern.revert_pattern_sequence( | |
gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7] | |
special_token=-1) | |
# print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION? | |
_, _, new_len = out_codes.shape # 4 IS PRESERVED AFTER REVERT! | |
out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len) | |
out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len) | |
print(out_codes.shape, 'o') | |
for _ in range(7): | |
out_codes = _shift(out_codes) | |
# Clear Transformer k/v history (Different history is kept by 48x selfattn) | |
for lay in self.transformer.layers: | |
lay.self_attn.k_history = None | |
lay.self_attn.v_history = None | |
return out_codes | |