import glob import json import os import re import numpy as np import pandas as pds import matplotlib import matplotlib.pyplot as plt from math import sqrt import torch from root import PRJROOT from sklearn.manifold import TSNE from itertools import product, chain # from src.drl.drl_uses import load_cfgs from src.gan.gankits import get_decoder, process_onehot from src.gan.gans import nz from src.smb.level import load_batch, hamming_dis, lvlhcat from src.utils.datastruct import RingQueue from src.utils.filesys import load_dict_json, getpath from src.utils.img import make_img_sheet from torch.distributions import Normal matplotlib.rcParams["axes.formatter.limits"] = (-5, 5) def print_compare_tab(): rand_lgp, rand_fhp, rand_divs = load_dict_json( 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity' ) rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs} def _print_line(_data, minimise=False): means = _data.mean(axis=-1) stds = _data.std(axis=-1) max_i, min_i = np.argmax(means), np.argmin(means) mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))] std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))] if minimise: mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i] mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i] std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i] std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i] else: mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i] mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i] std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i] std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i] print(' &', ' & '.join(mean_str_content), r'\\') print(' & &', ' & '.join(std_str_content), r'\\') pass def _print_block(_task): fds = [ f'sac/{_task}', f'egsac/{_task}', f'asyncsac/{_task}', f'pmoe/{_task}', f'dvd/{_task}', f'sunrise/{_task}', f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5', f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5' ] rewards, divs = [], [] for fd in fds: rewards.append([]) divs.append([]) # print(getpath()) for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True): reward, div = load_dict_json(path, 'reward', 'diversity') rewards[-1].append(reward) divs[-1].append(div) rewards = np.array(rewards) divs = np.array(divs) print(' & \\multirow{2}{*}{Reward}') _print_line(rewards) print(' \\cline{2-14}') print(' & \\multirow{2}{*}{Diversity}') _print_line(divs) print(' \\cline{2-14}') print(' & \\multirow{2}{*}{G-mean}') gmean = np.sqrt(rewards * divs) _print_line(gmean) print(' \\cline{2-14}') print(' & \\multirow{2}{*}{N-rank}') r_rank = np.zeros_like(rewards.flatten()) r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank)) d_rank = np.zeros_like(divs.flatten()) d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank)) n_rank = (r_rank.reshape([12, 5]) + d_rank.reshape([12, 5])) / (2 * 5) _print_line(n_rank, True) print(' \\multirow{8}{*}{MarioPuzzle}') _print_block('fhp') print(' \\midrule') print(' \\multirow{8}{*}{MultiFacet}') _print_block('lgp') pass def print_compare_tab_nonrl(): # rand_lgp, rand_fhp, rand_divs = load_dict_json( # 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity' # ) # rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs} def _print_line(_data, minimise=False): means = _data.mean(axis=-1) stds = _data.std(axis=-1) max_i, min_i = np.argmax(means), np.argmin(means) mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))] std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))] if minimise: mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i] mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i] std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i] std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i] else: mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i] mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i] std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i] std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i] print(' &', ' & '.join(mean_str_content), r'\\') print(' & &', ' & '.join(std_str_content), r'\\') pass def _print_block(_task): fds = [ f'GAN-{_task}', f'DDPM-{_task}', f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5', f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5' ] rewards, divs = [], [] for fd in fds: rewards.append([]) divs.append([]) # print(getpath()) for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True): reward, div = load_dict_json(path, 'reward', 'diversity') rewards[-1].append(reward) divs[-1].append(div) rewards = np.array(rewards) divs = np.array(divs) print(' & \\multirow{2}{*}{Reward}') _print_line(rewards) print(' \\cline{2-10}') print(' & \\multirow{2}{*}{Diversity}') _print_line(divs) print(' \\cline{2-10}') # print(' & \\multirow{2}{*}{G-mean}') # gmean = np.sqrt(rewards * divs) # _print_line(gmean) # # print(' \\cline{2-10}') # print(' & \\multirow{2}{*}{N-rank}') # r_rank = np.zeros_like(rewards.flatten()) # r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank)) # # d_rank = np.zeros_like(divs.flatten()) # d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank)) # n_rank = (r_rank.reshape([8, 5]) + d_rank.reshape([8, 5])) / (2 * 5) # _print_line(n_rank, True) print(' \\multirow{4}{*}{MarioPuzzle}') _print_block('fhp') print(' \\midrule') print(' \\multirow{4}{*}{MultiFacet}') _print_block('lgp') pass def plot_cmp_learning_curves(task, save_path='', title=''): plt.style.use('seaborn') colors = [plt.plot([0, 1], [-1000, -1000])[0].get_color() for _ in range(6)] plt.cla() plt.style.use('default') # colors = ('#5D2CAB', '#005BD4', '#007CE4', '#0097DD', '#00ADC4', '#00C1A5') def _get_algo_data(fd): res = [] for i in range(1, 6): path = getpath(fd, f't{i}', 'step_tests.csv') try: data = pds.read_csv(path) trajectory = [ [float(item['step']), float(item['r-avg']), float(item['diversity'])] for _, item in data.iterrows() ] trajectory.sort(key=lambda x: x[0]) res.append(trajectory) if len(trajectory) != 26: print('Not complete (%d)/26:' % len(trajectory), path) except FileNotFoundError: print(path) res = np.array(res) # rdsum = res[:, :, 1] + res[:, :, 2] gmean = np.sqrt(res[:, :, 1] * res[:, :, 2]) steps = res[0, :, 0] # r_avgs = np.mean(res[:, :, 1], axis=0) # r_stds = np.std(res[:, :, 1], axis=0) # divs = np.mean(res[:, :, 2], axis=0) # div_std = np.std(res[:, :, 2], axis=0) _performances = { 'reward': (np.mean(res[:, :, 1], axis=0), np.std(res[:, :, 1], axis=0)), 'diversity': (np.mean(res[:, :, 2], axis=0), np.std(res[:, :, 2], axis=0)), # 'rdsum': (np.mean(rdsum, axis=0), np.std(rdsum, axis=0)), 'gmean': (np.mean(gmean, axis=0), np.std(gmean, axis=0)), } # print(_performances['gmean']) return steps, _performances def _plot_criterion(_ax, _criterion): i, j, k = 0, 0, 0 for algo, (steps, _performances) in performances.items(): avgs, stds = _performances[_criterion] if '\lambda' in algo: ls = '-' _c = colors[i] i += 1 elif algo in {'SAC', 'EGSAC', 'ASAC'}: ls = ':' _c = colors[j] j += 1 else: ls = '--' _c = colors[j] j += 1 _ax.plot(steps, avgs, color=_c, label=algo, ls=ls) _ax.fill_between(steps, avgs - stds, avgs + stds, color=_c, alpha=0.15) _ax.grid(False) # plt.plot(steps, avgs, label=algo) # plt.plot(_performances, label=algo) pass _ax.set_xlabel('Time step') fig, ax = plt.subplots(1, 3, figsize=(9.6, 3.2), dpi=250, width_ratios=[1, 1, 1]) # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 4), dpi=256) # fig, ax1 = plt.subplots(1, 1, figsize=(8, 3), dpi=256) # ax2 = ax1.twinx() # fig = plt.plot(figsize=(4, 3), dpi=256) performances = { 'SUNRISE': _get_algo_data(f'test_data/sunrise/{task}'), '$\lambda$=0.0': _get_algo_data(f'test_data/varpm-{task}/l0.0_m5'), 'DvD': _get_algo_data(f'test_data/dvd/{task}'), '$\lambda$=0.1': _get_algo_data(f'test_data/varpm-{task}/l0.1_m5'), 'PMOE': _get_algo_data(f'test_data/pmoe/{task}'), '$\lambda$=0.2': _get_algo_data(f'test_data/varpm-{task}/l0.2_m5'), 'SAC': _get_algo_data(f'test_data/sac/{task}'), '$\lambda$=0.3': _get_algo_data(f'test_data/varpm-{task}/l0.3_m5'), 'EGSAC': _get_algo_data(f'test_data/egsac/{task}'), '$\lambda$=0.4': _get_algo_data(f'test_data/varpm-{task}/l0.4_m5'), 'ASAC': _get_algo_data(f'test_data/asyncsac/{task}'), '$\lambda$=0.5': _get_algo_data(f'test_data/varpm-{task}/l0.5_m5'), } # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SAC', '**', 'step_tests.csv'))), 'SAC') # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/EGSAC', '**', 'step_tests.csv'))), 'EGSAC') # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/AsyncSAC', '**', 'step_tests.csv'))), 'AsyncSAC') # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SUNRISE', '**', 'step_tests.csv'))), 'SUNRISE') # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/DvD-ES', '**', 'step_tests.csv'))), 'DvD-ES') # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/lbd-m-crosstest/l0.04_m5', '**', 'step_tests.csv'))), 'NCESAC') _plot_criterion(ax[0], 'reward') _plot_criterion(ax[1], 'diversity') # _plot_criterion(ax[2], 'rdsum') _plot_criterion(ax[2], 'gmean') # ax[0].set_title(f'{title} reward') ax[0].set_title(f'Cumulative Reward') ax[1].set_title('Diversity Score') # ax[2].set_title('Summation') ax[2].set_title('G-mean') # plt.title(title) lines, labels = fig.axes[-1].get_legend_handles_labels() fig.suptitle(title, fontsize=14) plt.tight_layout(pad=0.5) if save_path: plt.savefig(getpath(save_path)) else: plt.show() plt.cla() plt.figure(figsize=(9.6, 2.4), dpi=250) plt.grid(False) plt.axis('off') plt.yticks([1.0]) plt.legend( lines, labels, loc='lower center', ncol=6, edgecolor='white', fontsize=15, columnspacing=0.8, borderpad=0.16, labelspacing=0.2, handlelength=2.4, handletextpad=0.3 ) plt.tight_layout(pad=0.5) plt.show() pass def plot_crosstest_scatters(rfunc, xrange=None, yrange=None, title=''): def get_pareto(): all_points = list(chain(*scatter_groups.values())) + cmp_points res = [] for p in all_points: non_dominated = True for q in all_points: if q[0] >= p[0] and q[1] >= p[1] and (q[0] > p[0] or q[1] > p[1]): non_dominated = False break if non_dominated: res.append(p) res.sort(key=lambda item:item[0]) return np.array(res) def _hex_color(_c): return scatter_groups = {} all_lbd = set() # Initialise plt.style.use('seaborn-v0_8-dark-palette') # plt.figure(figsize=(4, 4), dpi=256) plt.figure(figsize=(2.5, 2.5), dpi=256) plt.axes().set_axisbelow(True) # Competitors' performances cmp_folders = ['asyncsac', 'egsac', 'sac', 'sunrise', 'dvd', 'pmoe'] cmp_names = ['ASAC', 'EGSAC', 'SAC', 'SUNRISE', 'DvD', 'PMOE'] cmp_labels = ['A', 'E', 'S', 'R', 'D', 'M'] cmp_markers = ['2', 'x', '+', 'o', '*', 'D'] cmp_sizes = [42, 20, 32, 16, 24, 10, 10] cmp_points = [] for name, folder, label, mk, s in zip(cmp_names, cmp_folders, cmp_labels, cmp_markers, cmp_sizes): path_fmt = getpath('test_data', folder, rfunc, '*', 'performance.csv') # print(path_fmt) xs, ys = [], [] for path in glob.glob(path_fmt, recursive=True): # print(path) try: x, y = load_dict_json(path, 'reward', 'diversity') xs.append(x) ys.append(y) cmp_points.append([x, y]) # plt.text(x, y, label, size=7, weight='bold', va='center', ha='center', color='#202020') except FileNotFoundError: print(path) if label in {'A', 'E', 'S'}: plt.scatter(xs, ys, marker=mk, zorder=2, s=s, label=name, color='#202020') else: plt.scatter( xs, ys, marker=mk, zorder=2, s=s, label=name, color=[0., 0., 0., 0.], edgecolors='#202020', linewidths=1 ) # NCESAC performances for path in glob.glob(getpath('test_data', f'varpm-{rfunc}', '**', 'performance.csv'), recursive=True): try: x, y = load_dict_json(path, 'reward', 'diversity') key = path.split('\\')[-3] _, mtxt = key.split('_') ltxt, _ = key.split('_') lbd = float(ltxt[1:]) # if mtxt in {'m2', 'm3', 'm4'}: # continue all_lbd.add(lbd) if key not in scatter_groups.keys(): scatter_groups[key] = [] scatter_groups[key].append([x, y]) except Exception as e: print(path) print(e) palette = plt.get_cmap('seismic') color_x = [0.2, 0.33, 0.4, 0.61, 0.67, 0.79] colors = {lbd: matplotlib.colors.to_hex(c) for c, lbd in zip(palette(color_x), sorted(all_lbd))} colors = {0.0: '#150080', 0.1: '#066598', 0.2: '#01E499', 0.3: '#9FD40C', 0.4: '#F3B020', 0.5: '#FA0000'} for lbd in sorted(all_lbd): plt.plot([-20], [-20], label=f'$\\lambda={lbd:.1f}$', lw=6, c=colors[lbd]) markers = {2: 'o', 3: '^', 4: 'D', 5: 'p', 6: 'h'} msizes = {2: 25, 3: 25, 4: 16, 5: 28, 6: 32} for key, group in scatter_groups.items(): ltxt, mtxt = key.split('_') l = float(ltxt[1:]) m = int(mtxt[1:]) arr = np.array(group) plt.scatter( arr[:, 0], arr[:, 1], marker=markers[m], s=msizes[m], color=[0., 0., 0., 0.], zorder=2, edgecolors=colors[l], linewidths=1 ) plt.xlim(xrange) plt.ylim(yrange) # plt.xlabel('Task Reward') # plt.ylabel('Diversity') # plt.legend(ncol=2) # plt.legend( # ncol=2, loc='lower left', columnspacing=1.2, borderpad=0.0, # handlelength=1, handletextpad=0.5, framealpha=0. # ) pareto = get_pareto() plt.plot( pareto[:, 0], pareto[:, 1], color='black', alpha=0.18, lw=6, zorder=3, solid_joinstyle='round', solid_capstyle='round' ) # plt.plot([88, 98, 98, 88, 88], [35, 35, 0.2, 0.2, 35], color='black', alpha=0.3, lw=1.5) # plt.xticks(fontsize=16) # plt.yticks(fontsize=16) # plt.xticks([(1+space) * (m-mlow) + 0.5 for m in ms], [f'm={m}' for m in ms]) plt.title(title) plt.grid() plt.tight_layout(pad=0.4) plt.show() def plot_varpm_heat(task, name): def _get_score(m, l): fd = getpath('test_data', f'varpm-{task}', f'l{l}_m{m}') rewards, divs = [], [] for i in range(5): reward, div = load_dict_json(f'{fd}/t{i+1}/performance.csv', 'reward', 'diversity') rewards.append(reward) divs.append(div) gmean = [sqrt(r * d) for r, d in zip(rewards, divs)] return np.mean(rewards), np.std(rewards), \ np.mean(divs), np.std(divs), \ np.mean(gmean), np.std(gmean) def _plot_map(avg_map, std_map, criterion): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=256, width_ratios=(1, 1)) heat1 = ax1.imshow(avg_map, cmap='spring') heat2 = ax2.imshow(std_map, cmap='spring') ax1.set_xlim([-0.5, 5.5]) ax1.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5']) ax1.set_ylim([-0.5, 3.5]) ax1.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2']) ax1.set_title('Average') for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]): v = avg_map[y, x] s = '%.4f' % v if v >= 1000: s = s[:4] elif v >= 1: s = s[:5] else: s = s[1:6] ax1.text(x, y, s, va='center', ha='center') plt.colorbar(heat1, ax=ax1, shrink=0.9) ax2.set_xlim([-0.5, 5.5]) ax2.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5']) ax2.set_ylim([-0.5, 3.5]) ax2.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2']) for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]): v = std_map[y, x] s = '%.4f' % v if v >= 1000: s = s[:4] elif v >= 1: s = s[:5] else: s = s[1:6] ax2.text(x, y, s, va='center', ha='center') ax2.set_title('Standard Deviation') plt.colorbar(heat2, ax=ax2, shrink=0.9) fig.suptitle(f'{name}: {criterion}', fontsize=14) plt.tight_layout() # plt.show() plt.savefig(getpath(f'results/heat/{name}-{criterion}.png')) r_mean_map, r_std_map, d_mean_map, d_std_map, g_mean_map, g_std_map \ = (np.zeros([4, 6], dtype=float) for _ in range(6)) ms = [2, 3, 4, 5] ls = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'] for i, j in product(range(4), range(6)): r_mean, r_std, d_mean, d_std, g_mean, g_std = _get_score(ms[i], ls[j]) r_mean_map[i, j] = r_mean r_std_map[i, j] = r_std d_mean_map[i, j] = d_mean d_std_map[i, j] = d_std g_mean_map[i, j] = g_mean g_std_map[i, j] = g_std _plot_map(r_mean_map, r_std_map, 'Reward') _plot_map(d_mean_map, d_std_map, 'Diversity') _plot_map(g_mean_map, g_std_map,'G-mean') # _plot_map(g_mean_map, g_std_map,'G-mean') def vis_samples(): # for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]): # for i in range(1, 6): # lvls = load_batch(f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls') # imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]] # make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.png') # for algo in ['sac', 'egsac', 'asyncsac', 'dvd', 'sunrise', 'pmoe']: # for i in range(1, 6): # lvls = load_batch(f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.lvls') # imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]] # make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.png') for i in range(1, 6): lvls = load_batch(f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.lvls') imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]] make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.png') pass pass def make_tsne(task, title, n=500, save_path=None): if not os.path.exists(getpath('test_data', f'samples_dist-{task}_{n}.npy')): samples = [] for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']: for t in range(5): lvls = load_batch(getpath('test_data', algo, task, f't{t+1}', 'samples.lvls')) samples += lvls[:n] for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']: for t in range(5): lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t+1}', 'samples.lvls')) samples += lvls[:n] distmat = [] for a in samples: dist_list = [] for b in samples: dist_list.append(hamming_dis(a, b)) distmat.append(dist_list) distmat = np.array(distmat) np.save(getpath('test_data', f'samples_dist-{task}_{n}.npy'), distmat) labels = ( '$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5', 'DvD', 'EGSAC', 'PMOE', 'SUNRISE', 'ASAC', 'SAC' ) tsne = TSNE(learning_rate='auto', n_components=2, metric='precomputed') print(np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy')).shape) data = np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy')) embx = np.array(tsne.fit_transform(data)) plt.style.use('seaborn-dark-palette') plt.figure(figsize=(5, 5), dpi=384) colors = [plt.plot([-1000, -1100], [0, 0])[0].get_color() for _ in range(6)] for i in range(6): x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1] plt.scatter(x, y, s=10, label=labels[i], marker='x', c=colors[i]) for i in range(6, 12): x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1] plt.scatter(x, y, s=8, linewidths=0, label=labels[i], c=colors[i-6]) # plt.scatter(embx[100:200, 0], embx[100:200, 1], c=colors[1], s=12, linewidths=0, label='Killer') # plt.scatter(embx[200:, 0], embx[200:, 1], c=colors[2], s=12, linewidths=0, label='Collector') # for i in range(4): # plt.text(embx[i+100, 0], embx[i+100, 1], str(i+1)) # plt.text(embx[i+200, 0], embx[i+200, 1], str(i+1)) # pass # for emb, lb, c in zip(embs, labels,colors): # plt.scatter(emb[:,0], emb[:,1], c=c, label=lb, alpha=0.15, linewidths=0, s=7) # xspan = 1.08 * max(abs(embx[:, 0].max()), abs(embx[:, 0].min())) # yspan = 1.08 * max(abs(embx[:, 1].max()), abs(embx[:, 1].min())) xrange = [1.05 * embx[:, 0].min(), 1.05 * embx[:, 0].max()] yrange = [1.05 * embx[:, 1].min(), 1.25 * embx[:, 1].max()] plt.xlim(xrange) plt.ylim(yrange) plt.xticks([]) plt.yticks([]) # plt.legend(ncol=6, handletextpad=0.02, labelspacing=0.05, columnspacing=0.16) # plt.xticks([-xspan, -0.5 * xspan, 0, 0.5 * xspan, xspan], [''] * 5) # plt.yticks([-yspan, -0.5 * yspan, 0, 0.6 * yspan, yspan], [''] * 5) plt.title(title) plt.legend(loc='upper center', ncol=6, fontsize=9, handlelength=.5, handletextpad=0.5, columnspacing=0.3, framealpha=0.) plt.tight_layout(pad=0.2) if save_path: plt.savefig(getpath(save_path)) else: plt.show() def _prob_fmt(p, digitals=3, threshold=0.001): fmt = '%.' + str(digitals) + 'f' if p < threshold: return '$\\approx 0$' else: txt = '$%s$' % ((fmt % p)[1:]) if txt == '$.000$': txt = '$1.00$' return txt def _g_fmt(v, digitals=4): fmt = '%.' + str(digitals) + 'g' txt = (fmt % v) lack = digitals - len(txt.replace('-', '').replace('.', '')) if lack > 0 and '.' not in txt: txt += '.' return txt + '0' * lack pass def print_selection_prob(path, h=15, runs=2): s0 = 0 model = torch.load(getpath(f'{path}/policy.pth'), map_location='cpu') model.requires_grad_(False) model.to('cpu') n = 11 # n = load_cfgs(path, 'N') # print(model.m) init_vec = np.load(getpath('analysis/initial_seg.npy'))[s0] decoder = get_decoder(device='cpu') obs_buffer = RingQueue(n) for r in range(runs): for _ in range(h): obs_buffer.push(np.zeros([nz])) obs_buffer.push(init_vec) level_latvecs = [init_vec] probs = np.zeros([model.m, h]) # probs = [] selects = [] for t in range(h): # probs.append([]) obs = torch.tensor(np.concatenate(obs_buffer.to_list(), axis=-1), dtype=torch.float).view([1, -1]) muss, stdss, betas = model.get_intermediate(torch.tensor(obs)) i = torch.multinomial(betas.squeeze(), 1).item() # print(i) mu, std = muss[0][i], stdss[0][i] action = Normal(mu, std).rsample([1]).squeeze().numpy() # print(action) # print(mu) # print(std) # print(action.numpy()) obs_buffer.push(action) level_latvecs.append(action) # i = torch.multinomial(betas.squeeze(), 1).item() # print(i) probs[:, t] = betas.squeeze().numpy() selects.append(i) pass onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1)) segs = process_onehot(onehots) lvl = lvlhcat(segs) lvl.to_img(f'figures/gen_process/run{r}-01.png') txts = [[_prob_fmt(p) for p in row] for row in probs] for t, i in enumerate(selects): txts[i][t] = r'$\boldsymbol{%s}$' % txts[i][t][1:-1] for i, txt in enumerate(txts): print(f' & $\\beta_{i+1}$ &', ' & '.join(txt), r'\\') print(r'\midrule') pass def calc_selection_freqs(task, n): def _count_one_init(): counts = np.zeros([model.m]) # init_vec = np.load(getpath('analysis/initial_seg.npy')) obs_buffer = RingQueue(n) for _ in range(runs): for _ in range(h): obs_buffer.push(np.zeros([len(init_vecs), nz])) obs_buffer.push(init_vecs) # level_latvecs = [init_vec] for _ in range(h): obs = np.concatenate(obs_buffer.to_list(), axis=-1) obs = torch.tensor(obs, device='cuda:0', dtype=torch.float) muss, stdss, betas = model.get_intermediate(obs) selects = torch.multinomial(betas.squeeze(), 1).squeeze() mus = muss[[*range(len(init_vecs))], selects, :] stds = stdss[[*range(len(init_vecs))], selects, :] actions = Normal(mus, stds).rsample().squeeze().cpu().numpy() obs_buffer.push(actions) for i in selects: counts[i] = counts[i] + 1 return counts # onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1)) pass pass init_vecs = np.load(getpath('analysis/initial_seg.npy')) freqs = [[] for _ in range(30)] start_line = 0 for l in ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5'): print(r' \midrule') for t, m in product(range(1, 6), (2, 3, 4, 5)): path = getpath(f'test_data/varpm-{task}/l{l}_m{m}/t{t}') model = torch.load(getpath(f'{path}/policy.pth'), map_location='cuda:0') model.requires_grad_(False) freq = np.zeros([m]) # n = load_cfgs(path, 'N') runs, h = 100, 25 freq += _count_one_init() freq /= (len(init_vecs) * runs * h) freq = np.sort(freq)[::-1] i = start_line + t - 1 freqs[i] += freq.tolist() print(freqs[i]) start_line += 5 print(freqs) with open(getpath(f'analysis/select_freqs-{task}.json'), 'w') as f: json.dump(freqs, f) def print_selection_freq(): # task, n = 'lgp', 5 task, n = 'fhp', 11 if not os.path.exists(getpath(f'analysis/select_freqs-{task}.json')): calc_selection_freqs(task, n) with open(getpath(f'analysis/select_freqs-{task}.json'), 'r') as f: freqs = json.load(f) lbds = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'] for i, row_data in enumerate(freqs): if i % 5 == 0: print(r' \midrule') print(r' \multirow{5}{*}{$%s$}' % lbds[i//5]) txt = ' & '.join(map(_prob_fmt, row_data)) print(f' & {i%5+1} &', txt, r'\\') def print_individual_performances(task): for m, l in product((2, 3, 4, 5), ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5')): values = [] if l == '0.0': print(r' \midrule') print(r' \multirow{6}{*}{%d}' % m) for t in range(1, 6): path = f'test_data/varpm-{task}/l{l}_m{m}/t{t}/performance.csv' reward, diversity = load_dict_json(path, 'reward', 'diversity') values.append([reward, diversity]) values.sort(key=lambda item: -item[0]) values = [*chain(*values)] txts = [_g_fmt(v) for v in values] print(' &', f'${l}$ & ', ' & '.join(txts), r'\\') pass if __name__ == '__main__': # print_selection_prob('test_data/varpm-fhp/l0.5_m5/t5') # print_selection_prob('test_data/varpm-fhp/l0.1_m5/t5') # print_selection_freq() # print_compare_tab_nonrl() # print_individual_performances('fhp') # print('\n\n') # print_individual_performances('lgp') # plot_cmp_learning_curves('fhp', save_path='results/learning_curves/fhp.png', title='MarioPuzzle') # plot_cmp_learning_curves('lgp', save_path='results/learning_curves/lgp.png', title='MultiFacet') # plot_crosstest_scatters('fhp', title='MarioPuzzle') # plot_crosstest_scatters('lgp', title='MultiFacet') # # plot_crosstest_scatters('fhp', yrange=(0, 2500), xrange=(20, 70), title='MarioPuzzle') # plot_crosstest_scatters('lgp', yrange=(0, 1500), xrange=(20, 50), title='MultiFacet') # plot_crosstest_scatters('lgp', yrange=(0, 800), xrange=(44, 48), title=' ') # plot_varpm_heat('fhp', 'MarioPuzzle') # plot_varpm_heat('lgp', 'MultiFacet') vis_samples() # make_tsne('fhp', 'MarioPuzzle', n=100) # make_tsne('lgp', 'MultiFacet', n=100) pass