Spaces:
Build error
Build error
import os, sys, numpy, torch, argparse, skimage, json, shutil | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from matplotlib.figure import Figure | |
from matplotlib.ticker import MaxNLocator | |
import matplotlib | |
def main(): | |
parser = argparse.ArgumentParser(description='ACE optimization utility', | |
prog='python -m netdissect.aceoptimize') | |
parser.add_argument('--classname', type=str, default=None, | |
help='intervention classname') | |
parser.add_argument('--layer', type=str, default='layer4', | |
help='layer name') | |
parser.add_argument('--l2_lambda', type=float, nargs='+', | |
help='l2 regularizer hyperparameter') | |
parser.add_argument('--outdir', type=str, default=None, | |
help='dissection directory') | |
parser.add_argument('--variant', type=str, default=None, | |
help='experiment variant') | |
args = parser.parse_args() | |
if args.variant is None: | |
args.variant = 'ace' | |
run_command(args) | |
def run_command(args): | |
fig = Figure(figsize=(4.5,3.5)) | |
FigureCanvas(fig) | |
ax = fig.add_subplot(111) | |
for l2_lambda in args.l2_lambda: | |
variant = args.variant | |
if l2_lambda != 0.01: | |
variant += '_reg%g' % l2_lambda | |
dirname = os.path.join(args.outdir, args.layer, variant, args.classname) | |
snapshots = os.path.join(dirname, 'snapshots') | |
try: | |
dat = [torch.load(os.path.join(snapshots, 'epoch-%d.pth' % i)) | |
for i in range(10)] | |
except: | |
print('Missing %s snapshots' % dirname) | |
return | |
print('reg %g' % l2_lambda) | |
for i in range(10): | |
print(i, dat[i]['avg_loss'], | |
len((dat[i]['ablation'] == 1).nonzero())) | |
ax.plot([dat[i]['avg_loss'] for i in range(10)], | |
label='reg %g' % l2_lambda) | |
ax.set_title('%s %s' % (args.classname, args.variant)) | |
ax.grid(True) | |
ax.legend() | |
ax.set_ylabel('Loss') | |
ax.set_xlabel('Epochs') | |
fig.tight_layout() | |
dirname = os.path.join(args.outdir, args.layer, | |
args.variant, args.classname) | |
fig.savefig(os.path.join(dirname, 'loss-plot.png')) | |
if __name__ == '__main__': | |
main() | |