composer / app.py
Ron Au
feat(ui): Add frontend
ca5bd83
raw
history blame
4.93 kB
# Copyright 2022 Tristan Behrens.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
from flask import Flask, render_template, request, send_file, jsonify, redirect, url_for
from PIL import Image
import os
import io
import random
import base64
import torch
import wave
from source.logging import create_logger
from source.tokensequence import token_sequence_to_audio, token_sequence_to_image
from source import constants
from transformers import AutoTokenizer, AutoModelForCausalLM
logger = create_logger(__name__)
# Load the auth-token from authtoken.txt.
auth_token = os.getenv("authtoken")
# Loading the model and its tokenizer.
logger.info("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token)
model = AutoModelForCausalLM.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token)
logger.info("Done.")
# Create the app.
logger.info("Creating app...")
app = Flask(__name__, static_url_path="")
logger.info("Done.")
# Route for the loading page.
@app.route("/")
def index():
return render_template(
"index.html",
compose_styles=constants.get_compose_styles_for_ui(),
densities=constants.get_densities_for_ui(),
temperatures=constants.get_temperatures_for_ui(),
)
@app.route("/compose", methods=["POST"])
def compose():
# Get the parameters as JSON.
params = request.get_json()
music_style = params["music_style"]
density = params["density"]
temperature = params["temperature"]
instruments = constants.get_instruments(music_style)
density = constants.get_density(density)
temperature = constants.get_temperature(temperature)
print(f"instruments: {instruments} density: {density} temperature: {temperature}")
# Generate with the given parameters.
logger.info(f"Generating token sequence...")
generated_sequence = generate_sequence(instruments, density, temperature)
logger.info(f"Generated token sequence: {generated_sequence}")
# Get the audio data as a array of int16.
logger.info("Generating audio...")
sample_rate, audio_data = token_sequence_to_audio(generated_sequence)
logger.info(f"Done. Audio data: {len(audio_data)}")
# Encode the audio-data as wave file in memory. Use the wave module.
audio_data_bytes = io.BytesIO()
wave_file = wave.open(audio_data_bytes, "wb")
wave_file.setframerate(sample_rate)
wave_file.setnchannels(1)
wave_file.setsampwidth(2)
wave_file.writeframes(audio_data)
wave_file.close()
# Return the audio-data as a base64-encoded string.
audio_data_bytes.seek(0)
audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8")
audio_data_bytes.close()
# Convert the audio data to an PIL image.
image = token_sequence_to_image(generated_sequence)
# Save PIL image to harddrive as PNG.
logger.debug(f"Saving image to harddrive... {type(image)}")
image_file_name = "compose.png"
image.save(image_file_name, "PNG")
# Save image to virtual file.
img_io = io.BytesIO()
image.save(img_io, "PNG", quality=70)
img_io.seek(0)
# Return the image as a base64-encoded string.
image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8")
img_io.close()
# Return.
return jsonify({
"tokens": generated_sequence,
"audio": "data:audio/wav;base64," + audio_data_base64,
"image": "data:image/png;base64," + image_data_base64,
"status": "OK"
})
def generate_sequence(instruments, density, temperature):
instruments = instruments[::]
random.shuffle(instruments)
generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0]
for instrument in instruments:
more_ids = tokenizer.encode(f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt")[0]
generated_ids = torch.cat((generated_ids, more_ids))
generated_ids = generated_ids.unsqueeze(0)
generated_ids = model.generate(
generated_ids,
max_length=2048,
do_sample=True,
temperature=temperature,
eos_token_id=tokenizer.encode("TRACK_END")[0]
)[0]
generated_sequence = tokenizer.decode(generated_ids)
return generated_sequence
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)