Spaces:
Running
on
Zero
Running
on
Zero
# imports from gradio_demo.py | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision.transforms import ToTensor, ToPILImage | |
import sys | |
import os | |
from midi_player import MIDIPlayer | |
from midi_player.stylers import basic, cifka_advanced, dark | |
import numpy as np | |
from time import sleep | |
from subprocess import call | |
import pandas as pd | |
# imports from sample.py | |
import argparse | |
from pathlib import Path | |
import accelerate | |
import safetensors.torch as safetorch | |
#import torch | |
from tqdm import trange, tqdm | |
#from PIL import Image | |
from torchvision import transforms | |
import k_diffusion as K | |
from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square | |
from pom.square_to_rect import square_to_rect | |
def infer_mask_from_init_img(img, mask_with='grey'): | |
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255" | |
assert mask_with in ['blue','white','grey'] | |
"given an image with mask areas marked, extract the mask itself" | |
if not torch.is_tensor(img): | |
img = ToTensor()(img) | |
print("img.shape: ", img.shape) | |
# shape of mask should be img shape without the channel dimension | |
if len(img.shape) == 3: | |
mask = torch.zeros(img.shape[-2:]) | |
elif len(img.shape) == 2: | |
mask = torch.zeros(img.shape) | |
print("mask.shape: ", mask.shape) | |
if mask_with == 'white': | |
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1 | |
elif mask_with == 'blue': | |
mask[img[2,:,:]==1] = 1 # blue | |
if mask_with == 'grey': | |
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1 | |
return mask*1.0 | |
def count_notes_in_mask(img, mask): | |
"counts the number of new notes in the mask" | |
img_t = ToTensor()(img) | |
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel | |
return new_notes.item() | |
def grab_dense_gen(init_img, | |
PREFIX, | |
num_to_gen=64, | |
busyness=100, # after ranking images by how many notes were in mask, which one should we grab? | |
): | |
df = None | |
mask = infer_mask_from_init_img(init_img, mask_with='grey') | |
for num in range(num_to_gen): | |
filename = f'{PREFIX}_{num:05d}.png' | |
gen_img = Image.open(filename) | |
gen_img_rect = square_to_rect(gen_img) | |
new_notes = count_notes_in_mask(gen_img, mask) | |
if df is None: | |
df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect']) | |
else: | |
df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True) | |
# sort df by new_notes column, | |
df = df.sort_values(by='new_notes', ascending=True) | |
grab_index = (len(df)-1)*busyness//100 | |
print("grab_index = ", grab_index) | |
dense_filename = df.iloc[grab_index]['filename'] | |
print("Grabbing filename = ", dense_filename) | |
return dense_filename | |
def process_image(image, repaint, busyness): | |
# get image ready and execute sampler | |
print("image = ",image) | |
image = image['composite'] | |
# if image is a numpy array convert to PIL | |
if isinstance(image, np.ndarray): | |
image = ToPILImage()(image) | |
image = image.convert("RGB").crop((0, 0, 512, 128)) | |
image = rect_to_square( image ) | |
#mask = infer_mask_from_init_img( image ) | |
masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale | |
print("Saving masked image file to ", masked_img_file) | |
image.save(masked_img_file) | |
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region | |
bs = num | |
repaint = repaint | |
seed_scale = 1.0 | |
DEVICES = 'CUDA_VISIBLE_DEVICES=3' | |
USER = 'shawley' | |
RUN_HOME = f'/runs/{USER}/k-diffusion/pop909/full_chords' | |
CKPT = f'{RUN_HOME}/256_chords_00130000.pth' | |
PREFIX = 'gradiodemo' | |
# !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file} | |
print("Reading init image from ", masked_img_file,", repaint = ",repaint) | |
cmd = f'/home/shawley/envs/hs/bin/python {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}' | |
print("Will run command: ", cmd) | |
args = cmd.split(' ') | |
#call(cmd, shell=True) | |
print("Calling: ", args) | |
return_value = call(args) | |
print("Return value = ", return_value) | |
# find gen'd image and convert to midi piano roll | |
#gen_file = f'{PREFIX}_00000.png' | |
gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num) | |
gen_image = square_to_rect(Image.open(gen_file)) | |
midi_file = img_file_2_midi_file(gen_file) | |
srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html | |
srcdoc = srcdoc.replace("\"", "'") | |
html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>''' | |
# convert the midi to audio too | |
audio_file = 'gradio_demo_out.mp3' | |
cmd = f'timidity {midi_file} -Ow -o {audio_file}' | |
print("Converting midi to audio with: ", cmd) | |
return_value = call(cmd.split(' ')) | |
print("Return value = ", return_value) | |
return gen_image, html, audio_file | |
# def greet(name): | |
# return "Hello " + name + "!!" | |
# demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
# demo.launch() | |
make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]} | |
demo = gr.Interface(fn=process_image, | |
inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])), | |
gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"), | |
gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")], | |
outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'), | |
gr.HTML(label="MIDI Player"), | |
gr.Audio(label="MIDI as Audio")], | |
examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+ | |
[[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+ | |
[[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] + | |
[[make_dict('ismir_mask_2.png'),6,100]], | |
) | |
demo.queue().launch() |