FashionGAN / netdissect /serverstate.py
fiesty-bear
Initial Commit
6064c9d
import os, torch, numpy, base64, json, re, threading, random
from torch.utils.data import TensorDataset, DataLoader
from collections import defaultdict
from netdissect.easydict import EasyDict
from netdissect.modelconfig import create_instrumented_model
from netdissect.runningstats import RunningQuantile
from netdissect.dissection import safe_dir_name
from netdissect.zdataset import z_sample_for_model
from PIL import Image
from io import BytesIO
class DissectionProject:
'''
DissectionProject understand how to drive a GanTester within a
dissection project directory structure: it caches data in files,
creates image files, and translates data between plain python data
types and the pytorch-specific tensors required by GanTester.
'''
def __init__(self, config, project_dir, path_url, public_host):
print('config done', project_dir)
self.use_cuda = torch.cuda.is_available()
self.dissect = config
self.project_dir = project_dir
self.path_url = path_url
self.public_host = public_host
self.cachedir = os.path.join(self.project_dir, 'cache')
self.tester = GanTester(
config.settings, dissectdir=project_dir,
device=torch.device('cuda') if self.use_cuda
else torch.device('cpu'))
self.stdz = []
def get_zs(self, size):
if size <= len(self.stdz):
return self.stdz[:size].tolist()
z_tensor = self.tester.standard_z_sample(size)
numpy_z = z_tensor.cpu().numpy()
self.stdz = numpy_z
return self.stdz.tolist()
def get_z(self, id):
if id < len(self.stdz):
return self.stdz[id]
return self.get_zs((id + 1) * 2)[id]
def get_zs_for_ids(self, ids):
max_id = max(ids)
if max_id >= len(self.stdz):
self.get_z(max_id)
return self.stdz[ids]
def get_layers(self):
result = []
layer_shapes = self.tester.layer_shapes()
for layer in self.tester.layers:
shape = layer_shapes[layer]
result.append(dict(
layer=layer,
channels=shape[1],
shape=[shape[2], shape[3]]))
return result
def get_units(self, layer):
try:
dlayer = [dl for dl in self.dissect['layers']
if dl['layer'] == layer][0]
except:
return None
dunits = dlayer['units']
result = [dict(unit=unit_num,
img='/%s/%s/s-image/%d-top.jpg' %
(self.path_url, layer, unit_num),
label=unit['iou_label'])
for unit_num, unit in enumerate(dunits)]
return result
def get_rankings(self, layer):
try:
dlayer = [dl for dl in self.dissect['layers']
if dl['layer'] == layer][0]
except:
return None
result = [dict(name=ranking['name'],
metric=ranking.get('metric', None),
scores=ranking['score'])
for ranking in dlayer['rankings']]
return result
def get_levels(self, layer, quantiles):
levels = self.tester.levels(
layer, torch.from_numpy(numpy.array(quantiles)))
return levels.cpu().numpy().tolist()
def generate_images(self, zs, ids, interventions, return_urls=False):
if ids is not None:
assert zs is None
zs = self.get_zs_for_ids(ids)
if not interventions:
# Do file caching when ids are given (and no ablations).
imgdir = os.path.join(self.cachedir, 'img', 'id')
os.makedirs(imgdir, exist_ok=True)
exist = set(os.listdir(imgdir))
unfinished = [('%d.jpg' % id) not in exist for id in ids]
needed_z_tensor = torch.tensor(zs[unfinished]).float().to(
self.tester.device)
needed_ids = numpy.array(ids)[unfinished]
# Generate image files for just the needed images.
if len(needed_z_tensor):
imgs = self.tester.generate_images(needed_z_tensor
).cpu().numpy()
for i, img in zip(needed_ids, imgs):
Image.fromarray(img.transpose(1, 2, 0)).save(
os.path.join(imgdir, '%d.jpg' % i), 'jpeg',
quality=99, optimize=True, progressive=True)
# Assemble a response.
imgurls = ['/%s/cache/img/id/%d.jpg'
% (self.path_url, i) for i in ids]
return [dict(id=i, d=d) for i, d in zip(ids, imgurls)]
# No file caching when ids are not given (or ablations are applied)
z_tensor = torch.tensor(zs).float().to(self.tester.device)
imgs = self.tester.generate_images(z_tensor,
intervention=decode_intervention_array(interventions,
self.tester.layer_shapes()),
).cpu().numpy()
numpy_z = z_tensor.cpu().numpy()
if return_urls:
randdir = '%03d' % random.randrange(1000)
imgdir = os.path.join(self.cachedir, 'img', 'uniq', randdir)
os.makedirs(imgdir, exist_ok=True)
startind = random.randrange(100000)
imgurls = []
for i, img in enumerate(imgs):
filename = '%d.jpg' % (i + startind)
Image.fromarray(img.transpose(1, 2, 0)).save(
os.path.join(imgdir, filename), 'jpeg',
quality=99, optimize=True, progressive=True)
image_url_path = ('/%s/cache/img/uniq/%s/%s'
% (self.path_url, randdir, filename))
imgurls.append(image_url_path)
tweet_filename = 'tweet-%d.html' % (i + startind)
tweet_url_path = ('/%s/cache/img/uniq/%s/%s'
% (self.path_url, randdir, tweet_filename))
with open(os.path.join(imgdir, tweet_filename), 'w') as f:
f.write(twitter_card(image_url_path, tweet_url_path,
self.public_host))
return [dict(d=d) for d in imgurls]
imgurls = [img2base64(img.transpose(1, 2, 0)) for img in imgs]
return [dict(d=d) for d in imgurls]
def get_features(self, ids, masks, layers, interventions):
zs = self.get_zs_for_ids(ids)
z_tensor = torch.tensor(zs).float().to(self.tester.device)
t_masks = torch.stack(
[torch.from_numpy(mask_to_numpy(mask)) for mask in masks]
)[:,None,:,:].to(self.tester.device)
t_features = self.tester.feature_stats(z_tensor, t_masks,
decode_intervention_array(interventions,
self.tester.layer_shapes()), layers)
# Convert torch arrays to plain python lists before returning.
return { layer: { key: value.cpu().numpy().tolist()
for key, value in feature.items() }
for layer, feature in t_features.items() }
def get_featuremaps(self, ids, layers, interventions):
zs = self.get_zs_for_ids(ids)
z_tensor = torch.tensor(zs).float().to(self.tester.device)
# Quantilized features are returned.
q_features = self.tester.feature_maps(z_tensor,
decode_intervention_array(interventions,
self.tester.layer_shapes()), layers)
# Scale them 0-255 and return them.
# TODO: turn them into pngs for returning.
return { layer: [
value.clamp(0, 1).mul(255).byte().cpu().numpy().tolist()
for value in valuelist ]
for layer, valuelist in q_features.items()
if (not layers) or (layer in layers) }
def get_recipes(self):
recipedir = os.path.join(self.project_dir, 'recipe')
if not os.path.isdir(recipedir):
return []
result = []
for filename in os.listdir(recipedir):
with open(os.path.join(recipedir, filename)) as f:
result.append(json.load(f))
return result
class GanTester:
'''
GanTester holds on to a specific model to test.
(1) loads and instantiates the GAN;
(2) instruments it at every layer so that units can be ablated
(3) precomputes z dimensionality, and output image dimensions.
'''
def __init__(self, args, dissectdir=None, device=None):
self.cachedir = os.path.join(dissectdir, 'cache')
self.device = device if device is not None else torch.device('cpu')
self.dissectdir = dissectdir
self.modellock = threading.Lock()
# Load the generator from the pth file.
args_copy = EasyDict(args)
args_copy.edit = True
model = create_instrumented_model(args_copy)
model.eval()
self.model = model
# Get the set of layers of interest.
# Default: all shallow children except last.
self.layers = sorted(model.retained_features().keys())
# Move it to CUDA if wanted.
model.to(device)
self.quantiles = {
layer: load_quantile_if_present(os.path.join(self.dissectdir,
safe_dir_name(layer)), 'quantiles.npz',
device=torch.device('cpu'))
for layer in self.layers }
def layer_shapes(self):
return self.model.feature_shape
def standard_z_sample(self, size=100, seed=1, device=None):
'''
Generate a standard set of random Z as a (size, z_dimension) tensor.
With the same random seed, it always returns the same z (e.g.,
the first one is always the same regardless of the size.)
'''
result = z_sample_for_model(self.model, size)
if device is not None:
result = result.to(device)
return result
def reset_intervention(self):
self.model.remove_edits()
def apply_intervention(self, intervention):
'''
Applies an ablation recipe of the form [(layer, unit, alpha)...].
'''
self.reset_intervention()
if not intervention:
return
for layer, (a, v) in intervention.items():
self.model.edit_layer(layer, ablation=a, replacement=v)
def generate_images(self, z_batch, intervention=None):
'''
Makes some images.
'''
with torch.no_grad(), self.modellock:
batch_size = 10
self.apply_intervention(intervention)
test_loader = DataLoader(TensorDataset(z_batch[:,:,None,None]),
batch_size=batch_size,
pin_memory=('cuda' == self.device.type
and z_batch.device.type == 'cpu'))
result_img = torch.zeros(
*((len(z_batch), 3) + self.model.output_shape[2:]),
dtype=torch.uint8, device=self.device)
for batch_num, [batch_z,] in enumerate(test_loader):
batch_z = batch_z.to(self.device)
out = self.model(batch_z)
result_img[batch_num*batch_size:
batch_num*batch_size+len(batch_z)] = (
(((out + 1) / 2) * 255).clamp(0, 255).byte())
return result_img
def get_layers(self):
return self.layers
def feature_stats(self, z_batch,
masks=None, intervention=None, layers=None):
feature_stat = defaultdict(dict)
with torch.no_grad(), self.modellock:
batch_size = 10
self.apply_intervention(intervention)
if masks is None:
masks = torch.ones(z_batch.size(0), 1, 1, 1,
device=z_batch.device, dtype=z_batch.dtype)
else:
assert masks.shape[0] == z_batch.shape[0]
assert masks.shape[1] == 1
test_loader = DataLoader(
TensorDataset(z_batch[:,:,None,None], masks),
batch_size=batch_size,
pin_memory=('cuda' == self.device.type
and z_batch.device.type == 'cpu'))
processed = 0
for batch_num, [batch_z, batch_m] in enumerate(test_loader):
batch_z, batch_m = [
d.to(self.device) for d in [batch_z, batch_m]]
# Run model but disregard output
self.model(batch_z)
processing = batch_z.shape[0]
for layer, feature in self.model.retained_features().items():
if layers is not None:
if layer not in layers:
continue
# Compute max features touching mask
resized_max = torch.nn.functional.adaptive_max_pool2d(
batch_m,
(feature.shape[2], feature.shape[3]))
max_feature = (feature * resized_max).view(
feature.shape[0], feature.shape[1], -1
).max(2)[0].max(0)[0]
if 'max' not in feature_stat[layer]:
feature_stat[layer]['max'] = max_feature
else:
torch.max(feature_stat[layer]['max'], max_feature,
out=feature_stat[layer]['max'])
# Compute mean features weighted by overlap with mask
resized_mean = torch.nn.functional.adaptive_avg_pool2d(
batch_m,
(feature.shape[2], feature.shape[3]))
mean_feature = (feature * resized_mean).view(
feature.shape[0], feature.shape[1], -1
).sum(2).sum(0) / (resized_mean.sum() + 1e-15)
if 'mean' not in feature_stat[layer]:
feature_stat[layer]['mean'] = mean_feature
else:
feature_stat[layer]['mean'] = (
processed * feature_mean[layer]['mean']
+ processing * mean_feature) / (
processed + processing)
processed += processing
# After summaries are done, also compute quantile stats
for layer, stats in feature_stat.items():
if self.quantiles.get(layer, None) is not None:
for statname in ['max', 'mean']:
stats['%s_quantile' % statname] = (
self.quantiles[layer].normalize(stats[statname]))
return feature_stat
def levels(self, layer, quantiles):
return self.quantiles[layer].quantiles(quantiles)
def feature_maps(self, z_batch, intervention=None, layers=None,
quantiles=True):
feature_map = defaultdict(list)
with torch.no_grad(), self.modellock:
batch_size = 10
self.apply_intervention(intervention)
test_loader = DataLoader(
TensorDataset(z_batch[:,:,None,None]),
batch_size=batch_size,
pin_memory=('cuda' == self.device.type
and z_batch.device.type == 'cpu'))
processed = 0
for batch_num, [batch_z] in enumerate(test_loader):
batch_z = batch_z.to(self.device)
# Run model but disregard output
self.model(batch_z)
processing = batch_z.shape[0]
for layer, feature in self.model.retained_features().items():
for single_featuremap in feature:
if quantiles:
feature_map[layer].append(self.quantiles[layer]
.normalize(single_featuremap))
else:
feature_map[layer].append(single_featuremap)
return feature_map
def load_quantile_if_present(outdir, filename, device):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
result = RunningQuantile(state=data)
result.to_(device)
return result
return None
if __name__ == '__main__':
test_main()
def mask_to_numpy(mask_record):
# Detect a png image mask.
bitstring = mask_record['bitstring']
bitnumpy = None
default_shape = (256, 256)
if 'image/png;base64,' in bitstring:
bitnumpy = base642img(bitstring)
default_shape = bitnumpy.shape[:2]
# Set up results
shape = mask_record.get('shape', None)
if not shape: # None or empty []
shape = default_shape
result = numpy.zeros(shape=shape, dtype=numpy.float32)
bitbounds = mask_record.get('bitbounds', None)
if not bitbounds: # None or empty []
bitbounds = ([0] * len(result.shape)) + list(result.shape)
start = bitbounds[:len(result.shape)]
end = bitbounds[len(result.shape):]
if bitnumpy is not None:
if bitnumpy.shape[2] == 4:
# Mask is any nontransparent bits in the alpha channel if present
result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,3] > 0)
else:
# Or any nonwhite pixels in the red channel if no alpha.
result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,0] < 255)
return result
else:
# Or bitstring can be just ones and zeros.
indexes = start.copy()
bitindex = 0
while True:
result[tuple(indexes)] = (bitstring[bitindex] != '0')
for ii in range(len(indexes) - 1, -1, -1):
if indexes[ii] < end[ii] - 1:
break
indexes[ii] = start[ii]
else:
assert (bitindex + 1) == len(bitstring)
return result
indexes[ii] += 1
bitindex += 1
def decode_intervention_array(interventions, layer_shapes):
result = {}
for channels in [decode_intervention(intervention, layer_shapes)
for intervention in (interventions or [])]:
for layer, channel in channels.items():
if layer not in result:
result[layer] = channel
continue
accum = result[layer]
newalpha = 1 - (1 - channel[:1]) * (1 - accum[:1])
newvalue = (accum[1:] * accum[:1] * (1 - channel[:1]) +
channel[1:] * channel[:1]) / (newalpha + 1e-40)
accum[:1] = newalpha
accum[1:] = newvalue
return result
def decode_intervention(intervention, layer_shapes):
# Every plane of an intervention is a solid choice of activation
# over a set of channels, with a mask applied to alpha-blended channels
# (when the mask resolution is different from the feature map, it can
# be either a max-pooled or average-pooled to the proper resolution).
# This can be reduced to a single alpha-blended featuremap.
if intervention is None:
return None
mask = intervention.get('mask', None)
if mask:
mask = torch.from_numpy(mask_to_numpy(mask))
maskpooling = intervention.get('maskpooling', 'max')
channels = {} # layer -> ([alpha, val], c)
for arec in intervention.get('ablations', []):
unit = arec['unit']
layer = arec['layer']
alpha = arec.get('alpha', 1.0)
if alpha is None:
alpha = 1.0
value = arec.get('value', 0.0)
if value is None:
value = 0.0
if alpha != 0.0 or value != 0.0:
if layer not in channels:
channels[layer] = torch.zeros(2, *layer_shapes[layer][1:])
channels[layer][0, unit] = alpha
channels[layer][1, unit] = value
if mask is not None:
for layer in channels:
layer_shape = layer_shapes[layer][2:]
if maskpooling == 'mean':
layer_mask = torch.nn.functional.adaptive_avg_pool2d(
mask[None,None,...], layer_shape)[0]
else:
layer_mask = torch.nn.functional.adaptive_max_pool2d(
mask[None,None,...], layer_shape)[0]
channels[layer][0] *= layer_mask
return channels
def img2base64(imgarray, for_html=True, image_format='jpeg'):
'''
Converts a numpy array to a jpeg base64 url
'''
input_image_buff = BytesIO()
Image.fromarray(imgarray).save(input_image_buff, image_format,
quality=99, optimize=True, progressive=True)
res = base64.b64encode(input_image_buff.getvalue()).decode('ascii')
if for_html:
return 'data:image/' + image_format + ';base64,' + res
else:
return res
def base642img(stringdata):
stringdata = re.sub('^(?:data:)?image/\w+;base64,', '', stringdata)
im = Image.open(BytesIO(base64.b64decode(stringdata)))
return numpy.array(im)
def twitter_card(image_path, tweet_path, public_host):
return '''\
<!doctype html>
<html>
<head>
<meta name="twitter:card" content="summary_large_image" />
<meta name="twitter:title" content="Painting with GANs from MIT-IBM Watson AI Lab" />
<meta name="twitter:description" content="This demo lets you modify a selection of meaningful GAN units for a generated image by simply painting." />
<meta name="twitter:image" content="http://{public_host}{image_path}" />
<meta name="twitter:url" content="http://{public_host}{tweet_path}" />
<meta http-equiv="refresh" content="10; url=http://bit.ly/ganpaint">
</head>
<style>
body {{ font: 12px Arial, sans-serif; }}
</style>
<body>
<center>
<h1>Painting with GANs from MIT-IBM Watson AI Lab</h1>
<p>This demo lets you modify a selection of meatningful GAN units for a generated image by simply painting.</p>
<img src="{image_path}">
<p>Redirecting to
<a href="http://bit.ly/ganpaint">GANPaint</a>
</p>
</center>
</body>
'''.format(
image_path=image_path,
tweet_path=tweet_path,
public_host=public_host)