Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from .modules import AudioEncoder | |
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig | |
class BartCaptionModel(nn.Module): | |
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768): | |
super(BartCaptionModel, self).__init__() | |
# non-finetunning case | |
bart_config = BartConfig.from_pretrained(bart_type) | |
self.tokenizer = BartTokenizer.from_pretrained(bart_type) | |
self.bart = BartForConditionalGeneration(bart_config) | |
self.n_sample = sr * duration | |
self.hop_length = int(0.01 * sr) # hard coding hop_size | |
self.n_frames = int(self.n_sample // self.hop_length) | |
self.num_of_stride_conv = num_of_conv - 1 | |
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1 | |
self.audio_encoder = AudioEncoder( | |
n_mels = n_mels, # hard coding n_mel | |
n_ctx = self.n_ctx, | |
audio_dim = audio_dim, | |
text_dim = self.bart.config.hidden_size, | |
num_of_stride_conv = self.num_of_stride_conv | |
) | |
self.max_length = max_length | |
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100) | |
def device(self): | |
return list(self.parameters())[0].device | |
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | |
""" | |
Shift input ids one token to the right.ls | |
""" | |
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
shifted_input_ids[:, 0] = decoder_start_token_id | |
if pad_token_id is None: | |
raise ValueError("self.model.config.pad_token_id has to be defined.") | |
# replace possible -100 values in labels by `pad_token_id` | |
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | |
return shifted_input_ids | |
def forward_encoder(self, audio): | |
audio_embs = self.audio_encoder(audio) | |
encoder_outputs = self.bart.model.encoder( | |
input_ids=None, | |
inputs_embeds=audio_embs, | |
return_dict=True | |
)["last_hidden_state"] | |
return encoder_outputs, audio_embs | |
def forward_decoder(self, text, encoder_outputs): | |
text = self.tokenizer(text, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_length, | |
return_tensors="pt") | |
input_ids = text["input_ids"].to(self.device) | |
attention_mask = text["attention_mask"].to(self.device) | |
decoder_targets = input_ids.masked_fill( | |
input_ids == self.tokenizer.pad_token_id, -100 | |
) | |
decoder_input_ids = self.shift_tokens_right( | |
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id | |
) | |
decoder_outputs = self.bart( | |
input_ids=None, | |
attention_mask=None, | |
decoder_input_ids=decoder_input_ids, | |
decoder_attention_mask=attention_mask, | |
inputs_embeds=None, | |
labels=None, | |
encoder_outputs=(encoder_outputs,), | |
return_dict=True | |
) | |
lm_logits = decoder_outputs["logits"] | |
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) | |
return loss | |
def forward(self, audio, text): | |
encoder_outputs, _ = self.forward_encoder(audio) | |
loss = self.forward_decoder(text, encoder_outputs) | |
return loss | |
def generate(self, | |
samples, | |
use_nucleus_sampling=False, | |
num_beams=5, | |
max_length=128, | |
min_length=2, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
): | |
# self.bart.force_bos_token_to_be_generated = True | |
audio_embs = self.audio_encoder(samples) | |
encoder_outputs = self.bart.model.encoder( | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=audio_embs, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=True) | |
input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) | |
input_ids[:, 0] = self.bart.config.decoder_start_token_id | |
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) | |
if use_nucleus_sampling: | |
outputs = self.bart.generate( | |
input_ids=None, | |
attention_mask=None, | |
decoder_input_ids=input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
encoder_outputs=encoder_outputs, | |
max_length=max_length, | |
min_length=min_length, | |
do_sample=True, | |
top_p=top_p, | |
num_return_sequences=1, | |
repetition_penalty=1.1) | |
else: | |
outputs = self.bart.generate(input_ids=None, | |
attention_mask=None, | |
decoder_input_ids=input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
encoder_outputs=encoder_outputs, | |
head_mask=None, | |
decoder_head_mask=None, | |
inputs_embeds=None, | |
decoder_inputs_embeds=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
max_length=max_length, | |
min_length=min_length, | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty) | |
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return captions | |