File size: 5,262 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse, os, shutil, json
from netdissect.easydict import EasyDict
from xml.etree import ElementTree as et
from collections import defaultdict

def parseargs():
    parser = argparse.ArgumentParser()
    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)
    aa('--model', choices=['resnet18-bn', 'resnet18-hn', 'resnet18-ha'],
            default='resnet18-bn')
    aa('--iteration', type=int, default=0)
    aa('--dataset', choices=['places'],
            default='places')
    aa('--seg', choices=['net', 'netp', 'netq', 'netpq',
            'netpqc', 'netpqxc', 'human'],
            default='net')
    aa('--layers', nargs='+')
    aa('--quantile', type=float, default=0.01)
    aa('--miniou', type=float, default=0.025)
    args = parser.parse_args()
    return args

def main():
    args = parseargs()
    threshold_iou = args.miniou
    layer_report = {}
    qdir = '-%d' % (args.quantile * 1000) if args.quantile != 0.01 else ''
    for layer in args.layers:
        input_filename = 'results/%s-%d-%s-%s-%s%s/report.json' % (
                args.model, args.iteration, args.dataset, args.seg, layer, qdir)
        with open(input_filename) as f:
            layer_report[layer] = EasyDict(json.load(f))
    # Now assemble the data needed for the graph
    # (Layername, [(catname, [unitcount, unitcount, unitcount]), (catname..)
    cat_order = ['object', 'part', 'material', 'color']
    graph_data = []
    for layer in args.layers:
        layer_data = []
        catmap = defaultdict(lambda: defaultdict(int))
        units = layer_report[layer].get('units',
                layer_report[layer].get('images', None)) # old format
        for unitrec in units:
            if unitrec.iou is None or unitrec.iou < threshold_iou:
                continue
            catmap[unitrec.cat][unitrec.label] += 1
        for cat in cat_order:
            if cat not in catmap:
                continue
            # For this graph we do not need labels
            cat_data = list(catmap[cat].values())
            cat_data.sort(key=lambda x: -x)
            layer_data.append((cat, cat_data))
        graph_data.append((layer, layer_data))
    # Now make the actual graph
    largest_layer = max(sum(len(cat_data)
            for cat, cat_data in layer_data)
            for layer, layer_data in graph_data)
    layer_height = 14
    layer_gap = 2
    barwidth = 3
    bargap = 0
    leftmargin = 48
    margin = 8
    svgwidth = largest_layer * (barwidth + bargap) + margin + leftmargin
    svgheight = ((layer_height + layer_gap) * len(args.layers) - layer_gap +
            2 * margin)
    textsize = 10

    # create an SVG XML element
    svg = et.Element('svg', width=str(svgwidth), height=str(svgheight),
            version='1.1', xmlns='http://www.w3.org/2000/svg')

    # Draw big category background rectangles
    y = margin
    for layer, layer_data in graph_data:
        et.SubElement(svg, 'text', x='0', y='0',
            style=('font-family:sans-serif;font-size:%dpx;' +
                'text-anchor:end;alignment-baseline:hanging;' +
                'transform:translate(%dpx, %dpx);') %
                (textsize, leftmargin - 4, y + (layer_height - textsize) / 2)
            ).text = str(layer)
        barmax = max(max(cat_data) if len(cat_data) else 1
                for cat, cat_data in layer_data) if len(layer_data) else 1
        barscale = float(layer_height) / barmax
        x = leftmargin
        for cat, cat_data in layer_data:
            catwidth = len(cat_data) * (barwidth + bargap)
            et.SubElement(svg, 'rect',
                    x=str(x), y=str(y),
                    width=str(catwidth),
                    height=str(layer_height),
                    fill=cat_palette[cat][1])
            for bar in cat_data:
                barheight = barscale * bar
                et.SubElement(svg, 'rect',
                        x=str(x), y=str(y + layer_height - barheight),
                        width=str(barwidth),
                        height=str(barheight),
                        fill=cat_palette[cat][0])
                x += barwidth + bargap
        y += layer_height + layer_gap

    # Output - this is the bare svg.
    result = et.tostring(svg).decode('utf-8')
    # Now add the file header.
    result = ''.join([
            '<?xml version=\"1.0\" standalone=\"no\"?>\n',
            '<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n',
            '\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n',
            result])
    output_filename = 'results/%s-%s-%s-%s%s/multilayer-%d.svg' % (
                args.model, args.iteration, args.dataset, args.seg, qdir,
                args.miniou * 1000)
    os.makedirs(os.path.dirname(output_filename), exist_ok=True)
    print('writing to %s' % output_filename)
    with open(output_filename, 'w') as f:
        f.write(result)

cat_palette = {
    'object':   ('#4B4CBF', '#B6B6F2'),
    'part':     ('#55B05B', '#B6F2BA'),
    'material': ('#50BDAC', '#A5E5DB'),
    'texture':  ('#81C679', '#C0FF9B'),
    'color':    ('#F0883B', '#F2CFB6'),
    'other1':   ('#D4CF24', '#F2F1B6'),
    'other2':   ('#D92E2B', '#F2B6B6'),
    'other3':   ('#AB6BC6', '#CFAAFF')
}

if __name__ == '__main__':
    main()