Spaces:
Build error
Build error
import os, torch, numpy, base64, json, re, threading, random | |
from torch.utils.data import TensorDataset, DataLoader | |
from collections import defaultdict | |
from netdissect.easydict import EasyDict | |
from netdissect.modelconfig import create_instrumented_model | |
from netdissect.runningstats import RunningQuantile | |
from netdissect.dissection import safe_dir_name | |
from netdissect.zdataset import z_sample_for_model | |
from PIL import Image | |
from io import BytesIO | |
class DissectionProject: | |
''' | |
DissectionProject understand how to drive a GanTester within a | |
dissection project directory structure: it caches data in files, | |
creates image files, and translates data between plain python data | |
types and the pytorch-specific tensors required by GanTester. | |
''' | |
def __init__(self, config, project_dir, path_url, public_host): | |
print('config done', project_dir) | |
self.use_cuda = torch.cuda.is_available() | |
self.dissect = config | |
self.project_dir = project_dir | |
self.path_url = path_url | |
self.public_host = public_host | |
self.cachedir = os.path.join(self.project_dir, 'cache') | |
self.tester = GanTester( | |
config.settings, dissectdir=project_dir, | |
device=torch.device('cuda') if self.use_cuda | |
else torch.device('cpu')) | |
self.stdz = [] | |
def get_zs(self, size): | |
if size <= len(self.stdz): | |
return self.stdz[:size].tolist() | |
z_tensor = self.tester.standard_z_sample(size) | |
numpy_z = z_tensor.cpu().numpy() | |
self.stdz = numpy_z | |
return self.stdz.tolist() | |
def get_z(self, id): | |
if id < len(self.stdz): | |
return self.stdz[id] | |
return self.get_zs((id + 1) * 2)[id] | |
def get_zs_for_ids(self, ids): | |
max_id = max(ids) | |
if max_id >= len(self.stdz): | |
self.get_z(max_id) | |
return self.stdz[ids] | |
def get_layers(self): | |
result = [] | |
layer_shapes = self.tester.layer_shapes() | |
for layer in self.tester.layers: | |
shape = layer_shapes[layer] | |
result.append(dict( | |
layer=layer, | |
channels=shape[1], | |
shape=[shape[2], shape[3]])) | |
return result | |
def get_units(self, layer): | |
try: | |
dlayer = [dl for dl in self.dissect['layers'] | |
if dl['layer'] == layer][0] | |
except: | |
return None | |
dunits = dlayer['units'] | |
result = [dict(unit=unit_num, | |
img='/%s/%s/s-image/%d-top.jpg' % | |
(self.path_url, layer, unit_num), | |
label=unit['iou_label']) | |
for unit_num, unit in enumerate(dunits)] | |
return result | |
def get_rankings(self, layer): | |
try: | |
dlayer = [dl for dl in self.dissect['layers'] | |
if dl['layer'] == layer][0] | |
except: | |
return None | |
result = [dict(name=ranking['name'], | |
metric=ranking.get('metric', None), | |
scores=ranking['score']) | |
for ranking in dlayer['rankings']] | |
return result | |
def get_levels(self, layer, quantiles): | |
levels = self.tester.levels( | |
layer, torch.from_numpy(numpy.array(quantiles))) | |
return levels.cpu().numpy().tolist() | |
def generate_images(self, zs, ids, interventions, return_urls=False): | |
if ids is not None: | |
assert zs is None | |
zs = self.get_zs_for_ids(ids) | |
if not interventions: | |
# Do file caching when ids are given (and no ablations). | |
imgdir = os.path.join(self.cachedir, 'img', 'id') | |
os.makedirs(imgdir, exist_ok=True) | |
exist = set(os.listdir(imgdir)) | |
unfinished = [('%d.jpg' % id) not in exist for id in ids] | |
needed_z_tensor = torch.tensor(zs[unfinished]).float().to( | |
self.tester.device) | |
needed_ids = numpy.array(ids)[unfinished] | |
# Generate image files for just the needed images. | |
if len(needed_z_tensor): | |
imgs = self.tester.generate_images(needed_z_tensor | |
).cpu().numpy() | |
for i, img in zip(needed_ids, imgs): | |
Image.fromarray(img.transpose(1, 2, 0)).save( | |
os.path.join(imgdir, '%d.jpg' % i), 'jpeg', | |
quality=99, optimize=True, progressive=True) | |
# Assemble a response. | |
imgurls = ['/%s/cache/img/id/%d.jpg' | |
% (self.path_url, i) for i in ids] | |
return [dict(id=i, d=d) for i, d in zip(ids, imgurls)] | |
# No file caching when ids are not given (or ablations are applied) | |
z_tensor = torch.tensor(zs).float().to(self.tester.device) | |
imgs = self.tester.generate_images(z_tensor, | |
intervention=decode_intervention_array(interventions, | |
self.tester.layer_shapes()), | |
).cpu().numpy() | |
numpy_z = z_tensor.cpu().numpy() | |
if return_urls: | |
randdir = '%03d' % random.randrange(1000) | |
imgdir = os.path.join(self.cachedir, 'img', 'uniq', randdir) | |
os.makedirs(imgdir, exist_ok=True) | |
startind = random.randrange(100000) | |
imgurls = [] | |
for i, img in enumerate(imgs): | |
filename = '%d.jpg' % (i + startind) | |
Image.fromarray(img.transpose(1, 2, 0)).save( | |
os.path.join(imgdir, filename), 'jpeg', | |
quality=99, optimize=True, progressive=True) | |
image_url_path = ('/%s/cache/img/uniq/%s/%s' | |
% (self.path_url, randdir, filename)) | |
imgurls.append(image_url_path) | |
tweet_filename = 'tweet-%d.html' % (i + startind) | |
tweet_url_path = ('/%s/cache/img/uniq/%s/%s' | |
% (self.path_url, randdir, tweet_filename)) | |
with open(os.path.join(imgdir, tweet_filename), 'w') as f: | |
f.write(twitter_card(image_url_path, tweet_url_path, | |
self.public_host)) | |
return [dict(d=d) for d in imgurls] | |
imgurls = [img2base64(img.transpose(1, 2, 0)) for img in imgs] | |
return [dict(d=d) for d in imgurls] | |
def get_features(self, ids, masks, layers, interventions): | |
zs = self.get_zs_for_ids(ids) | |
z_tensor = torch.tensor(zs).float().to(self.tester.device) | |
t_masks = torch.stack( | |
[torch.from_numpy(mask_to_numpy(mask)) for mask in masks] | |
)[:,None,:,:].to(self.tester.device) | |
t_features = self.tester.feature_stats(z_tensor, t_masks, | |
decode_intervention_array(interventions, | |
self.tester.layer_shapes()), layers) | |
# Convert torch arrays to plain python lists before returning. | |
return { layer: { key: value.cpu().numpy().tolist() | |
for key, value in feature.items() } | |
for layer, feature in t_features.items() } | |
def get_featuremaps(self, ids, layers, interventions): | |
zs = self.get_zs_for_ids(ids) | |
z_tensor = torch.tensor(zs).float().to(self.tester.device) | |
# Quantilized features are returned. | |
q_features = self.tester.feature_maps(z_tensor, | |
decode_intervention_array(interventions, | |
self.tester.layer_shapes()), layers) | |
# Scale them 0-255 and return them. | |
# TODO: turn them into pngs for returning. | |
return { layer: [ | |
value.clamp(0, 1).mul(255).byte().cpu().numpy().tolist() | |
for value in valuelist ] | |
for layer, valuelist in q_features.items() | |
if (not layers) or (layer in layers) } | |
def get_recipes(self): | |
recipedir = os.path.join(self.project_dir, 'recipe') | |
if not os.path.isdir(recipedir): | |
return [] | |
result = [] | |
for filename in os.listdir(recipedir): | |
with open(os.path.join(recipedir, filename)) as f: | |
result.append(json.load(f)) | |
return result | |
class GanTester: | |
''' | |
GanTester holds on to a specific model to test. | |
(1) loads and instantiates the GAN; | |
(2) instruments it at every layer so that units can be ablated | |
(3) precomputes z dimensionality, and output image dimensions. | |
''' | |
def __init__(self, args, dissectdir=None, device=None): | |
self.cachedir = os.path.join(dissectdir, 'cache') | |
self.device = device if device is not None else torch.device('cpu') | |
self.dissectdir = dissectdir | |
self.modellock = threading.Lock() | |
# Load the generator from the pth file. | |
args_copy = EasyDict(args) | |
args_copy.edit = True | |
model = create_instrumented_model(args_copy) | |
model.eval() | |
self.model = model | |
# Get the set of layers of interest. | |
# Default: all shallow children except last. | |
self.layers = sorted(model.retained_features().keys()) | |
# Move it to CUDA if wanted. | |
model.to(device) | |
self.quantiles = { | |
layer: load_quantile_if_present(os.path.join(self.dissectdir, | |
safe_dir_name(layer)), 'quantiles.npz', | |
device=torch.device('cpu')) | |
for layer in self.layers } | |
def layer_shapes(self): | |
return self.model.feature_shape | |
def standard_z_sample(self, size=100, seed=1, device=None): | |
''' | |
Generate a standard set of random Z as a (size, z_dimension) tensor. | |
With the same random seed, it always returns the same z (e.g., | |
the first one is always the same regardless of the size.) | |
''' | |
result = z_sample_for_model(self.model, size) | |
if device is not None: | |
result = result.to(device) | |
return result | |
def reset_intervention(self): | |
self.model.remove_edits() | |
def apply_intervention(self, intervention): | |
''' | |
Applies an ablation recipe of the form [(layer, unit, alpha)...]. | |
''' | |
self.reset_intervention() | |
if not intervention: | |
return | |
for layer, (a, v) in intervention.items(): | |
self.model.edit_layer(layer, ablation=a, replacement=v) | |
def generate_images(self, z_batch, intervention=None): | |
''' | |
Makes some images. | |
''' | |
with torch.no_grad(), self.modellock: | |
batch_size = 10 | |
self.apply_intervention(intervention) | |
test_loader = DataLoader(TensorDataset(z_batch[:,:,None,None]), | |
batch_size=batch_size, | |
pin_memory=('cuda' == self.device.type | |
and z_batch.device.type == 'cpu')) | |
result_img = torch.zeros( | |
*((len(z_batch), 3) + self.model.output_shape[2:]), | |
dtype=torch.uint8, device=self.device) | |
for batch_num, [batch_z,] in enumerate(test_loader): | |
batch_z = batch_z.to(self.device) | |
out = self.model(batch_z) | |
result_img[batch_num*batch_size: | |
batch_num*batch_size+len(batch_z)] = ( | |
(((out + 1) / 2) * 255).clamp(0, 255).byte()) | |
return result_img | |
def get_layers(self): | |
return self.layers | |
def feature_stats(self, z_batch, | |
masks=None, intervention=None, layers=None): | |
feature_stat = defaultdict(dict) | |
with torch.no_grad(), self.modellock: | |
batch_size = 10 | |
self.apply_intervention(intervention) | |
if masks is None: | |
masks = torch.ones(z_batch.size(0), 1, 1, 1, | |
device=z_batch.device, dtype=z_batch.dtype) | |
else: | |
assert masks.shape[0] == z_batch.shape[0] | |
assert masks.shape[1] == 1 | |
test_loader = DataLoader( | |
TensorDataset(z_batch[:,:,None,None], masks), | |
batch_size=batch_size, | |
pin_memory=('cuda' == self.device.type | |
and z_batch.device.type == 'cpu')) | |
processed = 0 | |
for batch_num, [batch_z, batch_m] in enumerate(test_loader): | |
batch_z, batch_m = [ | |
d.to(self.device) for d in [batch_z, batch_m]] | |
# Run model but disregard output | |
self.model(batch_z) | |
processing = batch_z.shape[0] | |
for layer, feature in self.model.retained_features().items(): | |
if layers is not None: | |
if layer not in layers: | |
continue | |
# Compute max features touching mask | |
resized_max = torch.nn.functional.adaptive_max_pool2d( | |
batch_m, | |
(feature.shape[2], feature.shape[3])) | |
max_feature = (feature * resized_max).view( | |
feature.shape[0], feature.shape[1], -1 | |
).max(2)[0].max(0)[0] | |
if 'max' not in feature_stat[layer]: | |
feature_stat[layer]['max'] = max_feature | |
else: | |
torch.max(feature_stat[layer]['max'], max_feature, | |
out=feature_stat[layer]['max']) | |
# Compute mean features weighted by overlap with mask | |
resized_mean = torch.nn.functional.adaptive_avg_pool2d( | |
batch_m, | |
(feature.shape[2], feature.shape[3])) | |
mean_feature = (feature * resized_mean).view( | |
feature.shape[0], feature.shape[1], -1 | |
).sum(2).sum(0) / (resized_mean.sum() + 1e-15) | |
if 'mean' not in feature_stat[layer]: | |
feature_stat[layer]['mean'] = mean_feature | |
else: | |
feature_stat[layer]['mean'] = ( | |
processed * feature_mean[layer]['mean'] | |
+ processing * mean_feature) / ( | |
processed + processing) | |
processed += processing | |
# After summaries are done, also compute quantile stats | |
for layer, stats in feature_stat.items(): | |
if self.quantiles.get(layer, None) is not None: | |
for statname in ['max', 'mean']: | |
stats['%s_quantile' % statname] = ( | |
self.quantiles[layer].normalize(stats[statname])) | |
return feature_stat | |
def levels(self, layer, quantiles): | |
return self.quantiles[layer].quantiles(quantiles) | |
def feature_maps(self, z_batch, intervention=None, layers=None, | |
quantiles=True): | |
feature_map = defaultdict(list) | |
with torch.no_grad(), self.modellock: | |
batch_size = 10 | |
self.apply_intervention(intervention) | |
test_loader = DataLoader( | |
TensorDataset(z_batch[:,:,None,None]), | |
batch_size=batch_size, | |
pin_memory=('cuda' == self.device.type | |
and z_batch.device.type == 'cpu')) | |
processed = 0 | |
for batch_num, [batch_z] in enumerate(test_loader): | |
batch_z = batch_z.to(self.device) | |
# Run model but disregard output | |
self.model(batch_z) | |
processing = batch_z.shape[0] | |
for layer, feature in self.model.retained_features().items(): | |
for single_featuremap in feature: | |
if quantiles: | |
feature_map[layer].append(self.quantiles[layer] | |
.normalize(single_featuremap)) | |
else: | |
feature_map[layer].append(single_featuremap) | |
return feature_map | |
def load_quantile_if_present(outdir, filename, device): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
result = RunningQuantile(state=data) | |
result.to_(device) | |
return result | |
return None | |
if __name__ == '__main__': | |
test_main() | |
def mask_to_numpy(mask_record): | |
# Detect a png image mask. | |
bitstring = mask_record['bitstring'] | |
bitnumpy = None | |
default_shape = (256, 256) | |
if 'image/png;base64,' in bitstring: | |
bitnumpy = base642img(bitstring) | |
default_shape = bitnumpy.shape[:2] | |
# Set up results | |
shape = mask_record.get('shape', None) | |
if not shape: # None or empty [] | |
shape = default_shape | |
result = numpy.zeros(shape=shape, dtype=numpy.float32) | |
bitbounds = mask_record.get('bitbounds', None) | |
if not bitbounds: # None or empty [] | |
bitbounds = ([0] * len(result.shape)) + list(result.shape) | |
start = bitbounds[:len(result.shape)] | |
end = bitbounds[len(result.shape):] | |
if bitnumpy is not None: | |
if bitnumpy.shape[2] == 4: | |
# Mask is any nontransparent bits in the alpha channel if present | |
result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,3] > 0) | |
else: | |
# Or any nonwhite pixels in the red channel if no alpha. | |
result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,0] < 255) | |
return result | |
else: | |
# Or bitstring can be just ones and zeros. | |
indexes = start.copy() | |
bitindex = 0 | |
while True: | |
result[tuple(indexes)] = (bitstring[bitindex] != '0') | |
for ii in range(len(indexes) - 1, -1, -1): | |
if indexes[ii] < end[ii] - 1: | |
break | |
indexes[ii] = start[ii] | |
else: | |
assert (bitindex + 1) == len(bitstring) | |
return result | |
indexes[ii] += 1 | |
bitindex += 1 | |
def decode_intervention_array(interventions, layer_shapes): | |
result = {} | |
for channels in [decode_intervention(intervention, layer_shapes) | |
for intervention in (interventions or [])]: | |
for layer, channel in channels.items(): | |
if layer not in result: | |
result[layer] = channel | |
continue | |
accum = result[layer] | |
newalpha = 1 - (1 - channel[:1]) * (1 - accum[:1]) | |
newvalue = (accum[1:] * accum[:1] * (1 - channel[:1]) + | |
channel[1:] * channel[:1]) / (newalpha + 1e-40) | |
accum[:1] = newalpha | |
accum[1:] = newvalue | |
return result | |
def decode_intervention(intervention, layer_shapes): | |
# Every plane of an intervention is a solid choice of activation | |
# over a set of channels, with a mask applied to alpha-blended channels | |
# (when the mask resolution is different from the feature map, it can | |
# be either a max-pooled or average-pooled to the proper resolution). | |
# This can be reduced to a single alpha-blended featuremap. | |
if intervention is None: | |
return None | |
mask = intervention.get('mask', None) | |
if mask: | |
mask = torch.from_numpy(mask_to_numpy(mask)) | |
maskpooling = intervention.get('maskpooling', 'max') | |
channels = {} # layer -> ([alpha, val], c) | |
for arec in intervention.get('ablations', []): | |
unit = arec['unit'] | |
layer = arec['layer'] | |
alpha = arec.get('alpha', 1.0) | |
if alpha is None: | |
alpha = 1.0 | |
value = arec.get('value', 0.0) | |
if value is None: | |
value = 0.0 | |
if alpha != 0.0 or value != 0.0: | |
if layer not in channels: | |
channels[layer] = torch.zeros(2, *layer_shapes[layer][1:]) | |
channels[layer][0, unit] = alpha | |
channels[layer][1, unit] = value | |
if mask is not None: | |
for layer in channels: | |
layer_shape = layer_shapes[layer][2:] | |
if maskpooling == 'mean': | |
layer_mask = torch.nn.functional.adaptive_avg_pool2d( | |
mask[None,None,...], layer_shape)[0] | |
else: | |
layer_mask = torch.nn.functional.adaptive_max_pool2d( | |
mask[None,None,...], layer_shape)[0] | |
channels[layer][0] *= layer_mask | |
return channels | |
def img2base64(imgarray, for_html=True, image_format='jpeg'): | |
''' | |
Converts a numpy array to a jpeg base64 url | |
''' | |
input_image_buff = BytesIO() | |
Image.fromarray(imgarray).save(input_image_buff, image_format, | |
quality=99, optimize=True, progressive=True) | |
res = base64.b64encode(input_image_buff.getvalue()).decode('ascii') | |
if for_html: | |
return 'data:image/' + image_format + ';base64,' + res | |
else: | |
return res | |
def base642img(stringdata): | |
stringdata = re.sub('^(?:data:)?image/\w+;base64,', '', stringdata) | |
im = Image.open(BytesIO(base64.b64decode(stringdata))) | |
return numpy.array(im) | |
def twitter_card(image_path, tweet_path, public_host): | |
return '''\ | |
<!doctype html> | |
<html> | |
<head> | |
<meta name="twitter:card" content="summary_large_image" /> | |
<meta name="twitter:title" content="Painting with GANs from MIT-IBM Watson AI Lab" /> | |
<meta name="twitter:description" content="This demo lets you modify a selection of meaningful GAN units for a generated image by simply painting." /> | |
<meta name="twitter:image" content="http://{public_host}{image_path}" /> | |
<meta name="twitter:url" content="http://{public_host}{tweet_path}" /> | |
<meta http-equiv="refresh" content="10; url=http://bit.ly/ganpaint"> | |
</head> | |
<style> | |
body {{ font: 12px Arial, sans-serif; }} | |
</style> | |
<body> | |
<center> | |
<h1>Painting with GANs from MIT-IBM Watson AI Lab</h1> | |
<p>This demo lets you modify a selection of meatningful GAN units for a generated image by simply painting.</p> | |
<img src="{image_path}"> | |
<p>Redirecting to | |
<a href="http://bit.ly/ganpaint">GANPaint</a> | |
</p> | |
</center> | |
</body> | |
'''.format( | |
image_path=image_path, | |
tweet_path=tweet_path, | |
public_host=public_host) | |