jiuku's picture
Duplicate from haoheliu/audioldm-text-to-audio-generation
4039be3
raw
history blame
36.3 kB
import json
import logging
import math
import os
import time
from contextlib import suppress
import numpy as np
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
from open_clip import ClipLoss, gather_features
from .distributed import is_master
from .zero_shot import zero_shot_eval
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def unwrap_model(model):
if hasattr(model, "module"):
return model.module
else:
return model
def train_one_epoch(
model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
):
device = torch.device(args.device)
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
model.train()
loss = ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mlp_loss=args.clap_mlploss,
weight_loss_kappa=args.kappa,
)
dataloader, sampler = data["train"].dataloader, data["train"].sampler
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
num_batches_per_epoch = dataloader.num_batches
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
# for toy dataset
if args.dataset_type == "toy":
dataloader.dataset.generate_queue()
loss_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
# logging.info(f"batch {i} of {num_batches_per_epoch}")
step = num_batches_per_epoch * epoch + i
if isinstance(scheduler, dict):
for s in scheduler.values():
s(step)
else:
scheduler(step)
audios = batch # contains mel_spec, wavform, and longer list
texts = batch["text"]
# audios = audios.to(device=device, non_blocking=True)
# texts = texts.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
if isinstance(optimizer, dict):
for o_ in optimizer.values():
o_.zero_grad()
else:
optimizer.zero_grad()
with autocast():
(
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
logit_scale_a,
logit_scale_t,
) = model(audios, texts, device)
if args.clap_mlploss:
total_loss = loss(
audio_features=audio_features,
text_features=text_features,
logit_scale_a=logit_scale_a,
logit_scale_t=logit_scale_t,
audio_features_mlp=audio_features_mlp,
text_features_mlp=text_features_mlp,
)
else:
total_loss = loss(
audio_features=audio_features,
text_features=text_features,
logit_scale_a=logit_scale_a,
)
if isinstance(optimizer, dict):
if scaler is not None:
scaler.scale(total_loss).backward()
for o_ in optimizer.values():
if args.horovod:
o_.synchronize()
scaler.unscale_(o_)
with o_.skip_synchronize():
scaler.step(o_)
else:
scaler.step(o_)
scaler.update()
else:
total_loss.backward()
for o_ in optimizer.values():
o_.step()
else:
if scaler is not None:
scaler.scale(total_loss).backward()
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale_a.clamp_(0, math.log(100))
if args.clap_mlploss:
unwrap_model(model).logit_scale_t.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i + 1
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
if isinstance(audios, dict):
batch_size = len(audios["waveform"])
else:
batch_size = len(audios)
num_samples = batch_count * batch_size * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
loss_m.update(total_loss.item(), batch_size)
logit_scale_scalar_a = logit_scale_a.item()
logit_scale_scalar_t = logit_scale_t.item()
if isinstance(optimizer, dict):
if args.clap_mlploss:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
)
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"scale_text": logit_scale_scalar_t,
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
}
else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
)
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
}
else:
if args.clap_mlploss:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"scale_text": logit_scale_scalar_t,
"lr": optimizer.param_groups[0]["lr"],
}
else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"scale_audio": logit_scale_scalar_a,
"lr": optimizer.param_groups[0]["lr"],
}
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": step})
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None):
metrics = {}
if not args.parallel_eval:
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
# CHANGE
# zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
# metrics.update(zero_shot_metrics)
if is_master(args):
print("Evaluating...")
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
if args.val_dataset_names == ["Clotho", "audiocaps"]:
# if only clotho and audiocaps are used, then we will use a different evaluation function.
# This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio.
if args.parallel_eval:
# (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps.
raise NotImplementedError(
"Parallel evaluation not supported for eval only Clotho and audiocaps."
)
val_metrics_per_dataset = evaluate_clotho_audiocaps(
model, data, epoch, args, autocast, device, tb_writer
)
for m in val_metrics_per_dataset.values():
metrics.update(m)
if "epoch" not in metrics.keys():
metrics.update({"epoch": epoch})
metrics = select_top_metric_clotho_audiocaps(
metrics, val_metrics_per_dataset, args
)
elif "val" in data and (
args.val_frequency
and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
):
dataloader = data["val"].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
# FIXME this does not scale past small eval datasets
# all_audio_features @ all_text_features will blow up memory and compute very quickly
eval_info = {}
if args.clap_mlploss:
eval_info["all"] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
"all_audio_features_mlp": [],
"all_text_features_mlp": [],
} # cumulative_loss = 0.0
else:
eval_info["all"] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
} # cumu
# all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], []
with torch.no_grad():
for i, batch in enumerate(dataloader):
audios = batch # contains mel_spec, wavform, and longer list
texts = batch["text"]
# audios = audios.to(device=device, non_blocking=True)
all_names = list(
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
)
for name in all_names:
if name not in eval_info.keys():
if args.clap_mlploss:
eval_info[name] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
"all_audio_features_mlp": [],
"all_text_features_mlp": [],
}
else:
eval_info[name] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
}
with autocast():
(
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
logit_scale_a,
logit_scale_t,
) = model(audios, texts, device)
if args.parallel_eval:
# multi-GPU eval
if args.clap_mlploss:
(
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
) = gather_features(
audio_features=audio_features,
text_features=text_features,
audio_features_mlp=audio_features_mlp,
text_features_mlp=text_features_mlp,
local_loss=False,
gather_with_grad=False,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mlp_loss=args.clap_mlploss,
)
else:
(audio_features, text_features,) = gather_features(
audio_features=audio_features,
text_features=text_features,
local_loss=False,
gather_with_grad=False,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mlp_loss=args.clap_mlploss,
)
if is_master(args):
num_samples += audio_features.shape[0]
for n in [*all_names, "all"]:
if n == "all":
eval_info[n]["all_audio_features"].append(
audio_features.cpu()
)
eval_info[n]["all_text_features"].append(
text_features.cpu()
)
if args.clap_mlploss:
eval_info[n]["all_audio_features_mlp"].append(
audio_features_mlp.cpu()
)
eval_info[n]["all_text_features_mlp"].append(
text_features_mlp.cpu()
)
else:
idx = np.where(
np.array(
[
"-".join(b.split("/")[-3:-1])
for b in batch["__url__"]
]
)
== n
)[0]
eval_info[n]["all_audio_features"].append(
audio_features.cpu().index_select(
0, torch.tensor(idx).long()
)
)
eval_info[n]["all_text_features"].append(
text_features.cpu().index_select(
0, torch.tensor(idx).long()
)
)
if args.clap_mlploss:
eval_info[n]["all_audio_features_mlp"].append(
audio_features_mlp.cpu().index_select(
0, torch.tensor(idx).long()
)
)
eval_info[n]["all_text_features_mlp"].append(
text_features_mlp.cpu().index_select(
0, torch.tensor(idx).long()
)
)
# print(f'eval step {i}') # (yusong): for debug
# cumulative_loss += total_loss * batch_size
# num_samples += batch_size
if is_master(args) and (i % 100) == 0: # and i != 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
)
if is_master(args):
val_metrics_per_dataset = {}
for n in eval_info.keys():
if args.clap_mlploss:
metrics_single_dataset = get_metrics(
audio_features=torch.cat(
eval_info[n]["all_audio_features"]
),
text_features=torch.cat(eval_info[n]["all_text_features"]),
logit_scale_a=logit_scale_a.cpu(),
audio_features_mlp=torch.cat(
eval_info[n]["all_audio_features_mlp"]
),
text_features_mlp=torch.cat(
eval_info[n]["all_text_features_mlp"]
),
logit_scale_t=logit_scale_t.cpu(),
mlp_loss=args.clap_mlploss,
)
else:
metrics_single_dataset = get_metrics(
audio_features=torch.cat(
eval_info[n]["all_audio_features"]
),
text_features=torch.cat(eval_info[n]["all_text_features"]),
logit_scale_a=logit_scale_a.cpu(),
mlp_loss=args.clap_mlploss,
)
val_metrics_per_dataset[n] = {
n + "/" + k: v for k, v in metrics_single_dataset.items()
}
metrics.update(val_metrics_per_dataset[n])
if "epoch" not in metrics.keys():
metrics.update({"epoch": epoch})
if is_master(args):
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\n".join(
[
"\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()])
for m in val_metrics_per_dataset.values()
]
)
)
if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, "Please install wandb."
for name, val in metrics.items():
wandb.log({f"val/{name}": val, "epoch": epoch})
return metrics
else:
return metrics
def get_metrics(
audio_features,
text_features,
logit_scale_a,
audio_features_mlp=None,
text_features_mlp=None,
logit_scale_t=None,
mlp_loss=False,
):
metrics = {}
if mlp_loss:
# Set up audio to text & text to audio similary matrice
a_logits_per_audio = (
(logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu()
)
a_logits_per_text = a_logits_per_audio.t().detach().cpu()
t_logits_per_audio = (
(logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu()
)
t_logits_per_text = t_logits_per_audio.t().detach().cpu()
labels = torch.arange(audio_features.shape[0]).long()
# Change the loss from two terms into four terms with 2x2 combined CE loss
total_loss = (
F.cross_entropy(a_logits_per_audio, labels)
+ F.cross_entropy(a_logits_per_text, labels)
+ F.cross_entropy(t_logits_per_audio, labels)
+ F.cross_entropy(t_logits_per_text, labels)
) / 4
metrics[f"cumulative_loss"] = total_loss.item()
metrics[f"num_samples"] = audio_features.shape[0]
logits = {
"audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2,
"text_to_audio": (a_logits_per_text + t_logits_per_text) / 2,
}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
else:
# print("text_features", text_features)
# print("text_features.shape", text_features.shape)
logits_per_audio = (
(logit_scale_a * audio_features @ text_features.t()).detach().cpu()
)
logits_per_text = logits_per_audio.t().detach().cpu()
labels = torch.arange(audio_features.shape[0]).long()
# Change the loss from two terms into four terms with 2x2 combined CE loss
total_loss = (
F.cross_entropy(logits_per_audio, labels)
+ F.cross_entropy(logits_per_text, labels)
) / 2
metrics[f"cumulative_loss"] = total_loss.item()
metrics[f"num_samples"] = audio_features.shape[0]
logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[
1
] # (yusong) this line is slow because it uses single thread
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
# map@10
metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
return metrics
def evaluate_clotho_audiocaps(
model, data, epoch, args, autocast, device, tb_writer=None
):
"""
Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
1. for text-to-audio retrieval, do 5 times and average the results
2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
3. for map@10 in audio-to-text retrieval:
3.1: sort the rank of 5 text
3.2: exclude the rank >=10 (0-index)
3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
(3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
(3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
"""
# TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now.
dataloader = data["val"].dataloader
with torch.no_grad():
eval_info = {}
for i, batch in enumerate(dataloader):
audios = batch # contains mel_spec, wavform, and longer list
# each item in the list has 5 texts
if args.tmodel == "transformer":
from open_clip import tokenize
texts = [tokenize(t) for t in batch["full_text"]]
texts = torch.cat(texts)
else:
from .data import tokenizer
texts = [
tokenizer(t) for t in batch["full_text"]
] # 5 texts for each audio
texts = {
k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()
} # 5 x batch
# audios = audios.to(device=device, non_blocking=True)
all_names = list(
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
)
for name in all_names:
if name not in eval_info.keys():
# we will not use mlp outputs even if args.clap_mlploss=True
eval_info[name] = {
"cumulative_loss": 0.0,
"num_samples": 0,
"all_audio_features": [],
"all_text_features": [],
}
with autocast():
audio_features = model(audios, None, device)
text_features = model(None, texts, device)
audio_features = F.normalize(audio_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
all_names = list(
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
)
for n in all_names:
idx = np.where(
np.array(
["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]
)
== n
)[0]
eval_info[n]["all_audio_features"].append(
audio_features.cpu().index_select(0, torch.tensor(idx).long())
)
# (yusong) please double-check. This is for selecting 5 text features at once.
# because idx is a list of indices in size of num_samples,
# and text_features is a tensor of size (5*num_samples, dim)
# so we need to select 5 consecutive indices at once for a single index in idx.
eval_info[n]["all_text_features"].append(
text_features.cpu()
.reshape([-1, 5, text_features.shape[1]])
.index_select(0, torch.tensor(idx).long())
.reshape([-1, text_features.shape[1]])
)
val_metrics_all = {}
for n in eval_info.keys():
logit_scale_a, logit_scale_t = model(None, None, device)
logit_scale_a = logit_scale_a.cpu()
audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0)
text_features = torch.cat(eval_info[n]["all_text_features"], dim=0)
logits_per_audio = (
(logit_scale_a * audio_features @ text_features.t()).detach().cpu()
)
logits_per_text = logits_per_audio.t().detach().cpu()
# logits_per_audio shape: [num_samples, num_samples*5]
# logits_per_text shape: [num_samples*5, num_samples]
logging.info(
f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, "
f"logits_per_text shape: {logits_per_text.shape}"
)
metrics = {}
num_samples = audio_features.shape[0]
metrics[f"num_samples"] = num_samples
# (yusong) the following code is very important, please double-check:
# logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d]
# logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
# Those two are retrieving one of the 5 text for each audio.
labels = torch.arange(audio_features.shape[0]).long()
audio_to_text_loss = [
F.cross_entropy(
logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d],
labels,
)
for d in range(5)
]
text_to_audio_loss = [
F.cross_entropy(
logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :],
labels,
)
for d in range(5)
]
total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2
metrics[f"cumulative_loss"] = total_loss.item()
# text to audio: do 5 times
pred_text = []
for d in range(5):
logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
ground_truth = torch.arange(len(logit)).view(-1, 1)
ranking = torch.argsort(
logit, descending=True
) # [num_samples, num_samples]
preds = torch.where(ranking == ground_truth)[1]
pred_text.append(preds.detach().cpu().numpy())
pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples]
metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1
metrics[f"text_to_audio_median_rank"] = (
np.floor(np.median(pred_text_concat)) + 1
)
for k in [1, 5, 10]:
metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k)
# map@10
metrics[f"text_to_audio_mAP@10"] = np.mean(
np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0)
)
# audio to text: take the best result
# for audio to text map 10, sort and assign descending ground truth.
# see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103
# map@10
map_all = []
pred_audio_all = []
for d in range(num_samples):
# logits_per_audio: [num_samples, num_samples*5]
logit_single = logits_per_audio[d, :] # [5*num_samples]
# Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4]
ranking = torch.argsort(
logit_single, descending=True
) # [5*num_samples]
# ranking: the index of first match, second match, ...
ground_truth = torch.arange(d * 5, d * 5 + 5)[None]
all_pred = torch.where(
torch.stack([ranking] * 5) == ground_truth.view(-1, 1)
)[1]
min_pred = torch.min(all_pred)
pred_audio_all.append(min_pred.detach().cpu().numpy())
all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy()
# /5 because we have 5 text, so it means for the text rank >=10 we count as 0.
map_single = (
np.sum(
(np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))
)
/ 5
)
map_all.append(map_single)
metrics[f"audio_to_text_mAP@10"] = np.mean(map_all)
for k in [1, 5, 10]:
metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k)
val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()}
return val_metrics_all
def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset):
"""
Calculate performance for Clotho+AudioCaps for model selection.
"""
selection_performance_all = []
for n in val_metrics_per_dataset.keys():
selection_performance = (
val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"]
+ val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]
) / 2
selection_performance_all.append(selection_performance)
return np.mean(selection_performance_all)
def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args):
# val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value
# metrics: dict, key: metric name, value: metric value
# Hack: use args to save the top performance
if not hasattr(args, "top_selection_performance"):
selection_performance = calculate_selection_performance_clotho_audiocaps(
val_metrics_per_dataset
)
# TODO: write the if and else together
metric_update = {}
for n in val_metrics_per_dataset.keys():
for k in val_metrics_per_dataset[n].keys():
metric_update[
k.split("/")[0] + "-top" + "/" + k.split("/")[1]
] = val_metrics_per_dataset[n][k]
metric_update["top_selection_performance"] = selection_performance
metric_update["top-selection-epoch"] = metrics["epoch"]
metrics.update(metric_update)
args.top_metric = metric_update
args.top_selection_performance = selection_performance
else:
selection_performance_new = calculate_selection_performance_clotho_audiocaps(
val_metrics_per_dataset
)
selection_performance_old = args.top_selection_performance
if selection_performance_new > selection_performance_old:
metric_update = {}
for n in val_metrics_per_dataset.keys():
for k in val_metrics_per_dataset[n].keys():
metric_update[
k.split("/")[0] + "-top" + "/" + k.split("/")[1]
] = val_metrics_per_dataset[n][k]
metric_update["top_selection_performance"] = selection_performance_new
metric_update["top-selection-epoch"] = metrics["epoch"]
metrics.update(metric_update)
args.top_metric = metric_update
args.top_selection_performance = selection_performance_new
else:
metrics.update(args.top_metric)
return metrics