Spaces:
Runtime error
Runtime error
""" | |
This script evaluates the contribution of a technique from the ablation study for | |
improving the masker evaluation metrics. The differences in the metrics are computed | |
for all images of paired models, that is those which only differ in the inclusion or | |
not of the given technique. Then, statistical inference is performed through the | |
percentile bootstrap to obtain robust estimates of the differences in the metrics and | |
confidence intervals. The script plots the distribution of the bootrstraped estimates. | |
""" | |
print("Imports...", end="") | |
from argparse import ArgumentParser | |
import yaml | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import os | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
import matplotlib.transforms as transforms | |
# ----------------------- | |
# ----- Constants ----- | |
# ----------------------- | |
dict_models = { | |
"md": 11, | |
"dada_ms, msd, pseudo": 9, | |
"msd, pseudo": 4, | |
"dada, msd_spade, pseudo": 7, | |
"msd": 13, | |
"dada_m, msd": 17, | |
"dada, msd_spade": 16, | |
"msd_spade, pseudo": 5, | |
"dada_ms, msd": 18, | |
"dada, msd, pseudo": 6, | |
"ms": 12, | |
"dada, msd": 15, | |
"dada_m, msd, pseudo": 8, | |
"msd_spade": 14, | |
"m": 10, | |
"md, pseudo": 2, | |
"ms, pseudo": 3, | |
"m, pseudo": 1, | |
"ground": "G", | |
"instagan": "I", | |
} | |
dict_metrics = { | |
"names": { | |
"tpr": "TPR, Recall, Sensitivity", | |
"tnr": "TNR, Specificity, Selectivity", | |
"fpr": "FPR", | |
"fpt": "False positives relative to image size", | |
"fnr": "FNR, Miss rate", | |
"fnt": "False negatives relative to image size", | |
"mpr": "May positive rate (MPR)", | |
"mnr": "May negative rate (MNR)", | |
"accuracy": "Accuracy (ignoring may)", | |
"error": "Error", | |
"f05": "F05 score", | |
"precision": "Precision", | |
"edge_coherence": "Edge coherence", | |
"accuracy_must_may": "Accuracy (ignoring cannot)", | |
}, | |
"key_metrics": ["f05", "error", "edge_coherence"], | |
} | |
dict_techniques = { | |
"depth": "depth", | |
"segmentation": "seg", | |
"seg": "seg", | |
"dada_s": "dada_seg", | |
"dada_seg": "dada_seg", | |
"dada_segmentation": "dada_seg", | |
"dada_m": "dada_masker", | |
"dada_masker": "dada_masker", | |
"spade": "spade", | |
"pseudo": "pseudo", | |
"pseudo-labels": "pseudo", | |
"pseudo_labels": "pseudo", | |
} | |
# Markers | |
dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"} | |
# Model features | |
model_feats = [ | |
"masker", | |
"seg", | |
"depth", | |
"dada_seg", | |
"dada_masker", | |
"spade", | |
"pseudo", | |
"ground", | |
"instagan", | |
] | |
# Colors | |
palette_colorblind = sns.color_palette("colorblind") | |
color_climategan = palette_colorblind[0] | |
color_munit = palette_colorblind[1] | |
color_cyclegan = palette_colorblind[6] | |
color_instagan = palette_colorblind[8] | |
color_maskinstagan = palette_colorblind[2] | |
color_paintedground = palette_colorblind[3] | |
color_cat1 = palette_colorblind[0] | |
color_cat2 = palette_colorblind[1] | |
palette_lightest = [ | |
sns.light_palette(color_cat1, n_colors=20)[3], | |
sns.light_palette(color_cat2, n_colors=20)[3], | |
] | |
palette_light = [ | |
sns.light_palette(color_cat1, n_colors=3)[1], | |
sns.light_palette(color_cat2, n_colors=3)[1], | |
] | |
palette_medium = [color_cat1, color_cat2] | |
palette_dark = [ | |
sns.dark_palette(color_cat1, n_colors=3)[1], | |
sns.dark_palette(color_cat2, n_colors=3)[1], | |
] | |
palette_cat1 = [ | |
palette_lightest[0], | |
palette_light[0], | |
palette_medium[0], | |
palette_dark[0], | |
] | |
palette_cat2 = [ | |
palette_lightest[1], | |
palette_light[1], | |
palette_medium[1], | |
palette_dark[1], | |
] | |
color_cat1_light = palette_light[0] | |
color_cat2_light = palette_light[1] | |
def parsed_args(): | |
""" | |
Parse and returns command-line args | |
Returns: | |
argparse.Namespace: the parsed arguments | |
""" | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--input_csv", | |
default="ablations_metrics_20210311.csv", | |
type=str, | |
help="CSV containing the results of the ablation study", | |
) | |
parser.add_argument( | |
"--output_dir", | |
default=None, | |
type=str, | |
help="Output directory", | |
) | |
parser.add_argument( | |
"--models", | |
default="all", | |
type=str, | |
help="Models to display: all, pseudo, no_dada_masker, no_baseline", | |
) | |
parser.add_argument( | |
"--dpi", | |
default=200, | |
type=int, | |
help="DPI for the output images", | |
) | |
parser.add_argument( | |
"--n_bs", | |
default=1e6, | |
type=int, | |
help="Number of bootrstrap samples", | |
) | |
parser.add_argument( | |
"--alpha", | |
default=0.99, | |
type=float, | |
help="Confidence level", | |
) | |
parser.add_argument( | |
"--bs_seed", | |
default=17, | |
type=int, | |
help="Bootstrap random seed, for reproducibility", | |
) | |
return parser.parse_args() | |
def plot_median_metrics( | |
df, do_stripplot=True, dpi=200, bs_seed=37, n_bs=1000, **snskwargs | |
): | |
def plot_metric( | |
ax, df, metric, do_stripplot=True, dpi=200, bs_seed=37, marker="o", **snskwargs | |
): | |
y_labels = [dict_models[f] for f in df.model_feats.unique()] | |
# Labels | |
y_labels_int = np.sort([el for el in y_labels if isinstance(el, int)]).tolist() | |
y_order_int = [ | |
k for vs in y_labels_int for k, vu in dict_models.items() if vs == vu | |
] | |
y_labels_int = [str(el) for el in y_labels_int] | |
y_labels_str = sorted([el for el in y_labels if not isinstance(el, int)]) | |
y_order_str = [ | |
k for vs in y_labels_str for k, vu in dict_models.items() if vs == vu | |
] | |
y_labels = y_labels_int + y_labels_str | |
y_order = y_order_int + y_order_str | |
# Palette | |
palette = len(y_labels_int) * [color_climategan] | |
for y in y_labels_str: | |
if y == "G": | |
palette = palette + [color_paintedground] | |
if y == "I": | |
palette = palette + [color_maskinstagan] | |
# Error | |
sns.pointplot( | |
ax=ax, | |
data=df, | |
x=metric, | |
y="model_feats", | |
order=y_order, | |
markers=marker, | |
estimator=np.median, | |
ci=99, | |
seed=bs_seed, | |
n_boot=n_bs, | |
join=False, | |
scale=0.6, | |
errwidth=1.5, | |
capsize=0.1, | |
palette=palette, | |
) | |
xlim = ax.get_xlim() | |
if do_stripplot: | |
sns.stripplot( | |
ax=ax, | |
data=df, | |
x=metric, | |
y="model_feats", | |
size=1.5, | |
palette=palette, | |
alpha=0.2, | |
) | |
ax.set_xlim(xlim) | |
# Set X-label | |
ax.set_xlabel(dict_metrics["names"][metric], rotation=0, fontsize="medium") | |
# Set Y-label | |
ax.set_ylabel(None) | |
ax.set_yticklabels(y_labels, fontsize="medium") | |
# Change spines | |
sns.despine(ax=ax, left=True, bottom=True) | |
# Draw gray area on final model | |
xlim = ax.get_xlim() | |
ylim = ax.get_ylim() | |
trans = transforms.blended_transform_factory(ax.transAxes, ax.transData) | |
rect = mpatches.Rectangle( | |
xy=(0.0, 5.5), | |
width=1, | |
height=1, | |
transform=trans, | |
linewidth=0.0, | |
edgecolor="none", | |
facecolor="gray", | |
alpha=0.05, | |
) | |
ax.add_patch(rect) | |
# Set up plot | |
sns.set(style="whitegrid") | |
plt.rcParams.update({"font.family": "serif"}) | |
plt.rcParams.update( | |
{ | |
"font.serif": [ | |
"Computer Modern Roman", | |
"Times New Roman", | |
"Utopia", | |
"New Century Schoolbook", | |
"Century Schoolbook L", | |
"ITC Bookman", | |
"Bookman", | |
"Times", | |
"Palatino", | |
"Charter", | |
"serif" "Bitstream Vera Serif", | |
"DejaVu Serif", | |
] | |
} | |
) | |
fig_h = 0.4 * len(df.model_feats.unique()) | |
fig, axes = plt.subplots( | |
nrows=1, ncols=3, sharey=True, dpi=dpi, figsize=(18, fig_h) | |
) | |
# Error | |
plot_metric( | |
axes[0], | |
df, | |
"error", | |
do_stripplot=do_stripplot, | |
dpi=dpi, | |
bs_seed=bs_seed, | |
marker=dict_markers["error"], | |
) | |
axes[0].set_ylabel("Models") | |
# F05 | |
plot_metric( | |
axes[1], | |
df, | |
"f05", | |
do_stripplot=do_stripplot, | |
dpi=dpi, | |
bs_seed=bs_seed, | |
marker=dict_markers["f05"], | |
) | |
# Edge coherence | |
plot_metric( | |
axes[2], | |
df, | |
"edge_coherence", | |
do_stripplot=do_stripplot, | |
dpi=dpi, | |
bs_seed=bs_seed, | |
marker=dict_markers["edge_coherence"], | |
) | |
xticks = axes[2].get_xticks() | |
xticklabels = ["{:.3f}".format(x) for x in xticks] | |
axes[2].set(xticks=xticks, xticklabels=xticklabels) | |
plt.subplots_adjust(wspace=0.12) | |
return fig | |
if __name__ == "__main__": | |
# ----------------------------- | |
# ----- Parse arguments ----- | |
# ----------------------------- | |
args = parsed_args() | |
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
# Determine output dir | |
if args.output_dir is None: | |
output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
else: | |
output_dir = Path(args.output_dir) | |
if not output_dir.exists(): | |
output_dir.mkdir(parents=True, exist_ok=False) | |
# Store args | |
output_yml = output_dir / "ablation_comparison_{}.yml".format(args.models) | |
with open(output_yml, "w") as f: | |
yaml.dump(vars(args), f) | |
# Read CSV | |
df = pd.read_csv(args.input_csv, index_col="model_img_idx") | |
# Determine models | |
if "all" in args.models.lower(): | |
pass | |
else: | |
if "no_baseline" in args.models.lower(): | |
df = df.loc[(df.ground == False) & (df.instagan == False)] | |
if "pseudo" in args.models.lower(): | |
df = df.loc[ | |
(df.pseudo == True) | (df.ground == True) | (df.instagan == True) | |
] | |
if "no_dada_mask" in args.models.lower(): | |
df = df.loc[ | |
(df.dada_masker == False) | (df.ground == True) | (df.instagan == True) | |
] | |
fig = plot_median_metrics(df, do_stripplot=True, dpi=args.dpi, bs_seed=args.bs_seed) | |
# Save figure | |
output_fig = output_dir / "ablation_comparison_{}.png".format(args.models) | |
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") | |