import random import torchaudio from six import string_types as string_classes import collections import re import torch.nn.functional as F import numpy as np from transformers import AutoTokenizer from wav_evaluation.models.utils import read_config_as_args from wav_evaluation.models.clap import CLAP import math import torchaudio.transforms as T import os import torch # from importlib_resources import files import numpy as np import librosa import torch import laion_clap def int16_to_float32(x): return (x / 32767.0).astype(np.float32) def float32_to_int16(x): x = np.clip(x, a_min=-1., a_max=1.) return (x * 32767.).astype(np.int16) class CLAPWrapper(): """ A class for interfacing CLAP model. """ def __init__(self, model_fp,config_path, use_cuda=False): self.np_str_obj_array_pattern = re.compile(r'[SaUO]') self.file_path = os.path.realpath(__file__) self.default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}") with open(config_path,'r') as f: self.config_as_str = f.read() self.model_fp = model_fp self.use_cuda = use_cuda self.clap, self.tokenizer, self.args = self.load_clap() self.model = laion_clap.CLAP_Module(enable_fusion=False,amodel= 'HTSAT-base') self.model.load_ckpt('/root/autodl-tmp/liuhuadai/CLAP/music_audioset_epoch_15_esc_90.14.pt') # download the default pretrained checkpoint. def load_clap(self): r"""Load CLAP model with args from config file""" args = read_config_as_args(self.config_as_str, is_config_str=True) if 'bert' in args.text_model: self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] else: self.token_keys = ['input_ids', 'attention_mask'] clap = CLAP( audioenc_name=args.audioenc_name, sample_rate=args.sampling_rate, window_size=args.window_size, hop_size=args.hop_size, mel_bins=args.mel_bins, fmin=args.fmin, fmax=args.fmax, classes_num=args.num_classes, out_emb=args.out_emb, text_model=args.text_model, transformer_embed_dim=args.transformer_embed_dim, d_proj=args.d_proj ) # Load pretrained weights for model model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] clap.load_state_dict(model_state_dict, strict=False) clap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model) if self.use_cuda and torch.cuda.is_available(): clap = clap.cuda() return clap, tokenizer, args def default_collate(self, batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError( self.default_collate_err_msg_format.format(elem.dtype)) return self.default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: self.default_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(self.default_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError( 'each element in list of batch should be of equal size') transposed = zip(*batch) return [self.default_collate(samples) for samples in transposed] raise TypeError(self.default_collate_err_msg_format.format(elem_type)) def load_audio_into_tensor(self, audio_path, audio_duration, resample=False): r"""Loads audio file and returns raw audio.""" # Randomly sample a segment of audio_duration from the clip or pad to match duration audio_time_series, sample_rate = torchaudio.load(audio_path) resample_rate = self.args.sampling_rate if resample: resampler = T.Resample(sample_rate, resample_rate) audio_time_series = resampler(audio_time_series) audio_time_series = audio_time_series.reshape(-1) # audio_time_series is shorter than predefined audio duration, # so audio_time_series is extended if audio_duration*sample_rate >= audio_time_series.shape[0]: repeat_factor = int(np.ceil((audio_duration*sample_rate) / audio_time_series.shape[0])) # Repeat audio_time_series by repeat_factor to match audio_duration audio_time_series = audio_time_series.repeat(repeat_factor) # remove excess part of audio_time_series audio_time_series = audio_time_series[0:audio_duration*sample_rate] else: # audio_time_series is longer than predefined audio duration, # so audio_time_series is trimmed start_index = random.randrange( audio_time_series.shape[0] - audio_duration*sample_rate) audio_time_series = audio_time_series[start_index:start_index + audio_duration*sample_rate] return torch.FloatTensor(audio_time_series) def preprocess_audio(self, audio_files, resample): r"""Load list of audio files and return raw audio""" audio_tensors = [] for audio_file in audio_files: audio_tensor = self.load_audio_into_tensor( audio_file, self.args.duration, resample) audio_tensor = audio_tensor.reshape( 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1) audio_tensors.append(audio_tensor) return self.default_collate(audio_tensors) def preprocess_text(self, text_queries): r"""Load list of class labels and return tokenized text""" tokenized_texts = [] for ttext in text_queries: tok = self.tokenizer.encode_plus( text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding="max_length", return_tensors="pt") # max_length=self.args.text_len, padding=True, for key in self.token_keys: tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1) tokenized_texts.append(tok) return self.default_collate(tokenized_texts) def get_text_embeddings(self, class_labels): print('loading text embeddings') print(class_labels) r"""Load list of class labels and return text embeddings""" text_embed = self.model.get_text_embedding(class_labels, use_tensor=True) text_embed = text_embed/torch.norm(text_embed, dim=-1, keepdim=True) # print(text_embed) # print(text_embed.shape) return text_embed def get_audio_embeddings(self, audio_files, resample): r"""Load list of audio files and return a audio embeddings""" print('loading audio embeddings') audio_data, _ = librosa.load(audio_files[0], sr=48000) # sample rate should be 48000 audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T) audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model audio_embed = self.model.get_audio_embedding_from_data(x = audio_data, use_tensor=True) audio_embed = audio_embed/torch.norm(audio_embed, dim=-1, keepdim=True) print(audio_embed[:,-20:]) print(audio_embed.shape) return audio_embed def _get_text_embeddings(self, preprocessed_text): r"""Load preprocessed text and return text embeddings""" with torch.no_grad(): text_embeddings = self.clap.caption_encoder(preprocessed_text) text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True) return text_embeddings def _get_audio_embeddings(self, preprocessed_audio): r"""Load preprocessed audio and return a audio embeddings""" with torch.no_grad(): preprocessed_audio = preprocessed_audio.reshape( preprocessed_audio.shape[0], preprocessed_audio.shape[2]) #Append [0] the audio emebdding, [1] has output class probabilities audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0] audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True) return audio_embeddings def compute_similarity(self, audio_embeddings, text_embeddings,use_logit_scale = True): r"""Compute similarity between text and audio embeddings""" # if use_logit_scale: # logit_scale = self.clap.logit_scale.exp() # similarity = logit_scale*text_embeddings @ audio_embeddings.T # else: # similarity = text_embeddings @ audio_embeddings.T # torch.cosine_similarity(text_embeddings, audio_embeddings) similarity = F.cosine_similarity(text_embeddings, audio_embeddings) print(similarity) return similarity def cal_clap_score(self,txt,audio_path): text_embeddings = self.get_text_embeddings([txt])# 经过了norm的embedding audio_embeddings = self.get_audio_embeddings([audio_path], resample=True)# 这一步比较耗时,读取音频并重采样到44100 score = self.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False).squeeze().cpu().numpy() return score def _generic_batch_inference(self, func, *args): r"""Process audio and/or text per batch""" input_tmp = args[0] batch_size = args[-1] # args[0] has audio_files, args[1] has class_labels inputs = [args[0], args[1]] if len(args) == 3 else [args[0]] args0_len = len(args[0]) # compute text_embeddings once for all the audio_files batches if len(inputs) == 2: text_embeddings = self.get_text_embeddings(args[1]) inputs = [args[0], args[1], text_embeddings] dataset_idx = 0 for _ in range(math.ceil(args0_len/batch_size)): next_batch_idx = dataset_idx + batch_size # batch size is bigger than available audio/text items if next_batch_idx >= args0_len: inputs[0] = input_tmp[dataset_idx:] return func(*tuple(inputs)) else: inputs[0] = input_tmp[dataset_idx:next_batch_idx] yield func(*tuple(inputs)) dataset_idx = next_batch_idx def get_audio_embeddings_per_batch(self, audio_files, batch_size): r"""Load preprocessed audio and return a audio embeddings per batch""" return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size) def get_text_embeddings_per_batch(self, class_labels, batch_size): r"""Load preprocessed text and return text embeddings per batch""" return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size) def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size): r"""Compute classification probabilities for each audio recording in a batch and each class label""" return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)