FashionGAN / netdissect /evalablate.py
fiesty-bear
Initial Commit
6064c9d
import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL
from torchvision import transforms
from torch.utils.data import TensorDataset
from netdissect.progress import default_progress, post_progress, desc_progress
from netdissect.progress import verbose_progress, print_progress
from netdissect.nethook import edit_layers
from netdissect.zdataset import standard_z_sample
from netdissect.autoeval import autoimport_eval
from netdissect.easydict import EasyDict
from netdissect.modelconfig import create_instrumented_model
help_epilog = '''\
Example:
python -m netdissect.evalablate \
--segmenter "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')" \
--model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \
--outdir dissect/dissectdir \
--classes mirror coffeetable tree \
--layers layer4 \
--size 1000
Output layout:
dissectdir/layer5/ablation/mirror-iqr.json
{ class: "mirror",
classnum: 43,
pixel_total: 41342300,
class_pixels: 1234531,
layer: "layer5",
ranking: "mirror-iqr",
ablation_units: [341, 23, 12, 142, 83, ...]
ablation_pixels: [143242, 132344, 429931, ...]
}
'''
def main():
# Training settings
def strpair(arg):
p = tuple(arg.split(':'))
if len(p) == 1:
p = p + p
return p
parser = argparse.ArgumentParser(description='Ablation eval',
epilog=textwrap.dedent(help_epilog),
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('--model', type=str, default=None,
help='constructor for the model to test')
parser.add_argument('--pthfile', type=str, default=None,
help='filename of .pth file for the model')
parser.add_argument('--outdir', type=str, default='dissect', required=True,
help='directory for dissection output')
parser.add_argument('--layers', type=strpair, nargs='+',
help='space-separated list of layer names to edit' +
', in the form layername[:reportedname]')
parser.add_argument('--classes', type=str, nargs='+',
help='space-separated list of class names to ablate')
parser.add_argument('--metric', type=str, default='iou',
help='ordering metric for selecting units')
parser.add_argument('--unitcount', type=int, default=30,
help='number of units to ablate')
parser.add_argument('--segmenter', type=str,
help='directory containing segmentation dataset')
parser.add_argument('--netname', type=str, default=None,
help='name for network in generated reports')
parser.add_argument('--batch_size', type=int, default=5,
help='batch size for forward pass')
parser.add_argument('--size', type=int, default=200,
help='number of images to test')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA usage')
parser.add_argument('--quiet', action='store_true', default=False,
help='silences console output')
if len(sys.argv) == 1:
parser.print_usage(sys.stderr)
sys.exit(1)
args = parser.parse_args()
# Set up console output
verbose_progress(not args.quiet)
# Speed up pytorch
torch.backends.cudnn.benchmark = True
# Set up CUDA
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
torch.backends.cudnn.benchmark = True
# Take defaults for model constructor etc from dissect.json settings.
with open(os.path.join(args.outdir, 'dissect.json')) as f:
dissection = EasyDict(json.load(f))
if args.model is None:
args.model = dissection.settings.model
if args.pthfile is None:
args.pthfile = dissection.settings.pthfile
if args.segmenter is None:
args.segmenter = dissection.settings.segmenter
# Instantiate generator
model = create_instrumented_model(args, gen=True, edit=True)
if model is None:
print('No model specified')
sys.exit(1)
# Instantiate model
device = next(model.parameters()).device
input_shape = model.input_shape
# 4d input if convolutional, 2d input if first layer is linear.
raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view(
(args.size,) + input_shape[1:])
dataset = TensorDataset(raw_sample)
# Create the segmenter
segmenter = autoimport_eval(args.segmenter)
# Now do the actual work.
labelnames, catnames = (
segmenter.get_label_and_category_names(dataset))
label_category = [catnames.index(c) if c in catnames else 0
for l, c in labelnames]
labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}
segloader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size, num_workers=10,
pin_memory=(device.type == 'cuda'))
# Index the dissection layers by layer name.
dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}
# First, collect a baseline
for l in model.ablation:
model.ablation[l] = None
# For each sort-order, do an ablation
progress = default_progress()
for classname in progress(args.classes):
post_progress(c=classname)
for layername in progress(model.ablation):
post_progress(l=layername)
rankname = '%s-%s' % (classname, args.metric)
classnum = labelnum_from_name[classname]
try:
ranking = next(r for r in dissect_layer[layername].rankings
if r.name == rankname)
except:
print('%s not found' % rankname)
sys.exit(1)
ordering = numpy.argsort(ranking.score)
# Check if already done
ablationdir = os.path.join(args.outdir, layername, 'pixablation')
if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)):
with open(os.path.join(ablationdir, '%s.json'%rankname)) as f:
data = EasyDict(json.load(f))
# If the unit ordering is not the same, something is wrong
if not all(a == o
for a, o in zip(data.ablation_units, ordering)):
continue
if len(data.ablation_effects) >= args.unitcount:
continue # file already done.
measurements = data.ablation_effects
measurements = measure_ablation(segmenter, segloader,
model, classnum, layername, ordering[:args.unitcount])
measurements = measurements.cpu().numpy().tolist()
os.makedirs(ablationdir, exist_ok=True)
with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f:
json.dump(dict(
classname=classname,
classnum=classnum,
baseline=measurements[0],
layer=layername,
metric=args.metric,
ablation_units=ordering.tolist(),
ablation_effects=measurements[1:]), f)
def measure_ablation(segmenter, loader, model, classnum, layer, ordering):
total_bincount = 0
data_size = 0
device = next(model.parameters()).device
progress = default_progress()
for l in model.ablation:
model.ablation[l] = None
feature_units = model.feature_shape[layer][1]
feature_shape = model.feature_shape[layer][2:]
repeats = len(ordering)
total_scores = torch.zeros(repeats + 1)
for i, batch in enumerate(progress(loader)):
z_batch = batch[0]
model.ablation[layer] = None
tensor_images = model(z_batch.to(device))
seg = segmenter.segment_batch(tensor_images, downsample=2)
mask = (seg == classnum).max(1)[0]
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d(
mask.float()[:,None,:,:], feature_shape)[:,0,:,:]
total_scores[0] += downsampled_seg.sum().cpu()
# Now we need to do an intervention for every location
# that had a nonzero downsampled_seg, if any.
interventions_needed = downsampled_seg.nonzero()
location_count = len(interventions_needed)
if location_count == 0:
continue
interventions_needed = interventions_needed.repeat(repeats, 1)
inter_z = batch[0][interventions_needed[:,0]].to(device)
inter_chan = torch.zeros(repeats, location_count, feature_units,
device=device)
for j, u in enumerate(ordering):
inter_chan[j:, :, u] = 1
inter_chan = inter_chan.view(len(inter_z), feature_units)
inter_loc = interventions_needed[:,1:]
scores = torch.zeros(len(inter_z))
batch_size = len(batch[0])
for j in range(0, len(inter_z), batch_size):
ibz = inter_z[j:j+batch_size]
ibl = inter_loc[j:j+batch_size].t()
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device)
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1
ibc = inter_chan[j:j+batch_size]
model.ablation[layer] = (
imask.float()[:,None,:,:] * ibc[:,:,None,None])
tensor_images = model(ibz)
seg = segmenter.segment_batch(tensor_images, downsample=2)
mask = (seg == classnum).max(1)[0]
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d(
mask.float()[:,None,:,:], feature_shape)[:,0,:,:]
scores[j:j+batch_size] = downsampled_iseg[
(torch.arange(len(ibz)),) + tuple(ibl)]
scores = scores.view(repeats, location_count).sum(1)
total_scores[1:] += scores
return total_scores
def count_segments(segmenter, loader, model):
total_bincount = 0
data_size = 0
progress = default_progress()
for i, batch in enumerate(progress(loader)):
tensor_images = model(z_batch.to(device))
seg = segmenter.segment_batch(tensor_images, downsample=2)
bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
).bincount(minlength=z_batch.shape[0] * self.num_classes)
data_size += seg.shape[0] * seg.shape[2] * seg.shape[3]
total_bincount += batch_label_counts.float().sum(0)
normalized_bincount = total_bincount / data_size
return normalized_bincount
if __name__ == '__main__':
main()