curate-me-a-playlist / src /laion_clap /evaluate /eval_zeroshot_classification.py
berkaygkv54's picture
first push
19759e2
import os.path
import glob
import random
import numpy as np
import logging
import wandb
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from clap_module import create_model
from clap_module import tokenize
from training.logger import setup_logging
from training.data import get_data
from training.train import evaluate
from clap_module.utils import get_tar_path_from_dataset_name, dataset_split
from training.params import parse_args
def find_params_value(file, key):
# find value of params in params_file
with open(file, 'r') as f:
for line in f:
if key + ': ' in line:
return line.split(': ')[1].strip()
return None
def evaluate_zeroshot(model, data, start_epoch, args, writer):
dataloader = data["val"].dataloader
metrics = {}
device = torch.device(args.device)
model.eval()
metrics.update({"epoch": start_epoch})
all_audio_features = []
all_class_labels = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
audios = batch # contains mel_spec, wavform, and longer list
audio_features = model(audios, None, device)
audio_features = F.normalize(audio_features, dim=-1)
all_audio_features.append(audio_features.detach().cpu())
all_class_labels.append(torch.argmax(batch["class_label"], 1).long())
all_audio_features = torch.cat(all_audio_features, dim=0)
all_class_labels = torch.cat(all_class_labels, dim=0)
metrics["num_samples"] = all_audio_features.shape[0]
# get text features
if args.val_dataset_names == ['GTZAN']:
all_texts = [f"This is a {t} song." for t in args.class_index_dict.keys()]
else:
all_texts = [f"This is a sound of {t}." for t in args.class_index_dict.keys()]
logging.info(f'class label prompts: {all_texts}')
# (yusong): a hack, can make it better
if args.tmodel == "transformer":
from clap_module.tokenizer import tokenize
all_texts = tokenize(all_texts)
else:
from training.data import tokenizer
all_texts = tokenizer(all_texts)
all_text_features = model(None, all_texts, device)
all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu()
# compute similarity
logit_scale_a, logit_scale_t = model(None, None, device)
logit_scale_a = logit_scale_a.cpu()
logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu()
logits_per_text = logits_per_audio.t().detach().cpu()
ground_truth = all_class_labels.view(-1, 1)
logit = logits_per_audio
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread
preds = preds.detach().cpu().numpy()
metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1
metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k)
# map@10
metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
logging.info(
f"Eval Epoch: {start_epoch} "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)
if args.wandb:
assert wandb is not None, "Please install wandb."
for name, val in metrics.items():
wandb.log({f"val/{name}": val, "epoch": start_epoch})
if __name__ == '__main__':
# (yusong) repeated run might have different metric results.
# This is because we randomly select crop 10s for each audio.
args = parse_args()
if os.path.isdir(args.pretrained):
log_dir = os.path.dirname(args.pretrained)
else:
log_dir = os.path.dirname(os.path.dirname(args.pretrained))
args.log_level = logging.DEBUG if args.debug else logging.INFO
log_path = os.path.join(log_dir, 'out.log')
setup_logging(log_path, args.log_level)
params_file = os.path.join(log_dir, 'params.txt')
seed = 3407
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = True
cudnn.deterministic = False
pretrained = 'openai'
amodel = find_params_value(params_file, 'amodel')
tmodel = find_params_value(params_file, 'tmodel')
if amodel is None or tmodel is None:
raise ValueError('model type not found in params file')
# set up dummy values for args
args.parallel_eval = False
args.rank = 0
args.local_rank = 0
args.world_size = 1
args.val_frequency = 1
args.epochs = 1
args.precision = 'fp32'
args.save_logs = True
args.wandb = args.report_to == 'wandb'
args.class_index_dict = None
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args.device = device
if args.remotedata:
for dataset_name in args.datasetnames:
for split in dataset_split[dataset_name]:
if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
os.makedirs(f"./json_files/{dataset_name}/{split}")
os.system(
f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
)
if args.datasetinfos is None:
args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
if args.dataset_type == "webdataset":
args.train_data = get_tar_path_from_dataset_name(
args.datasetnames,
args.datasetinfos,
islocal=not args.remotedata,
proportion=args.dataset_proportion,
dataset_path=args.datasetpath,
)
args.val_data = get_tar_path_from_dataset_name(
args.datasetnames,
["valid", "test", "eval"],
islocal=not args.remotedata,
proportion=1,
dataset_path=args.datasetpath,
)
model, model_cfg = create_model(
amodel,
tmodel,
pretrained,
precision='fp32',
device=device,
jit=False,
force_quick_gelu=False,
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
skip_params=False,
enable_fusion=args.enable_fusion,
fusion_type=args.fusion_type
) # a hack to get model_cfg
data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data
writer = None # if use tensorboard, initalize writer here
if args.wandb:
assert wandb is not None, "Please install wandb."
# # find the line with "wandb_notes" and get the value
# wandb_notes = find_params_value(params_file, 'wandb_notes')
# if wandb_notes is None:
# print(f'wandb_notes not found in params file: {params_file}, set to timestamp.')
# wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}'
# wandb_notes = wandb_notes + '-eval-retrieval'
wandb_notes = args.wandb_notes
logging.debug("Starting wandb.")
args.train_sz = data["train"].dataloader.num_samples
if args.val_data is not None:
args.val_sz = data["val"].dataloader.num_samples
# you will have to configure this for your project!
if args.wandb_id is not None:
wandb.init(
project="clap",
id=args.wandb_id,
resume=True
)
else:
wandb.init(
project="clap",
notes=wandb_notes,
name=wandb_notes,
tags=[],
config=vars(args),
)
logging.debug("Finished loading wandb.")
if os.path.isdir(args.pretrained):
all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime)
else:
all_model_checkpoints = [args.pretrained]
for model_path in all_model_checkpoints:
args.checkpoint_path = os.path.dirname(model_path)
model, model_cfg = create_model(
amodel,
tmodel,
pretrained,
precision='fp32',
device=device,
jit=False,
force_quick_gelu=False,
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
skip_params=False,
enable_fusion=args.enable_fusion,
fusion_type=args.fusion_type
)
# load model
checkpoint = torch.load(model_path, map_location=device)
if "epoch" in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith(
"module"
):
sd = {k[len("module."):]: v for k, v in sd.items()}
model.load_state_dict(sd)
logging.info(
f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})"
)
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
start_epoch = 0
model.to(device)
model.eval()
for param in model.parameters():
param.requires_grad = False
evaluate_zeroshot(model, data, start_epoch, args, writer)