Spaces:
Sleeping
Sleeping
import sys | |
sys.path.append('../') | |
from typing import Optional | |
from copy import deepcopy | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor, WhisperFeatureExtractor, WhisperModel | |
# from .modeling_whisper import WhisperModel | |
from my_laion_clap.CLAP.src.laion_clap.clap_module.htsat import create_htsat_model | |
import torch | |
import torchaudio | |
import torchaudio.transforms as T | |
import numpy as np | |
from torch import nn | |
import torchvision.transforms | |
from contextlib import suppress | |
try: | |
from .flamingo import Flamingo | |
from .flamingo_lm import FlamingoLMMixin | |
from .utils import extend_instance | |
except: | |
from flamingo import Flamingo | |
from flamingo_lm import FlamingoLMMixin | |
from utils import extend_instance | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def int16_to_float32_torch(x): | |
return (x / 32767.0).type(torch.float32) | |
def float32_to_int16_torch(x): | |
x = torch.clamp(x, min=-1., max=1.) | |
return (x * 32767.).type(torch.int16) | |
class CLAPAudioCfp: | |
model_type: str = "HTSAT" | |
model_name: str = "large" | |
sample_rate: int = 16000 | |
audio_length: int = 1024 | |
window_size: int = 1024 | |
hop_size: int = 160 | |
fmin: int = 50 | |
fmax: int = 14000 | |
class_num: int = 527 | |
mel_bins: int = 64 | |
clip_samples: int = 160000 | |
class CLAP(nn.Module): | |
def __init__(self, clap_config): | |
super(CLAP, self).__init__() | |
self.clap_config = clap_config | |
self.method = clap_config["method"] | |
device_id = f'cuda:{torch.cuda.current_device()}' | |
if ('finetune' in clap_config) and clap_config['finetune']: | |
self.finetune = True | |
print('Finetuning CLAP encoder as well!') | |
else: | |
self.finetune = False | |
audio_cfg = CLAPAudioCfp() | |
enable_fusion = True | |
fusion_type = "aff_2d" | |
self.nvclap = create_htsat_model(audio_cfg, enable_fusion, fusion_type) | |
clap_state_dict = torch.load(clap_config["checkpoint"], map_location = 'cpu') | |
clap_state_dict_copy = clap_state_dict['state_dict'].copy() | |
for key in list(clap_state_dict['state_dict'].keys()): | |
if 'audio' in key: | |
clap_state_dict_copy[key.replace('module.audio_branch.','')] = clap_state_dict_copy[key] | |
del clap_state_dict_copy[key] | |
else: | |
del clap_state_dict_copy[key] | |
self.nvclap.load_state_dict(clap_state_dict_copy, strict = False) | |
self.nvclap = self.nvclap.to(device_id) | |
for param in self.nvclap.parameters(): | |
param.requires_grad = self.finetune | |
if self.finetune: | |
self.nvclap.train() | |
else: | |
self.nvclap.eval() | |
print('loaded NVCLAP model: {}'.format(clap_config["checkpoint"])) | |
def get_mel(self, audio_data): | |
# mel shape: (n_mels, T) | |
mel_tf = torchaudio.transforms.MelSpectrogram( | |
sample_rate=16000, | |
n_fft=1024, | |
win_length=1024, | |
hop_length=160, | |
center=True, | |
pad_mode="reflect", | |
power=2.0, | |
norm=None, | |
onesided=True, | |
n_mels=64, | |
f_min=50, | |
f_max=14000 | |
).to(audio_data.device) | |
mel = mel_tf(audio_data) | |
# we use log mel spectrogram as input | |
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) | |
return mel.T # (T, n_mels) | |
def get_audio_features(self, sample, audio_data, max_len, data_truncating, data_filling, require_grad=False): | |
grad_fn = suppress if require_grad else torch.no_grad | |
with grad_fn(): | |
if len(audio_data) > max_len: | |
if data_truncating == "rand_trunc": | |
longer = torch.tensor([True]) | |
elif data_truncating == "fusion": | |
# fusion | |
mel = self.get_mel(audio_data) | |
# split to three parts | |
chunk_frames = max_len // 160 + 1 # the +1 related to how the spectrogram is computed | |
total_frames = mel.shape[0] | |
if chunk_frames == total_frames: | |
# there is a corner case where the audio length is | |
# larger than max_len but smaller than max_len+hop_size. | |
# In this case, we just use the whole audio. | |
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) | |
sample["mel_fusion"] = mel_fusion | |
longer = torch.tensor([False]) | |
else: | |
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) | |
if len(ranges[1]) == 0: | |
# if the audio is too short, we just use the first chunk | |
ranges[1] = [0] | |
if len(ranges[2]) == 0: | |
# if the audio is too short, we just use the first chunk | |
ranges[2] = [0] | |
# randomly choose index for each part | |
idx_front = np.random.choice(ranges[0]) | |
idx_middle = np.random.choice(ranges[1]) | |
idx_back = np.random.choice(ranges[2]) | |
# select mel | |
mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] | |
mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] | |
mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] | |
# shrink the mel | |
mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] | |
# logging.info(f"mel_shrink.shape: {mel_shrink.shape}") | |
# stack | |
mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) | |
sample["mel_fusion"] = mel_fusion | |
longer = torch.tensor([True]) | |
else: | |
raise NotImplementedError( | |
f"data_truncating {data_truncating} not implemented" | |
) | |
# random crop to max_len (for compatibility) | |
overflow = len(audio_data) - max_len | |
idx = np.random.randint(0, overflow + 1) | |
audio_data = audio_data[idx: idx + max_len] | |
else: # padding if too short | |
if len(audio_data) < max_len: # do nothing if equal | |
if data_filling == "repeatpad": | |
n_repeat = int(max_len / len(audio_data)) | |
audio_data = audio_data.repeat(n_repeat) | |
# audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
# audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] | |
audio_data = F.pad( | |
audio_data, | |
(0, max_len - len(audio_data)), | |
mode="constant", | |
value=0, | |
) | |
elif data_filling == "pad": | |
audio_data = F.pad( | |
audio_data, | |
(0, max_len - len(audio_data)), | |
mode="constant", | |
value=0, | |
) | |
elif data_filling == "repeat": | |
n_repeat = int(max_len / len(audio_data)) | |
audio_data = audio_data.repeat(n_repeat + 1)[:max_len] | |
else: | |
raise NotImplementedError( | |
f"data_filling {data_filling} not implemented" | |
) | |
if data_truncating == 'fusion': | |
mel = self.get_mel(audio_data) | |
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) | |
sample["mel_fusion"] = mel_fusion | |
longer = torch.tensor([False]) | |
sample["longer"] = longer | |
sample["waveform"] = audio_data | |
return sample | |
def load_audio(self, clips): | |
# waveform, sr = torchaudio.load(filename) | |
# waveform = torchaudio.functional.resample(waveform, orig_freq=self.clap_config['sampling_rate'], new_freq=16000) | |
processed_clips = [] | |
for clip in clips: | |
audio_data = int16_to_float32_torch(float32_to_int16_torch(clip)) | |
sample = self.get_audio_features({}, audio_data, 160000, "fusion", "repeatpad") | |
processed_clips.append(sample) | |
waveforms = {} | |
waveforms["mel_fusion"] = torch.stack([item["mel_fusion"] for item in processed_clips], dim=0) | |
waveforms["longer"] = torch.stack([item["longer"] for item in processed_clips], dim=0) | |
waveforms["waveform"] = torch.stack([item["waveform"] for item in processed_clips], dim=0) | |
return waveforms | |
def forward(self, audio_clips): | |
# It will handle various segments, 1 audio will have various segments [B X n_segments X time] | |
# expand batch dimension during inference | |
if len(audio_clips.shape) == 2: | |
audio_clips = audio_clips.unsqueeze(0) | |
assert len(audio_clips.shape) == 3 | |
audio_embeds = [] | |
for audio_clip in audio_clips: | |
audio = self.load_audio(audio_clip) | |
audio_embed = self.nvclap(audio) #.reshape(-1, self.clap_config["audio_embed_dim"]) | |
audio_embeds.append(audio_embed) | |
audio_embeds = torch.stack(audio_embeds, dim=0) | |
# audio_embeds.requires_grad = self.finetune | |
return audio_embeds | |
class Whisper(nn.Module): | |
def __init__(self, whisper_config): | |
super(Whisper, self).__init__() | |
self.whisper_config = whisper_config | |
self.method = self.whisper_config["method"] | |
device_id = f'cuda:{torch.cuda.current_device()}' | |
if ('finetune' in self.whisper_config) and self.whisper_config['finetune']: | |
self.finetune = True | |
print('Finetuning Whisper encoder as well!') | |
else: | |
self.finetune = False | |
self.whisper = WhisperModel.from_pretrained(self.whisper_config['path']).encoder | |
self.whisper = self.whisper.to(device_id) | |
self.wav_processor = WhisperFeatureExtractor.from_pretrained(self.whisper_config['path']) | |
for param in self.whisper.parameters(): | |
param.requires_grad = self.finetune | |
if self.finetune: | |
self.whisper.train() | |
else: | |
self.whisper.eval() | |
print('loaded Whisper model: {}'.format(self.whisper_config['path'])) | |
def load_audio(self, clips): | |
device_id = f'cuda:{torch.cuda.current_device()}' | |
sample = self.wav_processor(clips.cpu().numpy(), sampling_rate=self.whisper_config['sampling_rate'], return_tensors="pt")["input_features"].to(device_id) | |
return sample | |
def forward(self, audio_clips): | |
# It will handle various segments, 1 audio will have various segments [batch X n_segments X time] | |
if len(audio_clips.shape) == 2: | |
audio_clips = audio_clips.unsqueeze(0) | |
assert len(audio_clips.shape) == 3 | |
audio_embeds = [] | |
for audio_clip in audio_clips: | |
audio = self.load_audio(audio_clip) | |
audio_embed = self.whisper(audio).last_hidden_state #.reshape(-1, self.whisper_config["audio_embed_dim"]) | |
audio_embeds.append(audio_embed) | |
audio_embeds = torch.stack(audio_embeds, dim=0) | |
# audio_embeds.requires_grad = self.finetune | |
return audio_embeds | |
class MERT(nn.Module): | |
def __init__(self, mert_config): | |
super(MERT, self).__init__() | |
self.mert_config = mert_config | |
self.method = mert_config["method"] | |
device_id = f'cuda:{torch.cuda.current_device()}' | |
if ('finetune' in mert_config) and mert_config['finetune']: | |
self.finetune = True | |
print('Finetuning MERT encoder as well!') | |
else: | |
self.finetune = False | |
self.mert = AutoModel.from_pretrained(mert_config['path'], trust_remote_code=True) | |
self.mert = self.mert.to(device_id) | |
self.resampler = T.Resample(16000, mert_config['sampling_rate']).to(device_id) | |
self.wav_processor = Wav2Vec2FeatureExtractor.from_pretrained(mert_config['path'],trust_remote_code=True) | |
for param in self.mert.parameters(): | |
param.requires_grad = self.finetune | |
if self.finetune: | |
self.mert.train() | |
else: | |
self.mert.eval() | |
print('loaded MERT model: {}'.format(mert_config['path'])) | |
def load_audio(self, clips): | |
device_id = f'cuda:{torch.cuda.current_device()}' | |
clips = self.resampler(clips.float()).float() | |
sample = self.wav_processor(clips, sampling_rate=self.mert_config['sampling_rate'], return_tensors="pt")["input_values"] | |
if len(sample.shape) == 1: | |
sample = sample.unsqueeze(0) | |
return sample.to(device_id) | |
def forward(self, audio_clips): | |
# It will handle various segments, 1 audio will have various segments [batch X n_segments X time] | |
if len(audio_clips.shape) == 2: | |
audio_clips = audio_clips.unsqueeze(0) | |
assert len(audio_clips.shape) == 3 | |
audio_embeds = [] | |
for audio_clip in audio_clips: | |
audio = self.load_audio(audio_clip).to(torch.bfloat16) # all processing happens in float | |
if len(audio.shape) > 2: | |
audio = audio.squeeze(0) | |
audio_embed = self.mert(audio, output_hidden_states=True).last_hidden_state #.reshape(-1, self.mert_config["audio_embed_dim"]) | |
audio_embeds.append(audio_embed) | |
audio_embeds = torch.stack(audio_embeds, dim=0) | |
audio_embeds.requires_grad = self.finetune | |
return audio_embeds | |
def create_model_and_transforms( | |
clap_config: dict, | |
lang_encoder_path: str, | |
tokenizer_path: str, | |
audio_transformer_kwargs: dict, | |
cross_attn_every_n_layers: int = 1, | |
use_local_files: bool = False, | |
decoder_layers_attr_name: str = None, | |
freeze_lm_embeddings: bool = False, | |
unfreeze_full_lm: bool = False, | |
cache_dir: Optional[str] = None, | |
**flamingo_kwargs, | |
): | |
clap = CLAP(clap_config) | |
text_tokenizer = AutoTokenizer.from_pretrained( | |
tokenizer_path, | |
local_files_only=use_local_files, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
text_tokenizer.add_special_tokens( | |
{"additional_special_tokens": ["<audio>", "<|endofchunk|>", "<|PAD_TOKEN|>"]} | |
) | |
text_tokenizer.pad_token = None | |
text_tokenizer.pad_token_id = None | |
text_tokenizer.pad_token = "<|PAD_TOKEN|>" | |
text_tokenizer.pad_token_id = text_tokenizer.encode("<|PAD_TOKEN|>")[-1] | |
if text_tokenizer.sep_token is None: | |
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
lang_encoder = AutoModelForCausalLM.from_pretrained( | |
lang_encoder_path, | |
local_files_only=use_local_files, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
extend_instance(lang_encoder, FlamingoLMMixin) | |
if decoder_layers_attr_name is None: | |
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |
if ('finetune' in clap_config) and clap_config['finetune']: | |
unfreeze_clap = True | |
else: | |
unfreeze_clap = False | |
model = Flamingo( | |
clap, | |
unfreeze_clap, | |
lang_encoder, | |
text_tokenizer.encode("<|endofchunk|>")[-1], | |
text_tokenizer.encode("<audio>")[-1], | |
text_tokenizer.sep_token_id, | |
clap_embed_dim = clap_config["audio_embed_dim"], | |
audio_transformer_kwargs=audio_transformer_kwargs, | |
cross_attn_every_n_layers=cross_attn_every_n_layers, | |
**flamingo_kwargs, | |
) | |
model.requires_grad_(False) | |
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | |
model.audio_transformer_clap.requires_grad_(True) | |
model.lang_encoder.gated_cross_attn_layers_sound.requires_grad_(True) | |
if not freeze_lm_embeddings: | |
model.lang_encoder.get_input_embeddings().requires_grad_(True) | |
if unfreeze_full_lm: | |
model.lang_encoder.requires_grad_(True) | |
if unfreeze_clap: | |
model.clap.requires_grad_(True) | |
print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format( | |
sum(p.numel() for p in model.parameters() if p.requires_grad), | |
sum(p.numel() for p in model.audio_transformer_clap.parameters() if p.requires_grad), | |
sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad), | |
)) | |
return model, text_tokenizer | |
def _infer_decoder_layers_attr_name(model): | |
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | |
if k.lower() in model.__class__.__name__.lower(): | |
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | |
raise ValueError( | |
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." | |
) | |
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { | |
"opt": "model.decoder.layers", | |
"gptj": "transformer.h", | |
"gpt-j": "transformer.h", | |
"pythia": "gpt_neox.layers", | |
"llama": "model.layers", | |
"gptneoxforcausallm": "gpt_neox.layers", | |
"mpt": "transformer.blocks", | |
"mosaicgpt": "transformer.blocks", | |
"qwen": "model.layers", | |
} | |
if __name__ == '__main__': | |
import torch | |
torch.set_printoptions(profile="full") # only in debug mode | |
import sys | |
sys.path.append('../') | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
import yaml | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-c', '--config', type=str, default='../configs/config.yaml', help='yaml config path') | |
args = parser.parse_args() | |
config = yaml.load(open(args.config), Loader=yaml.FullLoader) | |
data_config = config['data_config'] | |
model_config = config["model_config"] | |
clap_config = config["clap_config"] | |
model, tokenizer = create_model_and_transforms( | |
**model_config, | |
clap_config=clap_config, | |
use_local_files=False, | |
gradient_checkpointing=True, | |
freeze_lm_embeddings=True | |
) | |
model = model.cuda() | |
from data.data import AudioTextData, DataCollator | |
from torch.utils.data import DataLoader | |
batch_size = 8 | |
trainset = AudioTextData( | |
**data_config, clap_config=clap_config, tokenizer=tokenizer, | |
epoch=1, force_reblend=True | |
) | |
data_collator = DataCollator(tokenizer) | |
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, num_workers=4) | |
for step, batch in enumerate(trainloader): | |
audio_clips = batch["audio_clips"].cuda() | |
audio_embed_mask = batch["audio_embed_mask"].cuda() | |
input_ids = batch["input_ids"].cuda() | |
attention_mask = batch["attention_mask"].cuda() | |
print('batch {}:'.format(step+1), audio_clips.shape, audio_embed_mask.shape, input_ids.shape, attention_mask.shape) | |
labels = input_ids.clone() | |
labels[labels == tokenizer.pad_token_id] = -100 | |
labels[:, :2] = -100 | |
labels[labels == tokenizer.encode("<audio>")[-1]] = -100 | |
sep_locations = labels == tokenizer.sep_token_id | |
endofchunk_token_id = tokenizer.encode("<|endofchunk|>")[-1] | |
eoc_locations = labels == endofchunk_token_id | |
if not all(sep_locations.sum(dim=1) == eoc_locations.sum(dim=1)): | |
print("Warning: sep loc {} but eoc loc {}".format(sep_locations.sum(dim=1), eoc_locations.sum(dim=1))) | |
for input_id in labels: | |
input_id[input_id==-100] = tokenizer.encode("-")[-1] | |
print(input_id, '\n', tokenizer.decode(input_id)) | |
for i in range(labels.shape[0]): | |
shouldmask = True | |
for j in range(labels.shape[1]): | |
if shouldmask and (labels[i][j] != tokenizer.eos_token_id): | |
masked_value = -100 | |
else: | |
masked_value = labels[i][j] | |
if labels[i][j] == tokenizer.sep_token_id: | |
shouldmask = False | |
elif labels[i][j] == endofchunk_token_id: | |
shouldmask = True | |
labels[i][j] = masked_value | |
if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]: | |
debug_masked_labels_in_the_end = [] | |
for j in range(labels.shape[1]-1, -1, -1): | |
if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]: | |
debug_masked_labels_in_the_end.insert(0, deepcopy(labels[i][j].item())) | |
labels[i][j] = -100 | |
else: | |
break | |
print('hit max_token and masking ids from the end:', \ | |
tokenizer.decode(torch.LongTensor(debug_masked_labels_in_the_end).to(labels.device)) | |
) | |
if step == 50: | |
break | |