Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from enum import Enum, auto | |
import math | |
import numpy as np | |
from typing import Tuple, List, Optional, Dict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import autograd | |
from fairseq import checkpoint_utils, utils | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.models import BaseFairseqModel, register_model | |
from fairseq.modules import ( | |
SamePad, | |
TransposeLast, | |
) | |
class SegmentationType(Enum): | |
NONE = auto() | |
RANDOM = auto() | |
UNIFORM_RANDOM = auto() | |
UNIFORM_RANDOM_JOIN = auto() | |
JOIN = auto() | |
class SegmentationConfig(FairseqDataclass): | |
type: SegmentationType = SegmentationType.NONE | |
subsample_rate: float = 0.25 | |
mean_pool: bool = True | |
mean_pool_join: bool = False | |
remove_zeros: bool = False | |
class Wav2vec_UConfig(FairseqDataclass): | |
discriminator_kernel: int = 3 | |
discriminator_dilation: int = 1 | |
discriminator_dim: int = 256 | |
discriminator_causal: bool = True | |
discriminator_linear_emb: bool = False | |
discriminator_depth: int = 1 | |
discriminator_max_pool: bool = False | |
discriminator_act_after_linear: bool = False | |
discriminator_dropout: float = 0.0 | |
discriminator_spectral_norm: bool = False | |
discriminator_weight_norm: bool = False | |
generator_kernel: int = 4 | |
generator_dilation: int = 1 | |
generator_stride: int = 1 | |
generator_bias: bool = False | |
generator_dropout: float = 0.0 | |
blank_weight: float = 0 | |
blank_mode: str = "add" | |
blank_is_sil: bool = False | |
no_softmax: bool = False | |
smoothness_weight: float = 0.0 | |
smoothing: float = 0.0 | |
smoothing_one_sided: bool = False | |
gradient_penalty: float = 0.0 | |
probabilistic_grad_penalty_slicing: bool = False | |
code_penalty: float = 0.0 | |
gumbel: bool = False | |
hard_gumbel: bool = True | |
temp: Tuple[float, float, float] = (2, 0.1, 0.99995) | |
input_dim: int = 128 | |
segmentation: SegmentationConfig = SegmentationConfig() | |
class Segmenter(nn.Module): | |
cfg: SegmentationConfig | |
def __init__(self, cfg: SegmentationConfig): | |
super().__init__() | |
self.cfg = cfg | |
self.subsample_rate = cfg.subsample_rate | |
def pre_segment(self, dense_x, dense_padding_mask): | |
return dense_x, dense_padding_mask | |
def logit_segment(self, logits, padding_mask): | |
return logits, padding_mask | |
class RandomSegmenter(Segmenter): | |
def pre_segment(self, dense_x, dense_padding_mask): | |
target_num = math.ceil(dense_x.size(1) * self.subsample_rate) | |
ones = torch.ones(dense_x.shape[:-1], device=dense_x.device) | |
indices, _ = ones.multinomial(target_num).sort(dim=-1) | |
indices_ld = indices.unsqueeze(-1).expand(-1, -1, dense_x.size(-1)) | |
dense_x = dense_x.gather(1, indices_ld) | |
dense_padding_mask = dense_padding_mask.gather(1, index=indices) | |
return dense_x, dense_padding_mask | |
class UniformRandomSegmenter(Segmenter): | |
def pre_segment(self, dense_x, dense_padding_mask): | |
bsz, tsz, fsz = dense_x.shape | |
target_num = math.ceil(tsz * self.subsample_rate) | |
rem = tsz % target_num | |
if rem > 0: | |
dense_x = F.pad(dense_x, [0, 0, 0, target_num - rem]) | |
dense_padding_mask = F.pad( | |
dense_padding_mask, [0, target_num - rem], value=True | |
) | |
dense_x = dense_x.view(bsz, target_num, -1, fsz) | |
dense_padding_mask = dense_padding_mask.view(bsz, target_num, -1) | |
if self.cfg.mean_pool: | |
dense_x = dense_x.mean(dim=-2) | |
dense_padding_mask = dense_padding_mask.all(dim=-1) | |
else: | |
ones = torch.ones((bsz, dense_x.size(2)), device=dense_x.device) | |
indices = ones.multinomial(1) | |
indices = indices.unsqueeze(-1).expand(-1, target_num, -1) | |
indices_ld = indices.unsqueeze(-1).expand(-1, -1, -1, fsz) | |
dense_x = dense_x.gather(2, indices_ld).reshape(bsz, -1, fsz) | |
dense_padding_mask = dense_padding_mask.gather(2, index=indices).reshape( | |
bsz, -1 | |
) | |
return dense_x, dense_padding_mask | |
class JoinSegmenter(Segmenter): | |
def logit_segment(self, logits, padding_mask): | |
preds = logits.argmax(dim=-1) | |
if padding_mask.any(): | |
preds[padding_mask] = -1 # mark pad | |
uniques = [] | |
bsz, tsz, csz = logits.shape | |
for p in preds: | |
uniques.append( | |
p.cpu().unique_consecutive(return_inverse=True, return_counts=True) | |
) | |
new_tsz = max(u[0].numel() for u in uniques) | |
new_logits = logits.new_zeros(bsz, new_tsz, csz) | |
new_pad = padding_mask.new_zeros(bsz, new_tsz) | |
for b in range(bsz): | |
u, idx, c = uniques[b] | |
keep = u != -1 | |
if self.cfg.remove_zeros: | |
keep.logical_and_(u != 0) | |
if self.training and not self.cfg.mean_pool_join: | |
u[0] = 0 | |
u[1:] = c.cumsum(0)[:-1] | |
m = c > 1 | |
r = torch.rand(m.sum()) | |
o = (c[m] * r).long() | |
u[m] += o | |
new_logits[b, : u.numel()] = logits[b, u] | |
else: | |
new_logits[b].index_add_( | |
dim=0, index=idx.to(new_logits.device), source=logits[b] | |
) | |
new_logits[b, : c.numel()] /= c.unsqueeze(-1).to(new_logits.device) | |
new_sz = keep.sum() | |
if not keep.all(): | |
kept_logits = new_logits[b, : c.numel()][keep] | |
new_logits[b, :new_sz] = kept_logits | |
if new_sz < new_tsz: | |
pad = new_tsz - new_sz | |
new_logits[b, -pad:] = 0 | |
new_pad[b, -pad:] = True | |
return new_logits, new_pad | |
class UniformRandomJoinSegmenter(UniformRandomSegmenter, JoinSegmenter): | |
pass | |
SEGMENT_FACTORY = { | |
SegmentationType.NONE: Segmenter, | |
SegmentationType.RANDOM: RandomSegmenter, | |
SegmentationType.UNIFORM_RANDOM: UniformRandomSegmenter, | |
SegmentationType.UNIFORM_RANDOM_JOIN: UniformRandomJoinSegmenter, | |
SegmentationType.JOIN: JoinSegmenter, | |
} | |
class Discriminator(nn.Module): | |
def __init__(self, dim, cfg: Wav2vec_UConfig): | |
super().__init__() | |
inner_dim = cfg.discriminator_dim | |
kernel = cfg.discriminator_kernel | |
dilation = cfg.discriminator_dilation | |
self.max_pool = cfg.discriminator_max_pool | |
if cfg.discriminator_causal: | |
padding = kernel - 1 | |
else: | |
padding = kernel // 2 | |
def make_conv(in_d, out_d, k, p=0, has_dilation=True): | |
conv = nn.Conv1d( | |
in_d, | |
out_d, | |
kernel_size=k, | |
padding=p, | |
dilation=dilation if has_dilation else 1, | |
) | |
if cfg.discriminator_spectral_norm: | |
conv = nn.utils.spectral_norm(conv) | |
elif cfg.discriminator_weight_norm: | |
conv = nn.utils.weight_norm(conv) | |
return conv | |
inner_net = [ | |
nn.Sequential( | |
make_conv(inner_dim, inner_dim, kernel, padding), | |
SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), | |
nn.Dropout(cfg.discriminator_dropout), | |
nn.GELU(), | |
) | |
for _ in range(cfg.discriminator_depth - 1) | |
] + [ | |
make_conv(inner_dim, 1, kernel, padding, has_dilation=False), | |
SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), | |
] | |
if cfg.discriminator_linear_emb: | |
emb_net = [make_conv(dim, inner_dim, 1)] | |
else: | |
emb_net = [ | |
make_conv(dim, inner_dim, kernel, padding), | |
SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), | |
] | |
if cfg.discriminator_act_after_linear: | |
emb_net.append(nn.GELU()) | |
self.net = nn.Sequential( | |
*emb_net, | |
nn.Dropout(cfg.discriminator_dropout), | |
*inner_net, | |
) | |
def forward(self, x, padding_mask): | |
x = x.transpose(1, 2) # BTC -> BCT | |
x = self.net(x) | |
x = x.transpose(1, 2) | |
x_sz = x.size(1) | |
if padding_mask is not None and padding_mask.any() and padding_mask.dim() > 1: | |
padding_mask = padding_mask[:, : x.size(1)] | |
x[padding_mask] = float("-inf") if self.max_pool else 0 | |
x_sz = x_sz - padding_mask.sum(dim=-1) | |
x = x.squeeze(-1) | |
if self.max_pool: | |
x, _ = x.max(dim=-1) | |
else: | |
x = x.sum(dim=-1) | |
x = x / x_sz | |
return x | |
class Generator(nn.Module): | |
def __init__(self, input_dim, output_dim, cfg: Wav2vec_UConfig): | |
super().__init__() | |
self.cfg = cfg | |
self.output_dim = output_dim | |
self.stride = cfg.generator_stride | |
self.dropout = nn.Dropout(cfg.generator_dropout) | |
padding = cfg.generator_kernel // 2 | |
self.proj = nn.Sequential( | |
TransposeLast(), | |
nn.Conv1d( | |
input_dim, | |
output_dim, | |
kernel_size=cfg.generator_kernel, | |
stride=cfg.generator_stride, | |
dilation=cfg.generator_dilation, | |
padding=padding, | |
bias=cfg.generator_bias, | |
), | |
TransposeLast(), | |
) | |
def forward(self, dense_x, tokens, dense_padding_mask): | |
dense_x = self.dropout(dense_x) | |
dense_x = self.proj(dense_x) | |
if self.stride > 1: | |
dense_padding_mask = dense_padding_mask[:, :: self.stride] | |
if dense_padding_mask.size(1) != dense_x.size(1): | |
new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1]) | |
diff = new_padding.size(1) - dense_padding_mask.size(1) | |
assert ( | |
diff > 0 | |
), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}" | |
if diff > 0: | |
new_padding[:, diff:] = dense_padding_mask | |
else: | |
assert diff < 0 | |
new_padding = dense_padding_mask[:, :diff] | |
dense_padding_mask = new_padding | |
result = {} | |
token_x = None | |
if tokens is not None: | |
token_x = dense_x.new_zeros(tokens.numel(), self.output_dim) | |
token_x.scatter_(1, tokens.view(-1, 1).long(), 1) | |
token_x = token_x.view(tokens.shape + (self.output_dim,)) | |
result["dense_x"] = dense_x | |
result["token_x"] = token_x | |
result["dense_padding_mask"] = dense_padding_mask | |
return result | |
class Wav2vec_U(BaseFairseqModel): | |
def calc_gradient_penalty(self, real_data, fake_data): | |
b_size = min(real_data.size(0), fake_data.size(0)) | |
t_size = min(real_data.size(1), fake_data.size(1)) | |
if self.cfg.probabilistic_grad_penalty_slicing: | |
def get_slice(data, dim, target_size): | |
size = data.size(dim) | |
diff = size - target_size | |
if diff <= 0: | |
return data | |
start = np.random.randint(0, diff + 1) | |
return data.narrow(dim=dim, start=start, length=target_size) | |
real_data = get_slice(real_data, 0, b_size) | |
real_data = get_slice(real_data, 1, t_size) | |
fake_data = get_slice(fake_data, 0, b_size) | |
fake_data = get_slice(fake_data, 1, t_size) | |
else: | |
real_data = real_data[:b_size, :t_size] | |
fake_data = fake_data[:b_size, :t_size] | |
alpha = torch.rand(real_data.size(0), 1, 1) | |
alpha = alpha.expand(real_data.size()) | |
alpha = alpha.to(real_data.device) | |
interpolates = alpha * real_data + ((1 - alpha) * fake_data) | |
disc_interpolates = self.discriminator(interpolates, None) | |
gradients = autograd.grad( | |
outputs=disc_interpolates, | |
inputs=interpolates, | |
grad_outputs=torch.ones(disc_interpolates.size(), device=real_data.device), | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True, | |
)[0] | |
gradient_penalty = (gradients.norm(2, dim=1) - 1) ** 2 | |
return gradient_penalty | |
def set_num_updates(self, num_updates): | |
super().set_num_updates(num_updates) | |
self.update_num = num_updates | |
self.curr_temp = max( | |
self.max_temp * self.temp_decay ** num_updates, self.min_temp | |
) | |
def discrim_step(self, num_updates): | |
return num_updates % 2 == 1 | |
def get_groups_for_update(self, num_updates): | |
return "discriminator" if self.discrim_step(num_updates) else "generator" | |
def __init__(self, cfg: Wav2vec_UConfig, target_dict): | |
super().__init__() | |
self.cfg = cfg | |
self.zero_index = target_dict.index("<SIL>") if "<SIL>" in target_dict else 0 | |
self.smoothness_weight = cfg.smoothness_weight | |
output_size = len(target_dict) | |
self.pad = target_dict.pad() | |
self.eos = target_dict.eos() | |
self.smoothing = cfg.smoothing | |
self.smoothing_one_sided = cfg.smoothing_one_sided | |
self.no_softmax = cfg.no_softmax | |
self.gumbel = cfg.gumbel | |
self.hard_gumbel = cfg.hard_gumbel | |
self.last_acc = None | |
self.gradient_penalty = cfg.gradient_penalty | |
self.code_penalty = cfg.code_penalty | |
self.blank_weight = cfg.blank_weight | |
self.blank_mode = cfg.blank_mode | |
self.blank_index = target_dict.index("<SIL>") if cfg.blank_is_sil else 0 | |
assert self.blank_index != target_dict.unk() | |
self.discriminator = Discriminator(output_size, cfg) | |
for p in self.discriminator.parameters(): | |
p.param_group = "discriminator" | |
self.pca_A = self.pca_b = None | |
d = cfg.input_dim | |
self.segmenter = SEGMENT_FACTORY[cfg.segmentation.type](cfg.segmentation) | |
self.generator = Generator(d, output_size, cfg) | |
for p in self.generator.parameters(): | |
p.param_group = "generator" | |
for p in self.segmenter.parameters(): | |
p.param_group = "generator" | |
self.max_temp, self.min_temp, self.temp_decay = cfg.temp | |
self.curr_temp = self.max_temp | |
self.update_num = 0 | |
def build_model(cls, cfg, task): | |
return cls(cfg, task.target_dictionary) | |
def get_logits( | |
self, | |
net_output: Optional[Dict[str, List[Optional[torch.Tensor]]]], | |
normalize: bool = False, | |
): | |
logits = net_output["logits"] | |
if self.blank_weight != 0: | |
if self.blank_mode == "add": | |
logits[..., self.blank_index] += self.blank_weight | |
elif self.blank_mode == "set": | |
logits[..., self.blank_index] = self.blank_weight | |
else: | |
raise Exception(f"invalid blank mode {self.blank_mode}") | |
padding = net_output["padding_mask"] | |
if padding.any(): | |
logits[padding] = float("-inf") | |
logits[padding][..., self.blank_index] = float("inf") | |
if normalize: | |
logits = utils.log_softmax(logits.float(), dim=-1) | |
return logits.transpose(0, 1) | |
def get_normalized_probs( | |
self, | |
net_output: Tuple[ | |
torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]] | |
], | |
log_probs: bool, | |
sample: Optional[Dict[str, torch.Tensor]] = None, | |
): | |
logits = self.get_logits(net_output) | |
probs = super().get_normalized_probs(logits, log_probs, sample) | |
# BTC -> TBC for ctc | |
probs = probs.transpose(0, 1) | |
return probs | |
def normalize(self, dense_x): | |
bsz, tsz, csz = dense_x.shape | |
if dense_x.numel() == 0: | |
raise Exception(dense_x.shape) | |
_, k = dense_x.max(-1) | |
hard_x = ( | |
dense_x.new_zeros(bsz * tsz, csz) | |
.scatter_(-1, k.view(-1, 1), 1.0) | |
.view(-1, csz) | |
) | |
hard_probs = torch.mean(hard_x.float(), dim=0) | |
code_perplexity = torch.exp( | |
-torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) | |
) | |
avg_probs = torch.softmax(dense_x.reshape(-1, csz).float(), dim=-1).mean(dim=0) | |
prob_perplexity = torch.exp( | |
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) | |
) | |
if not self.no_softmax: | |
if self.training and self.gumbel: | |
dense_x = F.gumbel_softmax( | |
dense_x.float(), tau=self.curr_temp, hard=self.hard_gumbel | |
).type_as(dense_x) | |
else: | |
dense_x = dense_x.softmax(-1) | |
return dense_x, code_perplexity, prob_perplexity | |
def forward( | |
self, | |
features, | |
padding_mask, | |
random_label=None, | |
dense_x_only=False, | |
segment=True, | |
): | |
if segment: | |
features, padding_mask = self.segmenter.pre_segment(features, padding_mask) | |
orig_size = features.size(0) * features.size(1) - padding_mask.sum() | |
gen_result = self.generator(features, random_label, padding_mask) | |
orig_dense_x, token_x = gen_result["dense_x"], gen_result["token_x"] | |
orig_dense_padding_mask = gen_result["dense_padding_mask"] | |
if segment: | |
dense_x, dense_padding_mask = self.segmenter.logit_segment( | |
orig_dense_x, orig_dense_padding_mask | |
) | |
else: | |
dense_x = orig_dense_x | |
dense_padding_mask = orig_dense_padding_mask | |
dense_logits = dense_x | |
prob_perplexity = None | |
code_perplexity = None | |
if not (self.no_softmax and dense_x_only): | |
dense_x, code_perplexity, prob_perplexity = self.normalize(dense_logits) | |
if dense_x_only or self.discriminator is None: | |
return { | |
"logits": dense_x, | |
"padding_mask": dense_padding_mask, | |
} | |
token_padding_mask = random_label == self.pad | |
dense_y = self.discriminator(dense_x, dense_padding_mask) | |
token_y = self.discriminator(token_x, token_padding_mask) | |
sample_size = features.size(0) | |
d_step = self.discrim_step(self.update_num) | |
fake_smooth = self.smoothing | |
real_smooth = self.smoothing | |
if self.smoothing_one_sided: | |
fake_smooth = 0 | |
zero_loss = None | |
smoothness_loss = None | |
code_pen = None | |
if d_step: | |
loss_dense = F.binary_cross_entropy_with_logits( | |
dense_y, | |
dense_y.new_ones(dense_y.shape) - fake_smooth, | |
reduction="sum", | |
) | |
loss_token = F.binary_cross_entropy_with_logits( | |
token_y, | |
token_y.new_zeros(token_y.shape) + real_smooth, | |
reduction="sum", | |
) | |
if self.training and self.gradient_penalty > 0: | |
grad_pen = self.calc_gradient_penalty(token_x, dense_x) | |
grad_pen = grad_pen.sum() * self.gradient_penalty | |
else: | |
grad_pen = None | |
else: | |
grad_pen = None | |
loss_token = None | |
loss_dense = F.binary_cross_entropy_with_logits( | |
dense_y, | |
dense_y.new_zeros(dense_y.shape) + fake_smooth, | |
reduction="sum", | |
) | |
num_vars = dense_x.size(-1) | |
if prob_perplexity is not None: | |
code_pen = (num_vars - prob_perplexity) / num_vars | |
code_pen = code_pen * sample_size * self.code_penalty | |
if self.smoothness_weight > 0: | |
smoothness_loss = F.mse_loss( | |
dense_logits[:, :-1], dense_logits[:, 1:], reduction="none" | |
) | |
smoothness_loss[dense_padding_mask[:, 1:]] = 0 | |
smoothness_loss = ( | |
smoothness_loss.mean() * sample_size * self.smoothness_weight | |
) | |
result = { | |
"losses": { | |
"grad_pen": grad_pen, | |
"code_pen": code_pen, | |
"smoothness": smoothness_loss, | |
}, | |
"temp": self.curr_temp, | |
"code_ppl": code_perplexity, | |
"prob_ppl": prob_perplexity, | |
"d_steps": int(d_step), | |
"sample_size": sample_size, | |
} | |
suff = "_d" if d_step else "_g" | |
result["losses"]["dense" + suff] = loss_dense | |
result["losses"]["token" + suff] = loss_token | |
return result | |