Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import argparse, connexion, os, sys, yaml, json, socket | |
from netdissect.easydict import EasyDict | |
from flask import send_from_directory, redirect | |
from flask_cors import CORS | |
from netdissect.serverstate import DissectionProject | |
__author__ = 'Hendrik Strobelt, David Bau' | |
CONFIG_FILE_NAME = 'dissect.json' | |
projects = {} | |
app = connexion.App(__name__, debug=False) | |
def get_all_projects(): | |
res = [] | |
for key, project in projects.items(): | |
# print key | |
res.append({ | |
'project': key, | |
'info': { | |
'layers': [layer['layer'] for layer in project.get_layers()] | |
} | |
}) | |
return sorted(res, key=lambda x: x['project']) | |
def get_layers(project): | |
return { | |
'request': {'project': project}, | |
'res': projects[project].get_layers() | |
} | |
def get_units(project, layer): | |
return { | |
'request': {'project': project, 'layer': layer}, | |
'res': projects[project].get_units(layer) | |
} | |
def get_rankings(project, layer): | |
return { | |
'request': {'project': project, 'layer': layer}, | |
'res': projects[project].get_rankings(layer) | |
} | |
def get_levels(project, layer, quantiles): | |
return { | |
'request': {'project': project, 'layer': layer, 'quantiles': quantiles}, | |
'res': projects[project].get_levels(layer, quantiles) | |
} | |
def get_channels(project, layer): | |
answer = dict(channels=projects[project].get_channels(layer)) | |
return { | |
'request': {'project': project, 'layer': layer}, | |
'res': answer | |
} | |
def post_generate(gen_req): | |
project = gen_req['project'] | |
zs = gen_req.get('zs', None) | |
ids = gen_req.get('ids', None) | |
return_urls = gen_req.get('return_urls', False) | |
assert (zs is None) != (ids is None) # one or the other, not both | |
ablations = gen_req.get('ablations', []) | |
interventions = gen_req.get('interventions', None) | |
# no z avilable if ablations | |
generated = projects[project].generate_images(zs, ids, interventions, | |
return_urls=return_urls) | |
return { | |
'request': gen_req, | |
'res': generated | |
} | |
def post_features(feat_req): | |
project = feat_req['project'] | |
ids = feat_req['ids'] | |
masks = feat_req.get('masks', None) | |
layers = feat_req.get('layers', None) | |
interventions = feat_req.get('interventions', None) | |
features = projects[project].get_features( | |
ids, masks, layers, interventions) | |
return { | |
'request': feat_req, | |
'res': features | |
} | |
def post_featuremaps(feat_req): | |
project = feat_req['project'] | |
ids = feat_req['ids'] | |
layers = feat_req.get('layers', None) | |
interventions = feat_req.get('interventions', None) | |
featuremaps = projects[project].get_featuremaps( | |
ids, layers, interventions) | |
return { | |
'request': feat_req, | |
'res': featuremaps | |
} | |
def send_static(path): | |
""" serves all files from ./client/ to ``/client/<path:path>`` | |
:param path: path from api call | |
""" | |
return send_from_directory(args.client, path) | |
def send_data(path): | |
""" serves all files from the data dir to ``/dissect/<path:path>`` | |
:param path: path from api call | |
""" | |
print('Got the data route for', path) | |
return send_from_directory(args.data, path) | |
def redirect_home(): | |
return redirect('/client/index.html', code=302) | |
def load_projects(directory): | |
""" | |
searches for CONFIG_FILE_NAME in all subdirectories of directory | |
and creates data handlers for all of them | |
:param directory: scan directory | |
:return: null | |
""" | |
project_dirs = [] | |
# Don't search more than 2 dirs deep. | |
search_depth = 2 + directory.count(os.path.sep) | |
for root, dirs, files in os.walk(directory): | |
if CONFIG_FILE_NAME in files: | |
project_dirs.append(root) | |
# Don't get subprojects under a project dir. | |
del dirs[:] | |
elif root.count(os.path.sep) >= search_depth: | |
del dirs[:] | |
for p_dir in project_dirs: | |
print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME)) | |
with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf: | |
config = EasyDict(json.load(jf)) | |
dh_id = os.path.split(p_dir)[1] | |
projects[dh_id] = DissectionProject( | |
config=config, | |
project_dir=p_dir, | |
path_url='data/' + os.path.relpath(p_dir, directory), | |
public_host=args.public_host) | |
app.add_api('server.yaml') | |
# add CORS support | |
CORS(app.app, headers='Content-Type') | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--nodebug", default=False) | |
parser.add_argument("--address", default="127.0.0.1") # 0.0.0.0 for nonlocal use | |
parser.add_argument("--port", default="5001") | |
parser.add_argument("--public_host", default=None) | |
parser.add_argument("--nocache", default=False) | |
parser.add_argument("--data", type=str, default='dissect') | |
parser.add_argument("--client", type=str, default='client_dist') | |
if __name__ == '__main__': | |
args = parser.parse_args() | |
for d in [args.data, args.client]: | |
if not os.path.isdir(d): | |
print('No directory %s' % d) | |
sys.exit(1) | |
args.data = os.path.abspath(args.data) | |
args.client = os.path.abspath(args.client) | |
if args.public_host is None: | |
args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) | |
app.run(port=int(args.port), debug=not args.nodebug, host=args.address, | |
use_reloader=False) | |
else: | |
args, _ = parser.parse_known_args() | |
if args.public_host is None: | |
args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) | |
load_projects(args.data) | |