Spaces:
Sleeping
Sleeping
# ---------------------------------------------------------------------------- | |
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
""" | |
We just merge all the required modules and functions into one python file. | |
It is for easily use the pre-trained model to extract features. | |
""" | |
import math | |
import numpy as np | |
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import Parameter | |
from torch import Tensor | |
from typing import Any, Dict, List, Tuple, Callable, Optional | |
logger = logging.getLogger(__name__) | |
# rewrite name for backward compatibility in `make_generation_fast_` | |
def module_name_fordropout(module_name: str) -> str: | |
if module_name == "TransformerEncoderBase": | |
return "TransformerEncoder" | |
else: | |
return module_name | |
def utils_make_positions(tensor, padding_idx: int, onnx_trace: bool = False): | |
"""Replace non-padding symbols with their position numbers. | |
Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
""" | |
# The series of casts and type-conversions here are carefully | |
# balanced to both work with ONNX export and XLA. In particular XLA | |
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
# how to handle the dtype kwarg in cumsum. | |
mask = tensor.ne(padding_idx).int() | |
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx | |
def utils_item(tensor): | |
# tpu-comment: making this a no-op for xla devices. | |
if torch.is_tensor(tensor) and tensor.device.type == "xla": | |
return tensor.detach() | |
if hasattr(tensor, "item"): | |
return tensor.item() | |
if hasattr(tensor, "__getitem__"): | |
return tensor[0] | |
return tensor | |
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): | |
""" | |
Helper to wrap layers/modules in FSDP. This falls back to a no-op if | |
fairscale is not available. | |
Args: | |
module (nn.Module): module to (maybe) wrap | |
min_num_params (int, Optional): minimum number of layer params to wrap | |
""" | |
try: | |
from fairscale.nn import wrap | |
if min_num_params is not None: | |
num_params = sum(p.numel() for p in module.parameters()) | |
if num_params >= min_num_params: | |
return wrap(module, **kwargs) | |
else: | |
return module | |
else: | |
return wrap(module, **kwargs) | |
except ImportError: | |
return module | |
def quant_noise(module, p, block_size): | |
""" | |
Wraps modules and applies quantization noise to the weights for | |
subsequent quantization with Iterative Product Quantization as | |
described in "Training with Quantization Noise for Extreme Model Compression" | |
Args: | |
- module: nn.Module | |
- p: amount of Quantization Noise | |
- block_size: size of the blocks for subsequent quantization with iPQ | |
Remarks: | |
- Module weights must have the right sizes wrt the block size | |
- Only Linear, Embedding and Conv2d modules are supported for the moment | |
- For more detail on how to quantize by blocks with convolutional weights, | |
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" | |
- We implement the simplest form of noise here as stated in the paper | |
which consists in randomly dropping blocks | |
""" | |
# if no quantization noise, don't register hook | |
if p <= 0: | |
return module | |
# supported modules | |
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) | |
# test whether module.weight has the right sizes wrt block_size | |
is_conv = module.weight.ndim == 4 | |
# 2D matrix | |
if not is_conv: | |
assert ( | |
module.weight.size(1) % block_size == 0 | |
), "Input features must be a multiple of block sizes" | |
# 4D matrix | |
else: | |
# 1x1 convolutions | |
if module.kernel_size == (1, 1): | |
assert ( | |
module.in_channels % block_size == 0 | |
), "Input channels must be a multiple of block sizes" | |
# regular convolutions | |
else: | |
k = module.kernel_size[0] * module.kernel_size[1] | |
assert k % block_size == 0, "Kernel size must be a multiple of block size" | |
def _forward_pre_hook(mod, input): | |
# no noise for evaluation | |
if mod.training: | |
if not is_conv: | |
# gather weight and sizes | |
weight = mod.weight | |
in_features = weight.size(1) | |
out_features = weight.size(0) | |
# split weight matrix into blocks and randomly drop selected blocks | |
mask = torch.zeros( | |
in_features // block_size * out_features, device=weight.device | |
) | |
mask.bernoulli_(p) | |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) | |
else: | |
# gather weight and sizes | |
weight = mod.weight | |
in_channels = mod.in_channels | |
out_channels = mod.out_channels | |
# split weight matrix into blocks and randomly drop selected blocks | |
if mod.kernel_size == (1, 1): | |
mask = torch.zeros( | |
int(in_channels // block_size * out_channels), | |
device=weight.device, | |
) | |
mask.bernoulli_(p) | |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) | |
else: | |
mask = torch.zeros( | |
weight.size(0), weight.size(1), device=weight.device | |
) | |
mask.bernoulli_(p) | |
mask = ( | |
mask.unsqueeze(2) | |
.unsqueeze(3) | |
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) | |
) | |
# scale weights and apply mask | |
mask = mask.to( | |
torch.bool | |
) # x.bool() is not currently supported in TorchScript | |
s = 1 / (1 - p) | |
mod.weight.data = s * weight.masked_fill(mask, 0) | |
module.register_forward_pre_hook(_forward_pre_hook) | |
return module | |
def relu_squared(x: torch.Tensor): | |
return F.relu(x).pow(2) | |
def gelu(x: torch.Tensor) -> torch.Tensor: | |
return torch.nn.functional.gelu(x.float()).type_as(x) | |
def gelu_accurate(x): | |
if not hasattr(gelu_accurate, "_a"): | |
gelu_accurate._a = math.sqrt(2 / math.pi) | |
return ( | |
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) | |
) | |
def get_activation_fn(activation: str) -> Callable: | |
"""Returns the activation function corresponding to `activation`""" | |
if activation == "relu": | |
return F.relu | |
elif activation == "relu_squared": | |
return relu_squared | |
elif activation == "gelu": | |
return gelu | |
elif activation == "gelu_fast": | |
logger.warn( | |
"--activation-fn=gelu_fast has been renamed to gelu_accurate" | |
) | |
return gelu_accurate | |
elif activation == "gelu_accurate": | |
return gelu_accurate | |
elif activation == "tanh": | |
return torch.tanh | |
elif activation == "linear": | |
return lambda x: x | |
elif activation == "swish": | |
return torch.nn.SiLU | |
else: | |
raise RuntimeError("--activation-fn {} not supported".format(activation)) | |
def softmax(x, dim: int, onnx_trace: bool = False): | |
if onnx_trace: | |
return F.softmax(x.float(), dim=dim) | |
else: | |
return F.softmax(x, dim=dim, dtype=torch.float32) | |
def compute_mask_indices( | |
shape: Tuple[int, int], | |
padding_mask: Optional[torch.Tensor], | |
mask_prob: float, | |
mask_length: int, | |
mask_type: str = "static", | |
mask_other: float = 0.0, | |
min_masks: int = 0, | |
no_overlap: bool = False, | |
min_space: int = 0, | |
require_same_masks: bool = True, | |
mask_dropout: float = 0.0, | |
) -> np.ndarray: | |
""" | |
Computes random mask spans for a given shape | |
Args: | |
shape: the the shape for which to compute masks. | |
should be of size 2 where first element is batch size and 2nd is timesteps | |
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements | |
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by | |
number of timesteps divided by length of mask span to mask approximately this percentage of all elements. | |
however due to overlaps, the actual number will be smaller (unless no_overlap is True) | |
mask_type: how to compute mask lengths | |
static = fixed size | |
uniform = sample from uniform distribution [mask_other, mask_length*2] | |
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element | |
poisson = sample from possion distribution with lambda = mask length | |
min_masks: minimum number of masked spans | |
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping | |
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans | |
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample | |
mask_dropout: randomly dropout this percentage of masks in each example | |
""" | |
bsz, all_sz = shape | |
mask = np.full((bsz, all_sz), False) | |
all_num_mask = int( | |
# add a random number for probabilistic rounding | |
mask_prob * all_sz / float(mask_length) | |
+ np.random.rand() | |
) | |
all_num_mask = max(min_masks, all_num_mask) | |
mask_idcs = [] | |
for i in range(bsz): | |
if padding_mask is not None: | |
sz = all_sz - padding_mask[i].long().sum().item() | |
num_mask = int( | |
# add a random number for probabilistic rounding | |
mask_prob * sz / float(mask_length) | |
+ np.random.rand() | |
) | |
num_mask = max(min_masks, num_mask) | |
else: | |
sz = all_sz | |
num_mask = all_num_mask | |
if mask_type == "static": | |
lengths = np.full(num_mask, mask_length) | |
elif mask_type == "uniform": | |
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) | |
elif mask_type == "normal": | |
lengths = np.random.normal(mask_length, mask_other, size=num_mask) | |
lengths = [max(1, int(round(x))) for x in lengths] | |
elif mask_type == "poisson": | |
lengths = np.random.poisson(mask_length, size=num_mask) | |
lengths = [int(round(x)) for x in lengths] | |
else: | |
raise Exception("unknown mask selection " + mask_type) | |
if sum(lengths) == 0: | |
lengths[0] = min(mask_length, sz - 1) | |
if no_overlap: | |
mask_idc = [] | |
def arrange(s, e, length, keep_length): | |
span_start = np.random.randint(s, e - length) | |
mask_idc.extend(span_start + i for i in range(length)) | |
new_parts = [] | |
if span_start - s - min_space >= keep_length: | |
new_parts.append((s, span_start - min_space + 1)) | |
if e - span_start - keep_length - min_space > keep_length: | |
new_parts.append((span_start + length + min_space, e)) | |
return new_parts | |
parts = [(0, sz)] | |
min_length = min(lengths) | |
for length in sorted(lengths, reverse=True): | |
lens = np.fromiter( | |
(e - s if e - s >= length + min_space else 0 for s, e in parts), | |
np.int, | |
) | |
l_sum = np.sum(lens) | |
if l_sum == 0: | |
break | |
probs = lens / np.sum(lens) | |
c = np.random.choice(len(parts), p=probs) | |
s, e = parts.pop(c) | |
parts.extend(arrange(s, e, length, min_length)) | |
mask_idc = np.asarray(mask_idc) | |
else: | |
min_len = min(lengths) | |
if sz - min_len <= num_mask: | |
min_len = sz - num_mask - 1 | |
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) | |
mask_idc = np.asarray( | |
[ | |
mask_idc[j] + offset | |
for j in range(len(mask_idc)) | |
for offset in range(lengths[j]) | |
] | |
) | |
mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) | |
min_len = min([len(m) for m in mask_idcs]) | |
for i, mask_idc in enumerate(mask_idcs): | |
if len(mask_idc) > min_len and require_same_masks: | |
mask_idc = np.random.choice(mask_idc, min_len, replace=False) | |
if mask_dropout > 0: | |
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) | |
mask_idc = np.random.choice( | |
mask_idc, len(mask_idc) - num_holes, replace=False | |
) | |
mask[i, mask_idc] = True | |
return mask | |
def init_bert_params(module): | |
""" | |
Initialize the weights specific to the BERT Model. | |
This overrides the default initializations depending on the specified arguments. | |
1. If normal_init_linear_weights is set then weights of linear | |
layer will be initialized using the normal distribution and | |
bais will be set to the specified value. | |
2. If normal_init_embed_weights is set then weights of embedding | |
layer will be initialized using the normal distribution. | |
3. If normal_init_proj_weights is set then weights of | |
in_project_weight for MultiHeadAttention initialized using | |
the normal distribution (to be validated). | |
""" | |
def normal_(data): | |
# with FSDP, module params will be on CUDA, so we cast them back to CPU | |
# so that the RNG is consistent with and without FSDP | |
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) | |
if isinstance(module, nn.Linear): | |
normal_(module.weight.data) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
if isinstance(module, nn.Embedding): | |
normal_(module.weight.data) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
if isinstance(module, MultiheadAttention): | |
normal_(module.q_proj.weight.data) | |
normal_(module.k_proj.weight.data) | |
normal_(module.v_proj.weight.data) | |
def pad_to_multiple(x, multiple, dim=-1, value=0): | |
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 | |
if x is None: | |
return None, 0 | |
tsz = x.size(dim) | |
m = tsz / multiple | |
remainder = math.ceil(m) * multiple - tsz | |
if m.is_integer(): | |
return x, 0 | |
pad_offset = (0,) * (-1 - dim) * 2 | |
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder | |
def is_xla_tensor(tensor): | |
return torch.is_tensor(tensor) and tensor.device.type == "xla" | |
def index_put(tensor, indices, value): | |
if is_xla_tensor(tensor): | |
for _ in range(indices.dim(), tensor.dim()): | |
indices = indices.unsqueeze(-1) | |
if indices.size(-1) < tensor.size(-1): | |
indices = indices.expand_as(tensor) | |
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) | |
else: | |
tensor[indices] = value | |
return tensor | |
def PositionalEmbedding( | |
num_embeddings: int, | |
embedding_dim: int, | |
padding_idx: int, | |
learned: bool = False, | |
): | |
if learned: | |
# if padding_idx is specified then offset the embedding ids by | |
# this index and adjust num_embeddings appropriately | |
# TODO: The right place for this offset would be inside | |
# LearnedPositionalEmbedding. Move this there for a cleaner implementation. | |
if padding_idx is not None: | |
num_embeddings = num_embeddings + padding_idx + 1 | |
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) | |
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
if padding_idx is not None: | |
nn.init.constant_(m.weight[padding_idx], 0) | |
else: | |
m = SinusoidalPositionalEmbedding( | |
embedding_dim, | |
padding_idx, | |
init_size=num_embeddings + padding_idx + 1, | |
) | |
return m | |
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): | |
if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
export = True | |
if not export and torch.cuda.is_available() and has_fused_layernorm: | |
return FusedLayerNorm(normalized_shape, eps, elementwise_affine) | |
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | |
class TransformerEncoderBase(nn.Module): | |
""" | |
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer | |
is a :class:`TransformerEncoderLayer`. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
dictionary: deprecated(None) | |
embed_tokens (torch.nn.Embedding): input embedding | |
""" | |
def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0): | |
self.cfg = cfg | |
super().__init__() | |
self.register_buffer("version", torch.Tensor([3])) | |
self.dropout_module = FairseqDropout( | |
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) | |
) | |
self.encoder_layerdrop = cfg.encoder.layerdrop | |
embed_dim = embed_tokens.embedding_dim if embed_tokens is not None else cfg.encoder.embed_dim | |
self.padding_idx = embed_tokens.padding_idx if embed_tokens is not None else 1 | |
self.max_source_positions = cfg.max_source_positions | |
self.embed_tokens = embed_tokens | |
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) | |
self.embed_positions = ( | |
PositionalEmbedding( | |
cfg.max_source_positions, | |
embed_dim, | |
self.padding_idx, | |
learned=cfg.encoder.learned_pos, | |
) | |
if not cfg.no_token_positional_embeddings | |
else None | |
) | |
if cfg.layernorm_embedding: | |
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) | |
else: | |
self.layernorm_embedding = None | |
if not cfg.adaptive_input and cfg.quant_noise.pq > 0: | |
self.quant_noise = quant_noise( | |
nn.Linear(embed_dim, embed_dim, bias=False), | |
cfg.quant_noise.pq, | |
cfg.quant_noise.pq_block_size, | |
) | |
else: | |
self.quant_noise = None | |
if self.encoder_layerdrop > 0.0: | |
self.layers = LayerDropModuleList(p=self.encoder_layerdrop) | |
else: | |
self.layers = nn.ModuleList([]) | |
self.use_rel_pos_enc = use_rel_pos_enc | |
self.scaling_for_att = scaling_for_att | |
self.layers.extend( | |
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] | |
) | |
self.num_layers = len(self.layers) | |
if cfg.encoder.normalize_before: | |
self.layer_norm = LayerNorm(embed_dim, export=cfg.export) | |
else: | |
self.layer_norm = None | |
if self.use_rel_pos_enc: | |
self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160) | |
def build_encoder_layer(self, cfg): | |
layer = TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att) | |
checkpoint = cfg.checkpoint_activations | |
if checkpoint: | |
raise ValueError("We don't support checkpoint_activations for now! Please set cfg.checkpoint_activations=False.") | |
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
return layer | |
def forward_embedding( | |
self, src_tokens, token_embedding: Optional[torch.Tensor] = None | |
): | |
# embed tokens and positions | |
if token_embedding is None: | |
token_embedding = self.embed_tokens(src_tokens) | |
x = embed = self.embed_scale * token_embedding | |
if self.embed_positions is not None: | |
x = embed + self.embed_positions(src_tokens) | |
if self.layernorm_embedding is not None: | |
x = self.layernorm_embedding(x) | |
x = self.dropout_module(x) | |
if self.quant_noise is not None: | |
x = self.quant_noise(x) | |
return x, embed | |
def forward( | |
self, | |
src_tokens, | |
src_lengths: Optional[torch.Tensor] = None, | |
return_all_hiddens: bool = False, | |
token_embeddings: Optional[torch.Tensor] = None, | |
uniformity_layers: Optional[List[int]] = None, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (torch.LongTensor): lengths of each source sentence of | |
shape `(batch)` | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
token_embeddings (torch.Tensor, optional): precomputed embeddings | |
default `None` will recompute embeddings | |
Returns: | |
dict: | |
- **encoder_out** (Tensor): the last encoder layer's output of | |
shape `(src_len, batch, embed_dim)` | |
- **encoder_padding_mask** (ByteTensor): the positions of | |
padding elements of shape `(batch, src_len)` | |
- **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
of shape `(batch, src_len, embed_dim)` | |
- **encoder_states** (List[Tensor]): all intermediate | |
hidden states of shape `(src_len, batch, embed_dim)`. | |
Only populated if *return_all_hiddens* is True. | |
""" | |
return self.forward_scriptable( | |
src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers | |
) | |
# TorchScript doesn't support super() method so that the scriptable Subclass | |
# can't access the base class model in Torchscript. | |
# Current workaround is to add a helper function with different name and | |
# call the helper function from scriptable Subclass. | |
def forward_scriptable( | |
self, | |
src_tokens, | |
src_lengths: Optional[torch.Tensor] = None, | |
return_all_hiddens: bool = False, | |
token_embeddings: Optional[torch.Tensor] = None, | |
uniformity_layers: Optional[List[int]] = None, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (torch.LongTensor): lengths of each source sentence of | |
shape `(batch)` | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
token_embeddings (torch.Tensor, optional): precomputed embeddings | |
default `None` will recompute embeddings | |
Returns: | |
dict: | |
- **encoder_out** (Tensor): the last encoder layer's output of | |
shape `(src_len, batch, embed_dim)` | |
- **encoder_padding_mask** (ByteTensor): the positions of | |
padding elements of shape `(batch, src_len)` | |
- **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
of shape `(batch, src_len, embed_dim)` | |
- **encoder_states** (List[Tensor]): all intermediate | |
hidden states of shape `(src_len, batch, embed_dim)`. | |
Only populated if *return_all_hiddens* is True. | |
""" | |
# compute padding mask | |
encoder_padding_mask = src_tokens.eq(self.padding_idx) | |
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() | |
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) | |
# account for padding while computing the representation | |
if has_pads: | |
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |
# B x T x C -> T x B x C | |
x = x.transpose(0, 1) | |
if self.use_rel_pos_enc: | |
x_len = x.shape[0] | |
pos_seq = torch.arange(0, x_len).long().to(x.device) | |
pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
pos_k, pos_v = self.pos_emb(pos_seq) | |
else: | |
pos_k = None | |
encoder_states = [] | |
uniformity_hiddens = [] | |
if return_all_hiddens: | |
encoder_states.append(x) | |
if uniformity_layers is not None and 0 in uniformity_layers: | |
x = F.normalize(x.float(), dim=-1).type_as(x) | |
uniformity_hiddens.append(x) | |
# encoder layers | |
for i, layer in enumerate(self.layers): | |
x = layer( | |
x, encoder_padding_mask=encoder_padding_mask if has_pads else None, | |
pos_bias=pos_k, | |
) | |
if uniformity_layers is not None and i+1 in uniformity_layers: | |
x = F.normalize(x.float(), dim=-1).type_as(x) | |
uniformity_hiddens.append(x) | |
if return_all_hiddens: | |
assert encoder_states is not None | |
encoder_states.append(x) | |
if self.layer_norm is not None: | |
x = self.layer_norm(x) | |
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in | |
# `forward` so we use a dictionary instead. | |
# TorchScript does not support mixed values so the values are all lists. | |
# The empty list is equivalent to None. | |
src_lengths = ( | |
src_tokens.ne(self.padding_idx) | |
.sum(dim=1, dtype=torch.int32) | |
.reshape(-1, 1) | |
.contiguous() | |
) | |
return { | |
"encoder_out": [x], # T x B x C | |
"encoder_padding_mask": [encoder_padding_mask], # B x T | |
"encoder_embedding": [encoder_embedding], # B x T x C | |
"encoder_states": encoder_states, # List[T x B x C] | |
"uniformity_hiddens": uniformity_hiddens, # List[T x B x C] | |
"src_tokens": [], | |
"src_lengths": [src_lengths], | |
} | |
def forward_torchscript(self, net_input: Dict[str, Tensor]): | |
"""A TorchScript-compatible version of forward. | |
Encoders which use additional arguments may want to override | |
this method for TorchScript compatibility. | |
""" | |
if torch.jit.is_scripting(): | |
return self.forward( | |
src_tokens=net_input["src_tokens"], | |
src_lengths=net_input["src_lengths"], | |
) | |
else: | |
return self.forward_non_torchscript(net_input) | |
def forward_non_torchscript(self, net_input: Dict[str, Tensor]): | |
encoder_input = { | |
k: v for k, v in net_input.items() if k != "prev_output_tokens" | |
} | |
return self.forward(**encoder_input) | |
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): | |
""" | |
Reorder encoder output according to *new_order*. | |
Args: | |
encoder_out: output from the ``forward()`` method | |
new_order (LongTensor): desired order | |
Returns: | |
*encoder_out* rearranged according to *new_order* | |
""" | |
if len(encoder_out["encoder_out"]) == 0: | |
new_encoder_out = [] | |
else: | |
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] | |
if len(encoder_out["encoder_padding_mask"]) == 0: | |
new_encoder_padding_mask = [] | |
else: | |
new_encoder_padding_mask = [ | |
encoder_out["encoder_padding_mask"][0].index_select(0, new_order) | |
] | |
if len(encoder_out["encoder_embedding"]) == 0: | |
new_encoder_embedding = [] | |
else: | |
new_encoder_embedding = [ | |
encoder_out["encoder_embedding"][0].index_select(0, new_order) | |
] | |
if len(encoder_out["src_tokens"]) == 0: | |
src_tokens = [] | |
else: | |
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] | |
if len(encoder_out["src_lengths"]) == 0: | |
src_lengths = [] | |
else: | |
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] | |
encoder_states = encoder_out["encoder_states"] | |
if len(encoder_states) > 0: | |
for idx, state in enumerate(encoder_states): | |
encoder_states[idx] = state.index_select(1, new_order) | |
return { | |
"encoder_out": new_encoder_out, # T x B x C | |
"encoder_padding_mask": new_encoder_padding_mask, # B x T | |
"encoder_embedding": new_encoder_embedding, # B x T x C | |
"encoder_states": encoder_states, # List[T x B x C] | |
"src_tokens": src_tokens, # B x T | |
"src_lengths": src_lengths, # B x 1 | |
} | |
def max_positions(self): | |
"""Maximum input length supported by the encoder.""" | |
if self.embed_positions is None: | |
return self.max_source_positions | |
return min(self.max_source_positions, self.embed_positions.max_positions) | |
def upgrade_state_dict_named(self, state_dict, name): | |
"""Upgrade a (possibly old) state dict for new versions.""" | |
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): | |
weights_key = "{}.embed_positions.weights".format(name) | |
if weights_key in state_dict: | |
print("deleting {0}".format(weights_key)) | |
del state_dict[weights_key] | |
state_dict[ | |
"{}.embed_positions._float_tensor".format(name) | |
] = torch.FloatTensor(1) | |
for i in range(self.num_layers): | |
# update layer norms | |
self.layers[i].upgrade_state_dict_named( | |
state_dict, "{}.layers.{}".format(name, i) | |
) | |
version_key = "{}.version".format(name) | |
if utils_item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: | |
# earlier checkpoints did not normalize after the stack of layers | |
self.layer_norm = None | |
self.normalize = False | |
state_dict[version_key] = torch.Tensor([1]) | |
return state_dict | |
def set_num_updates(self, num_updates): | |
"""State from trainer to pass along to model at every update.""" | |
def _apply(m): | |
if hasattr(m, "set_num_updates") and m != self: | |
m.set_num_updates(num_updates) | |
self.apply(_apply) | |
class TransformerEncoderLayerBase(nn.Module): | |
"""Encoder layer block. | |
In the original paper each operation (multi-head attention or FFN) is | |
postprocessed with: `dropout -> add residual -> layernorm`. In the | |
tensor2tensor code they suggest that learning is more robust when | |
preprocessing each layer with layernorm and postprocessing with: | |
`dropout -> add residual`. We default to the approach in the paper, but the | |
tensor2tensor approach can be enabled by setting | |
*cfg.encoder.normalize_before* to ``True``. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
""" | |
def __init__(self, cfg, has_relative_attention_bias=False, scaling_for_att=1.0): | |
super().__init__() | |
self.cfg = cfg | |
self.embed_dim = cfg.encoder.embed_dim | |
self.quant_noise = cfg.quant_noise.pq | |
self.quant_noise_block_size = cfg.quant_noise.pq_block_size | |
self.self_attn = self.build_self_attention(self.embed_dim, cfg, has_relative_attention_bias=has_relative_attention_bias, scaling_for_att=scaling_for_att) | |
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
self.dropout_module = FairseqDropout( | |
cfg.dropout, module_name=self.__class__.__name__ | |
) | |
self.activation_fn = get_activation_fn(activation=cfg.activation_fn) | |
activation_dropout_p = cfg.activation_dropout | |
if activation_dropout_p == 0: | |
# for backwards compatibility with models that use cfg.relu_dropout | |
activation_dropout_p = cfg.relu_dropout or 0 | |
self.activation_dropout_module = FairseqDropout( | |
float(activation_dropout_p), module_name=self.__class__.__name__ | |
) | |
self.normalize_before = cfg.encoder.normalize_before | |
self.fc1 = self.build_fc1( | |
self.embed_dim, | |
cfg.encoder.ffn_embed_dim, | |
self.quant_noise, | |
self.quant_noise_block_size, | |
) | |
self.fc2 = self.build_fc2( | |
cfg.encoder.ffn_embed_dim, | |
self.embed_dim, | |
self.quant_noise, | |
self.quant_noise_block_size, | |
) | |
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
if has_relative_attention_bias: | |
self.norm_k = LayerNorm(self.embed_dim // cfg.encoder.attention_heads) | |
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): | |
return quant_noise( | |
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size | |
) | |
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): | |
return quant_noise( | |
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size | |
) | |
def build_self_attention(self, embed_dim, cfg, has_relative_attention_bias=False, scaling_for_att=1.0): | |
return MultiheadAttention( | |
embed_dim, | |
cfg.encoder.attention_heads, | |
dropout=cfg.attention_dropout, | |
self_attention=True, | |
q_noise=self.quant_noise, | |
qn_block_size=self.quant_noise_block_size, | |
has_relative_attention_bias=has_relative_attention_bias, | |
scaling_for_att=scaling_for_att, | |
) | |
def residual_connection(self, x, residual): | |
return residual + x | |
def upgrade_state_dict_named(self, state_dict, name): | |
""" | |
Rename layer norm states from `...layer_norms.0.weight` to | |
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to | |
`...final_layer_norm.weight` | |
""" | |
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} | |
for old, new in layer_norm_map.items(): | |
for m in ("weight", "bias"): | |
k = "{}.layer_norms.{}.{}".format(name, old, m) | |
if k in state_dict: | |
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] | |
del state_dict[k] | |
def forward( | |
self, | |
x, | |
encoder_padding_mask: Optional[Tensor], | |
attn_mask: Optional[Tensor] = None, | |
pos_bias=None, | |
): | |
""" | |
Args: | |
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
encoder_padding_mask (ByteTensor): binary ByteTensor of shape | |
`(batch, seq_len)` where padding elements are indicated by ``1``. | |
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, | |
where `tgt_len` is the length of output and `src_len` is the | |
length of input, though here both are equal to `seq_len`. | |
`attn_mask[tgt_i, src_j] = 1` means that when calculating the | |
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is | |
useful for strided self-attention. | |
Returns: | |
encoded output of shape `(seq_len, batch, embed_dim)` | |
""" | |
# anything in original attn_mask = 1, becomes -1e8 | |
# anything in original attn_mask = 0, becomes 0 | |
# Note that we cannot use -inf here, because at some edge cases, | |
# the attention weight (before softmax) for some padded element in query | |
# will become -inf, which results in NaN in model parameters | |
if attn_mask is not None: | |
attn_mask = attn_mask.masked_fill( | |
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 | |
) | |
residual = x | |
if self.normalize_before: | |
x = self.self_attn_layer_norm(x) | |
if pos_bias is not None: | |
pos_bias = self.norm_k(pos_bias) | |
x, _ = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=encoder_padding_mask, | |
need_weights=False, | |
attn_mask=attn_mask, | |
position_bias=pos_bias, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.self_attn_layer_norm(x) | |
residual = x | |
if self.normalize_before: | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
x = self.activation_dropout_module(x) | |
x = self.fc2(x) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.final_layer_norm(x) | |
return x | |
class TransformerEncoder(nn.Module): | |
""" | |
wav2vec-style transformer encoder. | |
""" | |
def __init__(self, args): | |
super().__init__() | |
self.dropout = args.dropout | |
self.embedding_dim = args.encoder_embed_dim | |
self.required_seq_len_multiple = args.required_seq_len_multiple | |
self.pos_conv = nn.Conv1d( | |
self.embedding_dim, | |
self.embedding_dim, | |
kernel_size=args.conv_pos, | |
padding=args.conv_pos // 2, | |
groups=args.conv_pos_groups, | |
) | |
dropout = 0 | |
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) | |
nn.init.normal_(self.pos_conv.weight, mean=0, std=std) | |
nn.init.constant_(self.pos_conv.bias, 0) | |
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) | |
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) | |
layers = [] | |
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False) | |
for _ in range(args.encoder_layers): | |
layer = TransformerSentenceEncoderLayer( | |
embedding_dim=self.embedding_dim, | |
ffn_embedding_dim=args.encoder_ffn_embed_dim, | |
num_attention_heads=args.encoder_attention_heads, | |
dropout=self.dropout, | |
attention_dropout=args.attention_dropout, | |
activation_dropout=args.activation_dropout, | |
activation_fn=args.activation_fn, | |
layer_norm_first=args.layer_norm_first, | |
has_relative_attention_bias=self.use_rel_pos_enc, | |
scaling_for_att=getattr(args, "scaling_for_att", 1.0) | |
) | |
if args.checkpoint_activations: | |
raise ValueError("We don't support checkpoint_activations for now! Please set checkpoint_activations=False.") | |
layers.append(layer) | |
self.layers = nn.ModuleList(layers) | |
self.layer_norm_first = args.layer_norm_first | |
self.layer_norm = LayerNorm(self.embedding_dim) | |
self.layerdrop = args.encoder_layerdrop | |
if self.use_rel_pos_enc: | |
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160) | |
self.apply(init_bert_params) | |
def forward(self, x, padding_mask=None, layer=None, conv_pos=True): | |
x, layer_results = self.extract_features(x, padding_mask, layer, conv_pos) | |
if self.layer_norm_first and (layer is None or layer >= len(self.layers) - 1): | |
x = self.layer_norm(x) | |
return x, layer_results | |
def extract_features(self, x, padding_mask=None, tgt_layer=None, conv_pos=True): | |
if padding_mask is not None: | |
x = index_put(x, padding_mask, 0) | |
if conv_pos: | |
x_conv = self.pos_conv(x.transpose(1, 2)) | |
x_conv = x_conv.transpose(1, 2) | |
x = x + x_conv | |
if not self.layer_norm_first: | |
x = self.layer_norm(x) | |
# pad to the sequence length dimension | |
x, pad_length = pad_to_multiple( | |
x, self.required_seq_len_multiple, dim=-2, value=0 | |
) | |
if pad_length > 0 and padding_mask is None: | |
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) | |
padding_mask[:, -pad_length:] = True | |
else: | |
padding_mask, _ = pad_to_multiple( | |
padding_mask, self.required_seq_len_multiple, dim=-1, value=True | |
) | |
x = F.dropout(x, p=self.dropout, training=self.training) | |
# B x T x C -> T x B x C | |
x = x.transpose(0, 1) | |
if self.use_rel_pos_enc: | |
x_len = x.shape[0] | |
pos_seq = torch.arange(0, x_len).long().to(x.device) | |
pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
pos_k, pos_v = self.pos_emb(pos_seq) | |
else: | |
pos_k = None | |
layer_results = [] | |
r = None | |
for i, layer in enumerate(self.layers): | |
dropout_probability = np.random.random() | |
if not self.training or (dropout_probability > self.layerdrop): | |
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k) | |
if tgt_layer is not None: | |
# unpad if needed | |
if pad_length > 0: | |
layer_results.append( | |
x[:-pad_length] | |
# ( | |
# x[:-pad_length], | |
# z[:, :-pad_length, :-pad_length] | |
# if z is not None | |
# else z, | |
# ) | |
) | |
else: | |
# layer_results.append((x, z)) | |
layer_results.append(x) | |
if i == tgt_layer: | |
r = x | |
break | |
if r is not None: | |
x = r | |
# T x B x C -> B x T x C | |
x = x.transpose(0, 1) | |
# undo paddding | |
if pad_length > 0: | |
x = x[:, :-pad_length] | |
return x, layer_results | |
def max_positions(self): | |
"""Maximum output length supported by the encoder.""" | |
return self.args.max_positions | |
def upgrade_state_dict_named(self, state_dict, name): | |
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
return state_dict | |
class TransformerSentenceEncoderLayer(nn.Module): | |
""" | |
wav2vec-style transformer layer | |
""" | |
def __init__( | |
self, | |
embedding_dim: float = 768, | |
ffn_embedding_dim: float = 3072, | |
num_attention_heads: float = 8, | |
dropout: float = 0.1, | |
attention_dropout: float = 0.1, | |
activation_dropout: float = 0.1, | |
activation_fn: str = "relu", | |
layer_norm_first: bool = False, | |
has_relative_attention_bias: bool = False, | |
scaling_for_att: float = 1.0, | |
) -> None: | |
super().__init__() | |
# Initialize parameters | |
self.embedding_dim = embedding_dim | |
self.dropout = dropout | |
self.activation_dropout = activation_dropout | |
# Initialize blocks | |
self.activation_fn = get_activation_fn(activation_fn) | |
self.self_attn = MultiheadAttention( | |
self.embedding_dim, | |
num_attention_heads, | |
dropout=attention_dropout, | |
self_attention=True, | |
has_relative_attention_bias=has_relative_attention_bias, | |
scaling_for_att=scaling_for_att | |
) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(self.activation_dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.layer_norm_first = layer_norm_first | |
# layer norm associated with the self attention layer | |
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) | |
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) | |
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) | |
# layer norm associated with the position wise feed-forward NN | |
self.final_layer_norm = LayerNorm(self.embedding_dim) | |
if has_relative_attention_bias: | |
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads) | |
def forward( | |
self, | |
x: torch.Tensor, | |
self_attn_mask: torch.Tensor = None, | |
self_attn_padding_mask: torch.Tensor = None, | |
need_weights: bool = False, | |
att_args=None, | |
pos_bias=None, | |
): | |
""" | |
LayerNorm is applied either before or after the self-attention/ffn | |
modules similar to the original Transformer imlementation. | |
""" | |
residual = x | |
if self.layer_norm_first: | |
x = self.self_attn_layer_norm(x) | |
if pos_bias is not None: | |
pos_bias = self.norm_k(pos_bias) | |
x, attn = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=self_attn_padding_mask, | |
attn_mask=self_attn_mask, | |
position_bias=pos_bias, | |
) | |
x = self.dropout1(x) | |
x = residual + x | |
residual = x | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
x = self.dropout3(x) | |
x = residual + x | |
else: | |
x, attn = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=self_attn_padding_mask, | |
position_bias=pos_bias, | |
) | |
x = self.dropout1(x) | |
x = residual + x | |
x = self.self_attn_layer_norm(x) | |
residual = x | |
x = self.activation_fn(self.fc1(x)) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
x = self.dropout3(x) | |
x = residual + x | |
x = self.final_layer_norm(x) | |
return x, attn | |
class FairseqDropout(nn.Module): | |
def __init__(self, p, module_name=None): | |
super().__init__() | |
self.p = p | |
self.module_name = module_name | |
self.apply_during_inference = False | |
def forward(self, x, inplace: bool = False): | |
if self.p > 0 and (self.training or self.apply_during_inference): | |
return F.dropout(x, p=self.p, training=True, inplace=inplace) | |
else: | |
return x | |
def make_generation_fast_( | |
self, | |
name: str, | |
retain_dropout: bool = False, | |
retain_dropout_modules: Optional[List[str]] = None, | |
**kwargs | |
): | |
if retain_dropout: | |
if retain_dropout_modules is not None and self.module_name is None: | |
logger.warning( | |
"Cannot enable dropout during inference for module {} " | |
"because module_name was not set".format(name) | |
) | |
elif ( | |
retain_dropout_modules is None # if None, apply to all modules | |
or self.module_name in retain_dropout_modules | |
): | |
logger.info( | |
"Enabling dropout during inference for module: {}".format(name) | |
) | |
self.apply_during_inference = True | |
else: | |
logger.info("Disabling dropout for module: {}".format(name)) | |
class LearnedPositionalEmbedding(nn.Embedding): | |
""" | |
This module learns positional embeddings up to a fixed maximum size. | |
Padding ids are ignored by either offsetting based on padding_idx | |
or by setting padding_idx to None and ensuring that the appropriate | |
position ids are passed to the forward function. | |
""" | |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
super().__init__(num_embeddings, embedding_dim, padding_idx) | |
self.onnx_trace = False | |
if self.padding_idx is not None: | |
self.max_positions = self.num_embeddings - self.padding_idx - 1 | |
else: | |
self.max_positions = self.num_embeddings | |
def forward( | |
self, | |
input: Tensor, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
positions: Optional[Tensor] = None, | |
): | |
"""Input is expected to be of size [bsz x seqlen].""" | |
assert (positions is None) or ( | |
self.padding_idx is None | |
), "If positions is pre-computed then padding_idx should not be set." | |
if positions is None: | |
if incremental_state is not None: | |
# positions is the same for every token when decoding a single step | |
# Without the int() cast, it doesn't work in some cases when exporting to ONNX | |
positions = torch.zeros( | |
(1, 1), device=input.device, dtype=input.dtype | |
).fill_(int(self.padding_idx + input.size(1))) | |
else: | |
positions = utils_make_positions( | |
input, self.padding_idx, onnx_trace=self.onnx_trace | |
) | |
positions = torch.clamp(positions, max=self.padding_idx + self.max_positions) | |
return F.embedding( | |
positions, | |
self.weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
) | |
class SinusoidalPositionalEmbedding(nn.Module): | |
"""This module produces sinusoidal positional embeddings of any length. | |
Padding symbols are ignored. | |
""" | |
def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.padding_idx = padding_idx if padding_idx is not None else 0 | |
self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
init_size, embedding_dim, padding_idx | |
) | |
self.onnx_trace = False | |
self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
self.max_positions = int(1e5) | |
def prepare_for_onnx_export_(self): | |
self.onnx_trace = True | |
def get_embedding( | |
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None | |
): | |
"""Build sinusoidal embeddings. | |
This matches the implementation in tensor2tensor, but differs slightly | |
from the description in Section 3.5 of "Attention Is All You Need". | |
""" | |
half_dim = embedding_dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( | |
1 | |
) * emb.unsqueeze(0) | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( | |
num_embeddings, -1 | |
) | |
if embedding_dim % 2 == 1: | |
# zero pad | |
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
if padding_idx is not None: | |
emb[padding_idx, :] = 0 | |
return emb | |
def forward( | |
self, | |
input, | |
incremental_state: Optional[Any] = None, | |
timestep: Optional[Tensor] = None, | |
positions: Optional[Any] = None, | |
): | |
"""Input is expected to be of size [bsz x seqlen].""" | |
bspair = torch.onnx.operators.shape_as_tensor(input) | |
bsz, seq_len = bspair[0], bspair[1] | |
max_pos = self.padding_idx + 1 + seq_len | |
if self.weights is None or max_pos > self.weights.size(0): | |
# recompute/expand embeddings if needed | |
self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
max_pos, self.embedding_dim, self.padding_idx | |
) | |
self.weights = self.weights.to(self._float_tensor) | |
if incremental_state is not None: | |
# positions is the same for every token when decoding a single step | |
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
if self.onnx_trace: | |
return ( | |
self.weights.index_select(index=self.padding_idx + pos, dim=0) | |
.unsqueeze(1) | |
.repeat(bsz, 1, 1) | |
) | |
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
positions = utils_make_positions( | |
input, self.padding_idx, onnx_trace=self.onnx_trace | |
) | |
if self.onnx_trace: | |
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) | |
embedding_shape = torch.cat( | |
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) | |
) | |
embeddings = torch.onnx.operators.reshape_from_tensor_shape( | |
flat_embeddings, embedding_shape | |
) | |
return embeddings | |
return ( | |
self.weights.index_select(0, positions.view(-1)) | |
.view(bsz, seq_len, -1) | |
.detach() | |
) | |
try: | |
from apex.normalization import FusedLayerNorm as _FusedLayerNorm | |
has_fused_layernorm = True | |
class FusedLayerNorm(_FusedLayerNorm): | |
def forward(self, x): | |
if not x.is_cuda: | |
return super().forward(x) | |
else: | |
with torch.cuda.device(x.device): | |
return super().forward(x) | |
except ImportError: | |
has_fused_layernorm = False | |
class Fp32LayerNorm(nn.LayerNorm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, input): | |
output = F.layer_norm( | |
input.float(), | |
self.normalized_shape, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
) | |
return output.type_as(input) | |
class LayerDropModuleList(nn.ModuleList): | |
""" | |
A LayerDrop implementation based on :class:`torch.nn.ModuleList`. | |
We refresh the choice of which layers to drop every time we iterate | |
over the LayerDropModuleList instance. During evaluation we always | |
iterate over all layers. | |
Usage:: | |
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) | |
for layer in layers: # this might iterate over layers 1 and 3 | |
x = layer(x) | |
for layer in layers: # this might iterate over all layers | |
x = layer(x) | |
for layer in layers: # this might not iterate over any layers | |
x = layer(x) | |
Args: | |
p (float): probability of dropping out each layer | |
modules (iterable, optional): an iterable of modules to add | |
""" | |
def __init__(self, p, modules=None): | |
super().__init__(modules) | |
self.p = p | |
def __iter__(self): | |
dropout_probs = torch.empty(len(self)).uniform_() | |
for i, m in enumerate(super().__iter__()): | |
if not self.training or (dropout_probs[i] > self.p): | |
yield m | |
class RelativePositionalEncoding(torch.nn.Module): | |
def __init__(self, d_model, maxlen=1000, embed_v=False): | |
super(RelativePositionalEncoding, self).__init__() | |
self.d_model = d_model | |
self.maxlen = maxlen | |
self.pe_k = torch.nn.Embedding(2*maxlen, d_model) | |
if embed_v: | |
self.pe_v = torch.nn.Embedding(2*maxlen, d_model) | |
self.embed_v = embed_v | |
def forward(self, pos_seq, incremental_state=None): | |
pos_seq[pos_seq < -self.maxlen] = -self.maxlen | |
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 | |
pos_seq = pos_seq + self.maxlen | |
if incremental_state is not None: | |
pos_seq = pos_seq[-1:] | |
if self.embed_v: | |
return self.pe_k(pos_seq), self.pe_v(pos_seq) | |
else: | |
return self.pe_k(pos_seq), None | |
class MultiheadAttention(nn.Module): | |
"""Multi-headed attention. | |
See "Attention Is All You Need" for more details. | |
""" | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
kdim=None, | |
vdim=None, | |
dropout=0.0, | |
bias=True, | |
add_bias_kv=False, | |
add_zero_attn=False, | |
self_attention=False, | |
encoder_decoder_attention=False, | |
q_noise=0.0, | |
qn_block_size=8, | |
has_relative_attention_bias=False, | |
scaling_for_att=1.0 | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
self.num_heads = num_heads | |
self.dropout_module = FairseqDropout( | |
dropout, module_name=self.__class__.__name__ | |
) | |
self.has_relative_attention_bias = has_relative_attention_bias | |
self.head_dim = embed_dim // num_heads | |
assert ( | |
self.head_dim * num_heads == self.embed_dim | |
), "embed_dim must be divisible by num_heads" | |
self.scaling = self.head_dim ** -0.5 | |
self.scaling_for_att = scaling_for_att | |
self.self_attention = self_attention | |
self.encoder_decoder_attention = encoder_decoder_attention | |
assert not self.self_attention or self.qkv_same_dim, ( | |
"Self-attention requires query, key and " "value to be of the same size" | |
) | |
self.k_proj = quant_noise( | |
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size | |
) | |
self.v_proj = quant_noise( | |
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size | |
) | |
self.q_proj = quant_noise( | |
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size | |
) | |
self.out_proj = quant_noise( | |
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size | |
) | |
if add_bias_kv: | |
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |
else: | |
self.bias_k = self.bias_v = None | |
self.add_zero_attn = add_zero_attn | |
self.reset_parameters() | |
self.onnx_trace = False | |
def prepare_for_onnx_export_(self): | |
self.onnx_trace = True | |
def reset_parameters(self): | |
if self.qkv_same_dim: | |
# Empirically observed the convergence to be much better with | |
# the scaled initialization | |
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) | |
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) | |
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) | |
else: | |
nn.init.xavier_uniform_(self.k_proj.weight) | |
nn.init.xavier_uniform_(self.v_proj.weight) | |
nn.init.xavier_uniform_(self.q_proj.weight) | |
nn.init.xavier_uniform_(self.out_proj.weight) | |
if self.out_proj.bias is not None: | |
nn.init.constant_(self.out_proj.bias, 0.0) | |
if self.bias_k is not None: | |
nn.init.xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
nn.init.xavier_normal_(self.bias_v) | |
def forward( | |
self, | |
query, | |
key: Optional[Tensor], | |
value: Optional[Tensor], | |
key_padding_mask: Optional[Tensor] = None, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
need_weights: bool = True, | |
static_kv: bool = False, | |
attn_mask: Optional[Tensor] = None, | |
before_softmax: bool = False, | |
need_head_weights: bool = False, | |
position_bias: Optional[Tensor] = None | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
"""Input shape: Time x Batch x Channel | |
Args: | |
key_padding_mask (ByteTensor, optional): mask to exclude | |
keys that are pads, of shape `(batch, src_len)`, where | |
padding elements are indicated by 1s. | |
need_weights (bool, optional): return the attention weights, | |
averaged over heads (default: False). | |
attn_mask (ByteTensor, optional): typically used to | |
implement causal attention, where the mask prevents the | |
attention from looking forward in time (default: None). | |
before_softmax (bool, optional): return the raw attention | |
weights and values before the attention softmax. | |
need_head_weights (bool, optional): return the attention | |
weights for each head. Implies *need_weights*. Default: | |
return the average attention weights over all heads. | |
""" | |
if need_head_weights: | |
need_weights = True | |
is_tpu = query.device.type == "xla" | |
tgt_len, bsz, embed_dim = query.size() | |
src_len = tgt_len | |
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" | |
assert list(query.size()) == [tgt_len, bsz, embed_dim] | |
if key is not None: | |
src_len, key_bsz, _ = key.size() | |
if not torch.jit.is_scripting(): | |
assert key_bsz == bsz | |
assert value is not None | |
assert src_len, bsz == value.shape[:2] | |
if ( | |
not self.onnx_trace | |
and not is_tpu # don't use PyTorch version on TPUs | |
and incremental_state is None | |
and not static_kv | |
# A workaround for quantization to work. Otherwise JIT compilation | |
# treats bias in linear module as method. | |
and not torch.jit.is_scripting() | |
and not self.has_relative_attention_bias | |
): | |
assert key is not None and value is not None | |
return F.multi_head_attention_forward( | |
query, | |
key, | |
value, | |
self.embed_dim, | |
self.num_heads, | |
torch.empty([0]), | |
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), | |
self.bias_k, | |
self.bias_v, | |
self.add_zero_attn, | |
self.dropout_module.p, | |
self.out_proj.weight, | |
self.out_proj.bias, | |
self.training or self.dropout_module.apply_during_inference, | |
key_padding_mask, | |
need_weights, | |
attn_mask, | |
use_separate_proj_weight=True, | |
q_proj_weight=self.q_proj.weight, | |
k_proj_weight=self.k_proj.weight, | |
v_proj_weight=self.v_proj.weight, | |
) | |
if incremental_state is not None: | |
saved_state = self._get_input_buffer(incremental_state) | |
if saved_state is not None and "prev_key" in saved_state: | |
# previous time steps are cached - no need to recompute | |
# key and value if they are static | |
if static_kv: | |
assert self.encoder_decoder_attention and not self.self_attention | |
key = value = None | |
else: | |
saved_state = None | |
if self.self_attention: | |
q = self.q_proj(query) | |
k = self.k_proj(query) | |
v = self.v_proj(query) | |
elif self.encoder_decoder_attention: | |
# encoder-decoder attention | |
q = self.q_proj(query) | |
if key is None: | |
assert value is None | |
k = v = None | |
else: | |
k = self.k_proj(key) | |
v = self.v_proj(key) | |
else: | |
assert key is not None and value is not None | |
q = self.q_proj(query) | |
k = self.k_proj(key) | |
v = self.v_proj(value) | |
q *= self.scaling | |
q *= (1 / self.scaling_for_att) | |
if self.bias_k is not None: | |
assert self.bias_v is not None | |
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) | |
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) | |
if attn_mask is not None: | |
attn_mask = torch.cat( | |
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 | |
) | |
if key_padding_mask is not None: | |
key_padding_mask = torch.cat( | |
[ | |
key_padding_mask, | |
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), | |
], | |
dim=1, | |
) | |
q = ( | |
q.contiguous() | |
.view(tgt_len, bsz * self.num_heads, self.head_dim) | |
.transpose(0, 1) | |
) | |
if k is not None: | |
k = ( | |
k.contiguous() | |
.view(-1, bsz * self.num_heads, self.head_dim) | |
.transpose(0, 1) | |
) | |
if v is not None: | |
v = ( | |
v.contiguous() | |
.view(-1, bsz * self.num_heads, self.head_dim) | |
.transpose(0, 1) | |
) | |
if saved_state is not None: | |
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | |
if "prev_key" in saved_state: | |
_prev_key = saved_state["prev_key"] | |
assert _prev_key is not None | |
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) | |
if static_kv: | |
k = prev_key | |
else: | |
assert k is not None | |
k = torch.cat([prev_key, k], dim=1) | |
src_len = k.size(1) | |
if "prev_value" in saved_state: | |
_prev_value = saved_state["prev_value"] | |
assert _prev_value is not None | |
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) | |
if static_kv: | |
v = prev_value | |
else: | |
assert v is not None | |
v = torch.cat([prev_value, v], dim=1) | |
prev_key_padding_mask: Optional[Tensor] = None | |
if "prev_key_padding_mask" in saved_state: | |
prev_key_padding_mask = saved_state["prev_key_padding_mask"] | |
assert k is not None and v is not None | |
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | |
key_padding_mask=key_padding_mask, | |
prev_key_padding_mask=prev_key_padding_mask, | |
batch_size=bsz, | |
src_len=k.size(1), | |
static_kv=static_kv, | |
) | |
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) | |
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) | |
saved_state["prev_key_padding_mask"] = key_padding_mask | |
# In this branch incremental_state is never None | |
assert incremental_state is not None | |
incremental_state = self._set_input_buffer(incremental_state, saved_state) | |
assert k is not None | |
assert k.size(1) == src_len | |
# This is part of a workaround to get around fork/join parallelism | |
# not supporting Optional types. | |
if key_padding_mask is not None and key_padding_mask.dim() == 0: | |
key_padding_mask = None | |
if key_padding_mask is not None: | |
assert key_padding_mask.size(0) == bsz | |
assert key_padding_mask.size(1) == src_len | |
if self.add_zero_attn: | |
assert v is not None | |
src_len += 1 | |
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) | |
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) | |
if attn_mask is not None: | |
attn_mask = torch.cat( | |
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 | |
) | |
if key_padding_mask is not None: | |
key_padding_mask = torch.cat( | |
[ | |
key_padding_mask, | |
torch.zeros(key_padding_mask.size(0), 1).type_as( | |
key_padding_mask | |
), | |
], | |
dim=1, | |
) | |
attn_weights = torch.bmm(q, k.transpose(1, 2)) | |
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) | |
if position_bias is not None: ## first order | |
## position_bias: [241, 241, 64] | |
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241] | |
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64] | |
#print ("reshape_q: ", reshape_q.size()) | |
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) | |
#print ("B: ", B.size()) ## [241, 492, 241] | |
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1)) | |
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1)) | |
#print ("B 2: ", B.size()) | |
attn_weights += B | |
attn_weights *= self.scaling_for_att | |
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |
if attn_mask is not None: | |
attn_mask = attn_mask.unsqueeze(0) | |
if self.onnx_trace: | |
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) | |
attn_weights += attn_mask | |
if key_padding_mask is not None: | |
# don't attend to padding symbols | |
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
if not is_tpu: | |
attn_weights = attn_weights.masked_fill( | |
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | |
float("-inf"), | |
) | |
else: | |
attn_weights = attn_weights.transpose(0, 2) | |
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) | |
attn_weights = attn_weights.transpose(0, 2) | |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
if self.scaling_for_att > 1.0: | |
attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0] | |
if before_softmax: | |
return attn_weights, v | |
attn_weights_float = softmax( | |
attn_weights, dim=-1, onnx_trace=self.onnx_trace | |
) | |
attn_weights = attn_weights_float.type_as(attn_weights) | |
attn_probs = self.dropout_module(attn_weights) | |
assert v is not None | |
attn = torch.bmm(attn_probs, v) | |
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |
if self.onnx_trace and attn.size(1) == 1: | |
# when ONNX tracing a single decoder step (sequence length == 1) | |
# the transpose is a no-op copy before view, thus unnecessary | |
attn = attn.contiguous().view(tgt_len, bsz, embed_dim) | |
else: | |
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | |
attn = self.out_proj(attn) | |
attn_weights: Optional[Tensor] = None | |
if need_weights: | |
attn_weights = attn_weights_float.view( | |
bsz, self.num_heads, tgt_len, src_len | |
).transpose(1, 0) | |
if not need_head_weights: | |
# average attention weights over heads | |
attn_weights = attn_weights.mean(dim=0) | |
return attn, attn_weights | |
def _append_prev_key_padding_mask( | |
key_padding_mask: Optional[Tensor], | |
prev_key_padding_mask: Optional[Tensor], | |
batch_size: int, | |
src_len: int, | |
static_kv: bool, | |
) -> Optional[Tensor]: | |
# saved key padding masks have shape (bsz, seq_len) | |
if prev_key_padding_mask is not None and static_kv: | |
new_key_padding_mask = prev_key_padding_mask | |
elif prev_key_padding_mask is not None and key_padding_mask is not None: | |
new_key_padding_mask = torch.cat( | |
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 | |
) | |
# During incremental decoding, as the padding token enters and | |
# leaves the frame, there will be a time when prev or current | |
# is None | |
elif prev_key_padding_mask is not None: | |
if src_len > prev_key_padding_mask.size(1): | |
filler = torch.zeros( | |
(batch_size, src_len - prev_key_padding_mask.size(1)), | |
device=prev_key_padding_mask.device, | |
) | |
new_key_padding_mask = torch.cat( | |
[prev_key_padding_mask.float(), filler.float()], dim=1 | |
) | |
else: | |
new_key_padding_mask = prev_key_padding_mask.float() | |
elif key_padding_mask is not None: | |
if src_len > key_padding_mask.size(1): | |
filler = torch.zeros( | |
(batch_size, src_len - key_padding_mask.size(1)), | |
device=key_padding_mask.device, | |
) | |
new_key_padding_mask = torch.cat( | |
[filler.float(), key_padding_mask.float()], dim=1 | |
) | |
else: | |
new_key_padding_mask = key_padding_mask.float() | |
else: | |
new_key_padding_mask = prev_key_padding_mask | |
return new_key_padding_mask | |
def reorder_incremental_state( | |
self, | |
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
new_order: Tensor, | |
): | |
"""Reorder buffered internal state (for incremental generation).""" | |
input_buffer = self._get_input_buffer(incremental_state) | |
if input_buffer is not None: | |
for k in input_buffer.keys(): | |
input_buffer_k = input_buffer[k] | |
if input_buffer_k is not None: | |
if self.encoder_decoder_attention and input_buffer_k.size( | |
0 | |
) == new_order.size(0): | |
break | |
input_buffer[k] = input_buffer_k.index_select(0, new_order) | |
incremental_state = self._set_input_buffer(incremental_state, input_buffer) | |
return incremental_state | |
def _get_input_buffer( | |
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] | |
) -> Dict[str, Optional[Tensor]]: | |
result = self.get_incremental_state(incremental_state, "attn_state") | |
if result is not None: | |
return result | |
else: | |
empty_result: Dict[str, Optional[Tensor]] = {} | |
return empty_result | |
def _set_input_buffer( | |
self, | |
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
buffer: Dict[str, Optional[Tensor]], | |
): | |
return self.set_incremental_state(incremental_state, "attn_state", buffer) | |
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): | |
return attn_weights | |
def upgrade_state_dict_named(self, state_dict, name): | |
prefix = name + "." if name != "" else "" | |
items_to_add = {} | |
keys_to_remove = [] | |
for k in state_dict.keys(): | |
if k.endswith(prefix + "in_proj_weight"): | |
# in_proj_weight used to be q + k + v with same dimensions | |
dim = int(state_dict[k].shape[0] / 3) | |
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] | |
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] | |
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] | |
keys_to_remove.append(k) | |
k_bias = prefix + "in_proj_bias" | |
if k_bias in state_dict.keys(): | |
dim = int(state_dict[k].shape[0] / 3) | |
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] | |
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ | |
dim : 2 * dim | |
] | |
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] | |
keys_to_remove.append(prefix + "in_proj_bias") | |
for k in keys_to_remove: | |
del state_dict[k] | |
for key, value in items_to_add.items(): | |
state_dict[key] = value | |
class ConvFeatureExtractionModel(nn.Module): | |
def __init__( | |
self, | |
conv_layers: List[Tuple[int, int, int]], | |
dropout: float = 0.0, | |
mode: str = "default", | |
conv_bias: bool = False, | |
): | |
super().__init__() | |
assert mode in {"default", "layer_norm"} | |
def block( | |
n_in, | |
n_out, | |
k, | |
stride, | |
is_layer_norm=False, | |
is_group_norm=False, | |
conv_bias=False, | |
): | |
def make_conv(): | |
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) | |
nn.init.kaiming_normal_(conv.weight) | |
return conv | |
assert ( | |
is_layer_norm and is_group_norm | |
) == False, "layer norm and group norm are exclusive" | |
if is_layer_norm: | |
return nn.Sequential( | |
make_conv(), | |
nn.Dropout(p=dropout), | |
nn.Sequential( | |
TransposeLast(), | |
Fp32LayerNorm(dim, elementwise_affine=True), | |
TransposeLast(), | |
), | |
nn.GELU(), | |
) | |
elif is_group_norm: | |
return nn.Sequential( | |
make_conv(), | |
nn.Dropout(p=dropout), | |
Fp32GroupNorm(dim, dim, affine=True), | |
nn.GELU(), | |
) | |
else: | |
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) | |
in_d = 1 | |
self.conv_layers = nn.ModuleList() | |
for i, cl in enumerate(conv_layers): | |
assert len(cl) == 3, "invalid conv definition: " + str(cl) | |
(dim, k, stride) = cl | |
self.conv_layers.append( | |
block( | |
in_d, | |
dim, | |
k, | |
stride, | |
is_layer_norm=mode == "layer_norm", | |
is_group_norm=mode == "default" and i == 0, | |
conv_bias=conv_bias, | |
) | |
) | |
in_d = dim | |
def forward(self, x): | |
# BxT -> BxCxT | |
x = x.unsqueeze(1) | |
for conv in self.conv_layers: | |
x = conv(x) | |
return x | |
class TransposeLast(nn.Module): | |
def __init__(self, deconstruct_idx=None): | |
super().__init__() | |
self.deconstruct_idx = deconstruct_idx | |
def forward(self, x): | |
if self.deconstruct_idx is not None: | |
x = x[self.deconstruct_idx] | |
return x.transpose(-2, -1) | |
class Fp32GroupNorm(nn.GroupNorm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, input): | |
output = F.group_norm( | |
input.float(), | |
self.num_groups, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
) | |
return output.type_as(input) | |
class GradMultiply(torch.autograd.Function): | |
def forward(ctx, x, scale): | |
ctx.scale = scale | |
res = x.new(x) | |
return res | |
def backward(ctx, grad): | |
return grad * ctx.scale, None | |
class Rotate3D(nn.Module): | |
""" | |
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D) | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.permute(1, 2, 0) | |
class SamePad(nn.Module): | |
def __init__(self, kernel_size, causal=False): | |
super().__init__() | |
if causal: | |
self.remove = kernel_size - 1 | |
else: | |
self.remove = 1 if kernel_size % 2 == 0 else 0 | |
def forward(self, x): | |
if self.remove > 0: | |
x = x[:, :, : -self.remove] | |
return x | |