ZeyuXie's picture
Upload 167 files
8c1bf05 verified
raw
history blame
36.3 kB
import json
import logging
import math
import os
import time
from contextlib import suppress
import numpy as np
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
from open_clip import ClipLoss, gather_features
from .distributed import is_master
from .zero_shot import zero_shot_eval
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def unwrap_model(model):
if hasattr(model, "module"):
return model.module
else:
return model
def train_one_epoch(
model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
):
device = torch.device(args.device)
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
model.train()
loss = ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mlp_loss=args.clap_mlploss,
weight_loss_kappa=args.kappa,
)
dataloader, sampler = data["train"].dataloader, data["train"].sampler
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
num_batches_per_epoch = dataloader.num_batches
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
# for toy dataset
if args.dataset_type == "toy":
dataloader.dataset.generate_queue()
loss_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
# logging.info(f"batch {i} of {num_batches_per_epoch}")
step = num_batches_per_epoch * epoch + i
if isinstance(scheduler, dict):
for s in scheduler.values():
s(step)
else:
scheduler(step)
audios = batch # contains mel_spec, wavform, and longer list
texts = batch["text"]
# audios = audios.to(device=device, non_blocking=True)
# texts = texts.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
if isinstance(optimizer, dict):
for o_ in optimizer.values():
o_.zero_grad()
else:
optimizer.zero_grad()
with autocast():
(
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
logit_scale_a,
logit_scale_t,
) = model(audios, texts, device)
if args.clap_mlploss:
total_loss = loss(
audio_features=audio_features,
text_features=text_features,
logit_scale_a=logit_scale_a,
logit_scale_t=logit_scale_t,
audio_features_mlp=audio_features_mlp,
text_features_mlp=text_features_mlp,
)
else:
total_loss = loss(
audio_features=audio_features,
text_features=text_features,
logit_scale_a=logit_scale_a,
)
if isinstance(optimizer, dict):
if scaler is not None:
scaler.scale(total_loss).backward()
for o_ in optimizer.values():
if args.horovod:
o_.synchronize()
scaler.unscale_(o_)
with o_.skip_synchronize():
scaler.step(o_)
else:
scaler.step(o_)
scaler.update()
else:
total_loss.backward()
for o_ in optimizer.values():
o_.step()
else:
if scaler is not None:
scaler.scale(total_loss).backward()
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale_a.clamp_(0, math.log(100))
if args.clap_mlploss:
unwrap_model(model).logit_scale_t.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i + 1
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
if isinstance(audios, dict):
batch_size = len(audios["waveform"])
else:
batch_size = len(audios)
num_samples = batch_count * batch_size * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
loss_m.update(total_loss.item(), batch_size)
logit_scale_scalar_a = logit_scale_a.item()
logit_scale_scalar_t = logit_scale_t.item()
if isinstance(optimizer, dict):
if args.clap_mlploss:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
)
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"scale_text": logit_scale_scalar_t,
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
}
else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
)
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
}
else:
if args.clap_mlploss:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"scale_text": logit_scale_scalar_t,
"lr": optimizer.param_groups[0]["lr"],
}
else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"lr": optimizer.param_groups[0]["lr"],
}
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": step})
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None):
metrics = {}
if not args.parallel_eval:
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
# CHANGE
# zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
# metrics.update(zero_shot_metrics)
if is_master(args):
print("Evaluating...")
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
if args.val_dataset_names == ["Clotho", "audiocaps"]:
# if only clotho and audiocaps are used, then we will use a different evaluation function.
# This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio.
if args.parallel_eval:
# (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps.
raise NotImplementedError(
"Parallel evaluation not supported for eval only Clotho and audiocaps."
)
val_metrics_per_dataset = evaluate_clotho_audiocaps(
model, data, epoch, args, autocast, device, tb_writer
)
for m in val_metrics_per_dataset.values():
metrics.update(m)
if "epoch" not in metrics.keys():
metrics.update({"epoch": epoch})
metrics = select_top_metric_clotho_audiocaps(
metrics, val_metrics_per_dataset, args
)
elif "val" in data and (
args.val_frequency
and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
):
dataloader = data["val"].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
# FIXME this does not scale past small eval datasets
# all_audio_features @ all_text_features will blow up memory and compute very quickly
eval_info = {}
if args.clap_mlploss:
eval_info["all"] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
"all_audio_features_mlp": [],
"all_text_features_mlp": [],
} # cumulative_loss = 0.0
else:
eval_info["all"] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
} # cumu
# all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], []
with torch.no_grad():
for i, batch in enumerate(dataloader):
audios = batch # contains mel_spec, wavform, and longer list
texts = batch["text"]
# audios = audios.to(device=device, non_blocking=True)
all_names = list(
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
)
for name in all_names:
if name not in eval_info.keys():
if args.clap_mlploss:
eval_info[name] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
"all_audio_features_mlp": [],
"all_text_features_mlp": [],
}
else:
eval_info[name] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
}
with autocast():
(
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
logit_scale_a,
logit_scale_t,
) = model(audios, texts, device)
if args.parallel_eval:
# multi-GPU eval
if args.clap_mlploss:
(
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
) = gather_features(
audio_features=audio_features,
text_features=text_features,
audio_features_mlp=audio_features_mlp,
text_features_mlp=text_features_mlp,
local_loss=False,
gather_with_grad=False,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mlp_loss=args.clap_mlploss,
)
else:
(audio_features, text_features,) = gather_features(
audio_features=audio_features,
text_features=text_features,
local_loss=False,
gather_with_grad=False,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mlp_loss=args.clap_mlploss,
)
if is_master(args):
num_samples += audio_features.shape[0]
for n in [*all_names, "all"]:
if n == "all":
eval_info[n]["all_audio_features"].append(
audio_features.cpu()
)
eval_info[n]["all_text_features"].append(
text_features.cpu()
)
if args.clap_mlploss:
eval_info[n]["all_audio_features_mlp"].append(
audio_features_mlp.cpu()
)
eval_info[n]["all_text_features_mlp"].append(
text_features_mlp.cpu()
)
else:
idx = np.where(
np.array(
[
"-".join(b.split("/")[-3:-1])
for b in batch["__url__"]
]
)
== n
)[0]
eval_info[n]["all_audio_features"].append(
audio_features.cpu().index_select(
0, torch.tensor(idx).long()
)
)
eval_info[n]["all_text_features"].append(
text_features.cpu().index_select(
0, torch.tensor(idx).long()
)
)
if args.clap_mlploss:
eval_info[n]["all_audio_features_mlp"].append(
audio_features_mlp.cpu().index_select(
0, torch.tensor(idx).long()
)
)
eval_info[n]["all_text_features_mlp"].append(
text_features_mlp.cpu().index_select(
0, torch.tensor(idx).long()
)
)
# print(f'eval step {i}') # (yusong): for debug
# cumulative_loss += total_loss * batch_size
# num_samples += batch_size
if is_master(args) and (i % 100) == 0: # and i != 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
)
if is_master(args):
val_metrics_per_dataset = {}
for n in eval_info.keys():
if args.clap_mlploss:
metrics_single_dataset = get_metrics(
audio_features=torch.cat(
eval_info[n]["all_audio_features"]
),
text_features=torch.cat(eval_info[n]["all_text_features"]),
logit_scale_a=logit_scale_a.cpu(),
audio_features_mlp=torch.cat(
eval_info[n]["all_audio_features_mlp"]
),
text_features_mlp=torch.cat(
eval_info[n]["all_text_features_mlp"]
),
logit_scale_t=logit_scale_t.cpu(),
mlp_loss=args.clap_mlploss,
)
else:
metrics_single_dataset = get_metrics(
audio_features=torch.cat(
eval_info[n]["all_audio_features"]
),
text_features=torch.cat(eval_info[n]["all_text_features"]),
logit_scale_a=logit_scale_a.cpu(),
mlp_loss=args.clap_mlploss,
)
val_metrics_per_dataset[n] = {
n + "/" + k: v for k, v in metrics_single_dataset.items()
}
metrics.update(val_metrics_per_dataset[n])
if "epoch" not in metrics.keys():
metrics.update({"epoch": epoch})
if is_master(args):
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\n".join(
[
"\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()])
for m in val_metrics_per_dataset.values()
]
)
)
if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, "Please install wandb."
for name, val in metrics.items():
wandb.log({f"val/{name}": val, "epoch": epoch})
return metrics
else:
return metrics
def get_metrics(
audio_features,
text_features,
logit_scale_a,
audio_features_mlp=None,
text_features_mlp=None,
logit_scale_t=None,
mlp_loss=False,
):
metrics = {}
if mlp_loss:
# Set up audio to text & text to audio similary matrice
a_logits_per_audio = (
(logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu()
)
a_logits_per_text = a_logits_per_audio.t().detach().cpu()
t_logits_per_audio = (
(logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu()
)
t_logits_per_text = t_logits_per_audio.t().detach().cpu()
labels = torch.arange(audio_features.shape[0]).long()
# Change the loss from two terms into four terms with 2x2 combined CE loss
total_loss = (
F.cross_entropy(a_logits_per_audio, labels)
+ F.cross_entropy(a_logits_per_text, labels)
+ F.cross_entropy(t_logits_per_audio, labels)
+ F.cross_entropy(t_logits_per_text, labels)
) / 4
metrics[f"cumulative_loss"] = total_loss.item()
metrics[f"num_samples"] = audio_features.shape[0]
logits = {
"audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2,
"text_to_audio": (a_logits_per_text + t_logits_per_text) / 2,
}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
else:
# print("text_features", text_features)
# print("text_features.shape", text_features.shape)
logits_per_audio = (
(logit_scale_a * audio_features @ text_features.t()).detach().cpu()
)
logits_per_text = logits_per_audio.t().detach().cpu()
labels = torch.arange(audio_features.shape[0]).long()
# Change the loss from two terms into four terms with 2x2 combined CE loss
total_loss = (
F.cross_entropy(logits_per_audio, labels)
+ F.cross_entropy(logits_per_text, labels)
) / 2
metrics[f"cumulative_loss"] = total_loss.item()
metrics[f"num_samples"] = audio_features.shape[0]
logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[
1
] # (yusong) this line is slow because it uses single thread
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
# map@10
metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
return metrics
def evaluate_clotho_audiocaps(
model, data, epoch, args, autocast, device, tb_writer=None
):
"""
Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
1. for text-to-audio retrieval, do 5 times and average the results
2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
3. for map@10 in audio-to-text retrieval:
3.1: sort the rank of 5 text
3.2: exclude the rank >=10 (0-index)
3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
(3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
(3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
"""
# TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now.
dataloader = data["val"].dataloader
with torch.no_grad():
eval_info = {}
for i, batch in enumerate(dataloader):
audios = batch # contains mel_spec, wavform, and longer list
# each item in the list has 5 texts
if args.tmodel == "transformer":
from open_clip import tokenize
texts = [tokenize(t) for t in batch["full_text"]]
texts = torch.cat(texts)
else:
from .data import tokenizer
texts = [
tokenizer(t) for t in batch["full_text"]
] # 5 texts for each audio
texts = {
k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()
} # 5 x batch
# audios = audios.to(device=device, non_blocking=True)
all_names = list(
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
)
for name in all_names:
if name not in eval_info.keys():
# we will not use mlp outputs even if args.clap_mlploss=True
eval_info[name] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
}
with autocast():
audio_features = model(audios, None, device)
text_features = model(None, texts, device)
audio_features = F.normalize(audio_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
all_names = list(
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
)
for n in all_names:
idx = np.where(
np.array(
["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]
)
== n
)[0]
eval_info[n]["all_audio_features"].append(
audio_features.cpu().index_select(0, torch.tensor(idx).long())
)
# (yusong) please double-check. This is for selecting 5 text features at once.
# because idx is a list of indices in size of num_samples,
# and text_features is a tensor of size (5*num_samples, dim)
# so we need to select 5 consecutive indices at once for a single index in idx.
eval_info[n]["all_text_features"].append(
text_features.cpu()
.reshape([-1, 5, text_features.shape[1]])
.index_select(0, torch.tensor(idx).long())
.reshape([-1, text_features.shape[1]])
)
val_metrics_all = {}
for n in eval_info.keys():
logit_scale_a, logit_scale_t = model(None, None, device)
logit_scale_a = logit_scale_a.cpu()
audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0)
text_features = torch.cat(eval_info[n]["all_text_features"], dim=0)
logits_per_audio = (
(logit_scale_a * audio_features @ text_features.t()).detach().cpu()
)
logits_per_text = logits_per_audio.t().detach().cpu()
# logits_per_audio shape: [num_samples, num_samples*5]
# logits_per_text shape: [num_samples*5, num_samples]
logging.info(
f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, "
f"logits_per_text shape: {logits_per_text.shape}"
)
metrics = {}
num_samples = audio_features.shape[0]
metrics[f"num_samples"] = num_samples
# (yusong) the following code is very important, please double-check:
# logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d]
# logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
# Those two are retrieving one of the 5 text for each audio.
labels = torch.arange(audio_features.shape[0]).long()
audio_to_text_loss = [
F.cross_entropy(
logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d],
labels,
)
for d in range(5)
]
text_to_audio_loss = [
F.cross_entropy(
logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :],
labels,
)
for d in range(5)
]
total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2
metrics[f"cumulative_loss"] = total_loss.item()
# text to audio: do 5 times
pred_text = []
for d in range(5):
logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
ground_truth = torch.arange(len(logit)).view(-1, 1)
ranking = torch.argsort(
logit, descending=True
) # [num_samples, num_samples]
preds = torch.where(ranking == ground_truth)[1]
pred_text.append(preds.detach().cpu().numpy())
pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples]
metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1
metrics[f"text_to_audio_median_rank"] = (
np.floor(np.median(pred_text_concat)) + 1
)
for k in [1, 5, 10]:
metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k)
# map@10
metrics[f"text_to_audio_mAP@10"] = np.mean(
np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0)
)
# audio to text: take the best result
# for audio to text map 10, sort and assign descending ground truth.
# see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103
# map@10
map_all = []
pred_audio_all = []
for d in range(num_samples):
# logits_per_audio: [num_samples, num_samples*5]
logit_single = logits_per_audio[d, :] # [5*num_samples]
# Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4]
ranking = torch.argsort(
logit_single, descending=True
) # [5*num_samples]
# ranking: the index of first match, second match, ...
ground_truth = torch.arange(d * 5, d * 5 + 5)[None]
all_pred = torch.where(
torch.stack([ranking] * 5) == ground_truth.view(-1, 1)
)[1]
min_pred = torch.min(all_pred)
pred_audio_all.append(min_pred.detach().cpu().numpy())
all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy()
# /5 because we have 5 text, so it means for the text rank >=10 we count as 0.
map_single = (
np.sum(
(np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))
)
/ 5
)
map_all.append(map_single)
metrics[f"audio_to_text_mAP@10"] = np.mean(map_all)
for k in [1, 5, 10]:
metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k)
val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()}
return val_metrics_all
def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset):
"""
Calculate performance for Clotho+AudioCaps for model selection.
"""
selection_performance_all = []
for n in val_metrics_per_dataset.keys():
selection_performance = (
val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"]
+ val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]
) / 2
selection_performance_all.append(selection_performance)
return np.mean(selection_performance_all)
def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args):
# val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value
# metrics: dict, key: metric name, value: metric value
# Hack: use args to save the top performance
if not hasattr(args, "top_selection_performance"):
selection_performance = calculate_selection_performance_clotho_audiocaps(
val_metrics_per_dataset
)
# TODO: write the if and else together
metric_update = {}
for n in val_metrics_per_dataset.keys():
for k in val_metrics_per_dataset[n].keys():
metric_update[
k.split("/")[0] + "-top" + "/" + k.split("/")[1]
] = val_metrics_per_dataset[n][k]
metric_update["top_selection_performance"] = selection_performance
metric_update["top-selection-epoch"] = metrics["epoch"]
metrics.update(metric_update)
args.top_metric = metric_update
args.top_selection_performance = selection_performance
else:
selection_performance_new = calculate_selection_performance_clotho_audiocaps(
val_metrics_per_dataset
)
selection_performance_old = args.top_selection_performance
if selection_performance_new > selection_performance_old:
metric_update = {}
for n in val_metrics_per_dataset.keys():
for k in val_metrics_per_dataset[n].keys():
metric_update[
k.split("/")[0] + "-top" + "/" + k.split("/")[1]
] = val_metrics_per_dataset[n][k]
metric_update["top_selection_performance"] = selection_performance_new
metric_update["top-selection-epoch"] = metrics["epoch"]
metrics.update(metric_update)
args.top_metric = metric_update
args.top_selection_performance = selection_performance_new
else:
metrics.update(args.top_metric)
return metrics