import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from open_clip import create_model from open_clip import tokenize import glob import json import librosa from tqdm import tqdm import numpy as np import os from laion_clap.training.params import parse_args def get_output_from_single_audio(audio, text, model, device): # audio_embedding = model.audio_infer(audio, hopsize=5 * 48000, key="embedding", device=device)['embedding'] # if audio_embedding.ndim > 1: # audio_embedding = audio_embedding.mean(dim=0, keepdim=True) # else: # audio_embedding = audio_embedding.unsqueeze(0) audio_features = model(audio, None, device) audio_features = F.normalize(audio_features, dim=-1) text_features = model(None, text, device=device) text_features = F.normalize(text_features, dim=-1) # CHANGE: before normalize or after audio_features_mlp = model.audio_transform(audio_features) text_features_mlp = model.text_transform(text_features) return audio_features, text_features, audio_features_mlp, text_features_mlp, model.logit_scale_a.exp(), model.logit_scale_t.exp() def get_metrics(text_to_audio_logits): metrics = {} # repeat ground truth 5 times because Clotho has 5 text for 1 audio ground_truth = torch.repeat_interleave(torch.arange(len(text_features) // 5), 5).view(-1, 1) ranking = torch.argsort(text_to_audio_logits, descending=True) preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread preds = preds.detach().cpu().numpy() metrics[f"mean_rank"] = preds.mean() + 1 metrics[f"median_rank"] = np.floor(np.median(preds)) + 1 for k in [1, 5, 10]: metrics[f"R@{k}"] = np.mean(preds < k) # map@10 metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) return metrics if __name__ == '__main__': args = parse_args() model_path = args.pretrained clotho_test_preprocessed_dir = "/fsx/yusong/clotho_test_set/test" cudnn.benchmark = True cudnn.deterministic = False audio_features_ensemble_all = [] text_features_ensemble_all = [] audio_features_mlp_ensemble_all = [] text_features_mlp_ensemble_all = [] logit_scale_a_ensemble_all = [] logit_scale_t_ensemble_all = [] device = torch.device('cuda') model, clap_model_cfg = create_model( args.amodel, args.tmodel, args.pretrained, precision=args.precision, device=device, jit=args.torchscript, force_quick_gelu=args.force_quick_gelu, openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), skip_params=False ) # load model checkpoint = torch.load(model_path, map_location=device) if "epoch" in checkpoint: # resuming a train checkpoint w/ epoch and optimizer state start_epoch = checkpoint["epoch"] sd = checkpoint["state_dict"] if next(iter(sd.items()))[0].startswith( "module" ): sd = {k[len("module."):]: v for k, v in sd.items()} model.load_state_dict(sd) else: # loading a bare (model only) checkpoint for fine-tune or evaluation model.load_state_dict(checkpoint) model.to(device) model.eval() for param in model.parameters(): param.requires_grad = False # take every 5th file because clotho has 5 texts for 1 audio test_file_list = sorted(glob.glob(f"{clotho_test_preprocessed_dir}/*.flac")) audio_features_all = [] text_features_all = [] audio_features_mlp_all = [] text_features_mlp_all = [] logit_scale_a_all = [] logit_scale_t_all = [] with torch.no_grad(): for file_path in tqdm(test_file_list): json_path = file_path.replace(".flac", ".json") with open(json_path, "r") as f: json_data = json.load(f) audio, sr = librosa.load(file_path, sr=48000, mono=True) audio = torch.from_numpy(audio).to(device) audio = {'waveform': audio.unsqueeze(0), 'sample_rate': sr} text = json_data["text"] if args.tmodel == "transformer": from open_clip import tokenize text = tokenize(text) else: from laion_clap.training.data import tokenizer text = tokenizer(text, tmodel=args.tmodel) # 5 texts for each audio audio_features, text_features, audio_features_mlp, text_features_mlp, logit_scale_a, logit_scale_t = \ get_output_from_single_audio(audio, text, model, device) audio_features_all.append(audio_features.detach().cpu()) text_features_all.append(text_features.detach().cpu()) audio_features_mlp_all.append(audio_features_mlp.detach().cpu()) text_features_mlp_all.append(text_features_mlp.detach().cpu()) logit_scale_a_all.append(logit_scale_a.detach().cpu()) logit_scale_t_all.append(logit_scale_t.detach().cpu()) audio_features = torch.cat(audio_features_all) text_features = torch.cat(text_features_all) logit_scale_a = logit_scale_a_all[0] logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_audio.t().detach().cpu() metrics = get_metrics( logits_per_text ) print(metrics)