Spaces:
Sleeping
Sleeping
import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL | |
from torchvision import transforms | |
from torch.utils.data import TensorDataset | |
from netdissect.progress import default_progress, post_progress, desc_progress | |
from netdissect.progress import verbose_progress, print_progress | |
from netdissect.nethook import edit_layers | |
from netdissect.zdataset import standard_z_sample | |
from netdissect.autoeval import autoimport_eval | |
from netdissect.easydict import EasyDict | |
from netdissect.modelconfig import create_instrumented_model | |
help_epilog = '''\ | |
Example: | |
python -m netdissect.evalablate \ | |
--segmenter "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')" \ | |
--model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \ | |
--outdir dissect/dissectdir \ | |
--classes mirror coffeetable tree \ | |
--layers layer4 \ | |
--size 1000 | |
Output layout: | |
dissectdir/layer5/ablation/mirror-iqr.json | |
{ class: "mirror", | |
classnum: 43, | |
pixel_total: 41342300, | |
class_pixels: 1234531, | |
layer: "layer5", | |
ranking: "mirror-iqr", | |
ablation_units: [341, 23, 12, 142, 83, ...] | |
ablation_pixels: [143242, 132344, 429931, ...] | |
} | |
''' | |
def main(): | |
# Training settings | |
def strpair(arg): | |
p = tuple(arg.split(':')) | |
if len(p) == 1: | |
p = p + p | |
return p | |
parser = argparse.ArgumentParser(description='Ablation eval', | |
epilog=textwrap.dedent(help_epilog), | |
formatter_class=argparse.RawDescriptionHelpFormatter) | |
parser.add_argument('--model', type=str, default=None, | |
help='constructor for the model to test') | |
parser.add_argument('--pthfile', type=str, default=None, | |
help='filename of .pth file for the model') | |
parser.add_argument('--outdir', type=str, default='dissect', required=True, | |
help='directory for dissection output') | |
parser.add_argument('--layers', type=strpair, nargs='+', | |
help='space-separated list of layer names to edit' + | |
', in the form layername[:reportedname]') | |
parser.add_argument('--classes', type=str, nargs='+', | |
help='space-separated list of class names to ablate') | |
parser.add_argument('--metric', type=str, default='iou', | |
help='ordering metric for selecting units') | |
parser.add_argument('--unitcount', type=int, default=30, | |
help='number of units to ablate') | |
parser.add_argument('--segmenter', type=str, | |
help='directory containing segmentation dataset') | |
parser.add_argument('--netname', type=str, default=None, | |
help='name for network in generated reports') | |
parser.add_argument('--batch_size', type=int, default=5, | |
help='batch size for forward pass') | |
parser.add_argument('--size', type=int, default=200, | |
help='number of images to test') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA usage') | |
parser.add_argument('--quiet', action='store_true', default=False, | |
help='silences console output') | |
if len(sys.argv) == 1: | |
parser.print_usage(sys.stderr) | |
sys.exit(1) | |
args = parser.parse_args() | |
# Set up console output | |
verbose_progress(not args.quiet) | |
# Speed up pytorch | |
torch.backends.cudnn.benchmark = True | |
# Set up CUDA | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
if args.cuda: | |
torch.backends.cudnn.benchmark = True | |
# Take defaults for model constructor etc from dissect.json settings. | |
with open(os.path.join(args.outdir, 'dissect.json')) as f: | |
dissection = EasyDict(json.load(f)) | |
if args.model is None: | |
args.model = dissection.settings.model | |
if args.pthfile is None: | |
args.pthfile = dissection.settings.pthfile | |
if args.segmenter is None: | |
args.segmenter = dissection.settings.segmenter | |
# Instantiate generator | |
model = create_instrumented_model(args, gen=True, edit=True) | |
if model is None: | |
print('No model specified') | |
sys.exit(1) | |
# Instantiate model | |
device = next(model.parameters()).device | |
input_shape = model.input_shape | |
# 4d input if convolutional, 2d input if first layer is linear. | |
raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view( | |
(args.size,) + input_shape[1:]) | |
dataset = TensorDataset(raw_sample) | |
# Create the segmenter | |
segmenter = autoimport_eval(args.segmenter) | |
# Now do the actual work. | |
labelnames, catnames = ( | |
segmenter.get_label_and_category_names(dataset)) | |
label_category = [catnames.index(c) if c in catnames else 0 | |
for l, c in labelnames] | |
labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} | |
segloader = torch.utils.data.DataLoader(dataset, | |
batch_size=args.batch_size, num_workers=10, | |
pin_memory=(device.type == 'cuda')) | |
# Index the dissection layers by layer name. | |
dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} | |
# First, collect a baseline | |
for l in model.ablation: | |
model.ablation[l] = None | |
# For each sort-order, do an ablation | |
progress = default_progress() | |
for classname in progress(args.classes): | |
post_progress(c=classname) | |
for layername in progress(model.ablation): | |
post_progress(l=layername) | |
rankname = '%s-%s' % (classname, args.metric) | |
classnum = labelnum_from_name[classname] | |
try: | |
ranking = next(r for r in dissect_layer[layername].rankings | |
if r.name == rankname) | |
except: | |
print('%s not found' % rankname) | |
sys.exit(1) | |
ordering = numpy.argsort(ranking.score) | |
# Check if already done | |
ablationdir = os.path.join(args.outdir, layername, 'pixablation') | |
if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)): | |
with open(os.path.join(ablationdir, '%s.json'%rankname)) as f: | |
data = EasyDict(json.load(f)) | |
# If the unit ordering is not the same, something is wrong | |
if not all(a == o | |
for a, o in zip(data.ablation_units, ordering)): | |
continue | |
if len(data.ablation_effects) >= args.unitcount: | |
continue # file already done. | |
measurements = data.ablation_effects | |
measurements = measure_ablation(segmenter, segloader, | |
model, classnum, layername, ordering[:args.unitcount]) | |
measurements = measurements.cpu().numpy().tolist() | |
os.makedirs(ablationdir, exist_ok=True) | |
with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f: | |
json.dump(dict( | |
classname=classname, | |
classnum=classnum, | |
baseline=measurements[0], | |
layer=layername, | |
metric=args.metric, | |
ablation_units=ordering.tolist(), | |
ablation_effects=measurements[1:]), f) | |
def measure_ablation(segmenter, loader, model, classnum, layer, ordering): | |
total_bincount = 0 | |
data_size = 0 | |
device = next(model.parameters()).device | |
progress = default_progress() | |
for l in model.ablation: | |
model.ablation[l] = None | |
feature_units = model.feature_shape[layer][1] | |
feature_shape = model.feature_shape[layer][2:] | |
repeats = len(ordering) | |
total_scores = torch.zeros(repeats + 1) | |
for i, batch in enumerate(progress(loader)): | |
z_batch = batch[0] | |
model.ablation[layer] = None | |
tensor_images = model(z_batch.to(device)) | |
seg = segmenter.segment_batch(tensor_images, downsample=2) | |
mask = (seg == classnum).max(1)[0] | |
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( | |
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] | |
total_scores[0] += downsampled_seg.sum().cpu() | |
# Now we need to do an intervention for every location | |
# that had a nonzero downsampled_seg, if any. | |
interventions_needed = downsampled_seg.nonzero() | |
location_count = len(interventions_needed) | |
if location_count == 0: | |
continue | |
interventions_needed = interventions_needed.repeat(repeats, 1) | |
inter_z = batch[0][interventions_needed[:,0]].to(device) | |
inter_chan = torch.zeros(repeats, location_count, feature_units, | |
device=device) | |
for j, u in enumerate(ordering): | |
inter_chan[j:, :, u] = 1 | |
inter_chan = inter_chan.view(len(inter_z), feature_units) | |
inter_loc = interventions_needed[:,1:] | |
scores = torch.zeros(len(inter_z)) | |
batch_size = len(batch[0]) | |
for j in range(0, len(inter_z), batch_size): | |
ibz = inter_z[j:j+batch_size] | |
ibl = inter_loc[j:j+batch_size].t() | |
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) | |
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 | |
ibc = inter_chan[j:j+batch_size] | |
model.ablation[layer] = ( | |
imask.float()[:,None,:,:] * ibc[:,:,None,None]) | |
tensor_images = model(ibz) | |
seg = segmenter.segment_batch(tensor_images, downsample=2) | |
mask = (seg == classnum).max(1)[0] | |
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( | |
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] | |
scores[j:j+batch_size] = downsampled_iseg[ | |
(torch.arange(len(ibz)),) + tuple(ibl)] | |
scores = scores.view(repeats, location_count).sum(1) | |
total_scores[1:] += scores | |
return total_scores | |
def count_segments(segmenter, loader, model): | |
total_bincount = 0 | |
data_size = 0 | |
progress = default_progress() | |
for i, batch in enumerate(progress(loader)): | |
tensor_images = model(z_batch.to(device)) | |
seg = segmenter.segment_batch(tensor_images, downsample=2) | |
bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 | |
).bincount(minlength=z_batch.shape[0] * self.num_classes) | |
data_size += seg.shape[0] * seg.shape[2] * seg.shape[3] | |
total_bincount += batch_label_counts.float().sum(0) | |
normalized_bincount = total_bincount / data_size | |
return normalized_bincount | |
if __name__ == '__main__': | |
main() | |