Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
To run this script, from the root of the repo. Make sure to have Flask installed | |
FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567 | |
# or if you have gunicorn | |
gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - | |
""" | |
from collections import defaultdict | |
from functools import wraps | |
from hashlib import sha1 | |
import json | |
import math | |
from pathlib import Path | |
import random | |
import typing as tp | |
from flask import Flask, redirect, render_template, request, session, url_for | |
from audiocraft import train | |
from audiocraft.utils.samples.manager import get_samples_for_xps | |
SAMPLES_PER_PAGE = 8 | |
MAX_RATING = 5 | |
storage = Path(train.main.dora.dir / 'mos_storage') | |
storage.mkdir(exist_ok=True) | |
surveys = storage / 'surveys' | |
surveys.mkdir(exist_ok=True) | |
magma_root = Path(train.__file__).parent.parent | |
app = Flask('mos', static_folder=str(magma_root / 'scripts/static'), | |
template_folder=str(magma_root / 'scripts/templates')) | |
app.secret_key = b'audiocraft makes the best songs' | |
def normalize_path(path: Path): | |
"""Just to make path a bit nicer, make them relative to the Dora root dir. | |
""" | |
path = path.resolve() | |
dora_dir = train.main.dora.dir.resolve() / 'xps' | |
return path.relative_to(dora_dir) | |
def get_full_path(normalized_path: Path): | |
"""Revert `normalize_path`. | |
""" | |
return train.main.dora.dir.resolve() / 'xps' / normalized_path | |
def get_signature(xps: tp.List[str]): | |
"""Return a signature for a list of XP signatures. | |
""" | |
return sha1(json.dumps(xps).encode()).hexdigest()[:10] | |
def ensure_logged(func): | |
"""Ensure user is logged in. | |
""" | |
def _wrapped(*args, **kwargs): | |
user = session.get('user') | |
if user is None: | |
return redirect(url_for('login', redirect_to=request.url)) | |
return func(*args, **kwargs) | |
return _wrapped | |
def login(): | |
"""Login user if not already, then redirect. | |
""" | |
user = session.get('user') | |
if user is None: | |
error = None | |
if request.method == 'POST': | |
user = request.form['user'] | |
if not user: | |
error = 'User cannot be empty' | |
if user is None or error: | |
return render_template('login.html', error=error) | |
assert user | |
session['user'] = user | |
redirect_to = request.args.get('redirect_to') | |
if redirect_to is None: | |
redirect_to = url_for('index') | |
return redirect(redirect_to) | |
def index(): | |
"""Offer to create a new study. | |
""" | |
errors = [] | |
if request.method == 'POST': | |
xps_or_grids = [part.strip() for part in request.form['xps'].split()] | |
xps = set() | |
for xp_or_grid in xps_or_grids: | |
xp_path = train.main.dora.dir / 'xps' / xp_or_grid | |
if xp_path.exists(): | |
xps.add(xp_or_grid) | |
continue | |
grid_path = train.main.dora.dir / 'grids' / xp_or_grid | |
if grid_path.exists(): | |
for child in grid_path.iterdir(): | |
if child.is_symlink(): | |
xps.add(child.name) | |
continue | |
errors.append(f'{xp_or_grid} is neither an XP nor a grid!') | |
assert xps or errors | |
blind = 'true' if request.form.get('blind') == 'on' else 'false' | |
xps = list(xps) | |
if not errors: | |
signature = get_signature(xps) | |
manifest = { | |
'xps': xps, | |
} | |
survey_path = surveys / signature | |
survey_path.mkdir(exist_ok=True) | |
with open(survey_path / 'manifest.json', 'w') as f: | |
json.dump(manifest, f, indent=2) | |
return redirect(url_for('survey', blind=blind, signature=signature)) | |
return render_template('index.html', errors=errors) | |
def survey(signature): | |
success = request.args.get('success', False) | |
seed = int(request.args.get('seed', 4321)) | |
blind = request.args.get('blind', 'false') in ['true', 'on', 'True'] | |
exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True'] | |
exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True'] | |
max_epoch = int(request.args.get('max_epoch', '-1')) | |
survey_path = surveys / signature | |
assert survey_path.exists(), survey_path | |
user = session['user'] | |
result_folder = survey_path / 'results' | |
result_folder.mkdir(exist_ok=True) | |
result_file = result_folder / f'{user}_{seed}.json' | |
with open(survey_path / 'manifest.json') as f: | |
manifest = json.load(f) | |
xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']] | |
names, ref_name = train.main.get_names(xps) | |
samples_kwargs = { | |
'exclude_prompted': exclude_prompted, | |
'exclude_unprompted': exclude_unprompted, | |
'max_epoch': max_epoch, | |
} | |
matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) # fetch latest epoch | |
models_by_id = { | |
id: [{ | |
'xp': xps[idx], | |
'xp_name': names[idx], | |
'model_id': f'{xps[idx].sig}-{sample.id}', | |
'sample': sample, | |
'is_prompted': sample.prompt is not None, | |
'errors': [], | |
} for idx, sample in enumerate(samples)] | |
for id, samples in matched_samples.items() | |
} | |
experiments = [ | |
{'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch} | |
for idx, xp in enumerate(xps) | |
] | |
keys = list(matched_samples.keys()) | |
keys.sort() | |
rng = random.Random(seed) | |
rng.shuffle(keys) | |
model_ids = keys[:SAMPLES_PER_PAGE] | |
if blind: | |
for key in model_ids: | |
rng.shuffle(models_by_id[key]) | |
ok = True | |
if request.method == 'POST': | |
all_samples_results = [] | |
for id in model_ids: | |
models = models_by_id[id] | |
result = { | |
'id': id, | |
'is_prompted': models[0]['is_prompted'], | |
'models': {} | |
} | |
all_samples_results.append(result) | |
for model in models: | |
rating = request.form[model['model_id']] | |
if rating: | |
rating = int(rating) | |
assert rating <= MAX_RATING and rating >= 1 | |
result['models'][model['xp'].sig] = rating | |
model['rating'] = rating | |
else: | |
ok = False | |
model['errors'].append('Please rate this model.') | |
if ok: | |
result = { | |
'results': all_samples_results, | |
'seed': seed, | |
'user': user, | |
'blind': blind, | |
'exclude_prompted': exclude_prompted, | |
'exclude_unprompted': exclude_unprompted, | |
} | |
print(result) | |
with open(result_file, 'w') as f: | |
json.dump(result, f) | |
seed = seed + 1 | |
return redirect(url_for( | |
'survey', signature=signature, blind=blind, seed=seed, | |
exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, | |
max_epoch=max_epoch, success=True)) | |
ratings = list(range(1, MAX_RATING + 1)) | |
return render_template( | |
'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success, | |
exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, | |
experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[], | |
ref_name=ref_name, already_filled=result_file.exists()) | |
def audio(path: str): | |
full_path = Path('/') / path | |
assert full_path.suffix in [".mp3", ".wav"] | |
return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'} | |
def mean(x): | |
return sum(x) / len(x) | |
def std(x): | |
m = mean(x) | |
return math.sqrt(sum((i - m)**2 for i in x) / len(x)) | |
def results(signature): | |
survey_path = surveys / signature | |
assert survey_path.exists(), survey_path | |
result_folder = survey_path / 'results' | |
result_folder.mkdir(exist_ok=True) | |
# ratings per model, then per user. | |
ratings_per_model = defaultdict(list) | |
users = [] | |
for result_file in result_folder.iterdir(): | |
if result_file.suffix != '.json': | |
continue | |
with open(result_file) as f: | |
results = json.load(f) | |
users.append(results['user']) | |
for result in results['results']: | |
for sig, rating in result['models'].items(): | |
ratings_per_model[sig].append(rating) | |
fmt = '{:.2f}' | |
models = [] | |
for model in sorted(ratings_per_model.keys()): | |
ratings = ratings_per_model[model] | |
models.append({ | |
'sig': model, | |
'samples': len(ratings), | |
'mean_rating': fmt.format(mean(ratings)), | |
# the value 1.96 was probably chosen to achieve some | |
# confidence interval assuming gaussianity. | |
'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5), | |
}) | |
return render_template('results.html', signature=signature, models=models, users=users) | |