import argparse, os, shutil, json from netdissect.easydict import EasyDict from xml.etree import ElementTree as et from collections import defaultdict def parseargs(): parser = argparse.ArgumentParser() def aa(*args, **kwargs): parser.add_argument(*args, **kwargs) aa('--model', choices=['resnet18-bn', 'resnet18-hn', 'resnet18-ha'], default='resnet18-bn') aa('--iteration', type=int, default=0) aa('--dataset', choices=['places'], default='places') aa('--seg', choices=['net', 'netp', 'netq', 'netpq', 'netpqc', 'netpqxc', 'human'], default='net') aa('--layers', nargs='+') aa('--quantile', type=float, default=0.01) aa('--miniou', type=float, default=0.025) args = parser.parse_args() return args def main(): args = parseargs() threshold_iou = args.miniou layer_report = {} qdir = '-%d' % (args.quantile * 1000) if args.quantile != 0.01 else '' for layer in args.layers: input_filename = 'results/%s-%d-%s-%s-%s%s/report.json' % ( args.model, args.iteration, args.dataset, args.seg, layer, qdir) with open(input_filename) as f: layer_report[layer] = EasyDict(json.load(f)) # Now assemble the data needed for the graph # (Layername, [(catname, [unitcount, unitcount, unitcount]), (catname..) cat_order = ['object', 'part', 'material', 'color'] graph_data = [] for layer in args.layers: layer_data = [] catmap = defaultdict(lambda: defaultdict(int)) units = layer_report[layer].get('units', layer_report[layer].get('images', None)) # old format for unitrec in units: if unitrec.iou is None or unitrec.iou < threshold_iou: continue catmap[unitrec.cat][unitrec.label] += 1 for cat in cat_order: if cat not in catmap: continue # For this graph we do not need labels cat_data = list(catmap[cat].values()) cat_data.sort(key=lambda x: -x) layer_data.append((cat, cat_data)) graph_data.append((layer, layer_data)) # Now make the actual graph largest_layer = max(sum(len(cat_data) for cat, cat_data in layer_data) for layer, layer_data in graph_data) layer_height = 14 layer_gap = 2 barwidth = 3 bargap = 0 leftmargin = 48 margin = 8 svgwidth = largest_layer * (barwidth + bargap) + margin + leftmargin svgheight = ((layer_height + layer_gap) * len(args.layers) - layer_gap + 2 * margin) textsize = 10 # create an SVG XML element svg = et.Element('svg', width=str(svgwidth), height=str(svgheight), version='1.1', xmlns='http://www.w3.org/2000/svg') # Draw big category background rectangles y = margin for layer, layer_data in graph_data: et.SubElement(svg, 'text', x='0', y='0', style=('font-family:sans-serif;font-size:%dpx;' + 'text-anchor:end;alignment-baseline:hanging;' + 'transform:translate(%dpx, %dpx);') % (textsize, leftmargin - 4, y + (layer_height - textsize) / 2) ).text = str(layer) barmax = max(max(cat_data) if len(cat_data) else 1 for cat, cat_data in layer_data) if len(layer_data) else 1 barscale = float(layer_height) / barmax x = leftmargin for cat, cat_data in layer_data: catwidth = len(cat_data) * (barwidth + bargap) et.SubElement(svg, 'rect', x=str(x), y=str(y), width=str(catwidth), height=str(layer_height), fill=cat_palette[cat][1]) for bar in cat_data: barheight = barscale * bar et.SubElement(svg, 'rect', x=str(x), y=str(y + layer_height - barheight), width=str(barwidth), height=str(barheight), fill=cat_palette[cat][0]) x += barwidth + bargap y += layer_height + layer_gap # Output - this is the bare svg. result = et.tostring(svg).decode('utf-8') # Now add the file header. result = ''.join([ '\n', '\n', result]) output_filename = 'results/%s-%s-%s-%s%s/multilayer-%d.svg' % ( args.model, args.iteration, args.dataset, args.seg, qdir, args.miniou * 1000) os.makedirs(os.path.dirname(output_filename), exist_ok=True) print('writing to %s' % output_filename) with open(output_filename, 'w') as f: f.write(result) cat_palette = { 'object': ('#4B4CBF', '#B6B6F2'), 'part': ('#55B05B', '#B6F2BA'), 'material': ('#50BDAC', '#A5E5DB'), 'texture': ('#81C679', '#C0FF9B'), 'color': ('#F0883B', '#F2CFB6'), 'other1': ('#D4CF24', '#F2F1B6'), 'other2': ('#D92E2B', '#F2B6B6'), 'other3': ('#AB6BC6', '#CFAAFF') } if __name__ == '__main__': main()