Polos-Demo / validate /validate_cvpr.py
yuwd's picture
init
03f6091
import json
import argparse
import json
from utils import *
from dataset import *
from metrics import *
from compute_correlations import compute_flickr
from compute_pascal50s import compute_pascal50S
from compute_foil import compute_foil
def collect_coef(memory, dataset_name, method, coef_tensor):
memory.setdefault(dataset_name, {})
coef = {k : round(float(v.numpy() if not isinstance(v,float) else v),4) for k, v in coef_tensor.items()}
memory[dataset_name].update({method : coef})
gprint(f"[{dataset_name}]",method,coef)
def compute_coef(args,memory,tops):
dataset_name = "test"
path = f"data_en/polaris/polaris_{dataset_name}.csv"
yprint(f"Processing {dataset_name} ... (path: {path})")
test_dataset = get_dataset(path)
# mypolos
if args.polos:
polos_coef = compute_polos_coef(args,test_dataset,dataset_name,kendall_type='c')
collect_coef(memory, dataset_name, "Polos", polos_coef)
return memory, tops
def main(args):
memory, tops = {}, {}
if args.flickr:
memory, tops = compute_flickr(args,args.model,memory,tops)
if args.coef:
memory, tops = compute_coef(args, memory, tops)
if args.pascal:
memory, tops = compute_pascal50S(args, memory, tops)
if args.foil:
memory, tops = compute_foil(args, memory, tops)
with open("zeroshot_test_results.json", "w") as f:
json.dump(memory, f, indent=4)
yprint("[RESULTS]")
gprint(json.dumps(memory, indent=4))
rprint("[TOP]")
for dataset_name, values in tops.items():
rprint(f"> {dataset_name}")
if isinstance(values,dict): # coef
for kind, coef in values.items():
rprint(f"{kind}: {coef[0]} ({coef[1]})")
else: # acc
method, acc = values
rprint(f"{method} ({acc})")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# models
parser.add_argument('--model', default=None)
parser.add_argument('--hparams',default=None)
parser.add_argument('--polos', action='store_true')
# benchmarks
parser.add_argument('--coef', action='store_true')
parser.add_argument('--flickr', action='store_true')
parser.add_argument('--pascal', action='store_true')
parser.add_argument('--foil', action='store_true')
args = parser.parse_args()
main(args)