Spaces:
Runtime error
Runtime error
File size: 3,534 Bytes
c914273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from preprocessing.preprocess import AudioPipeline
from preprocessing.preprocess import AudioPipeline
from dancer_net.dancer_net import ShortChunkCNN
import os
import json
from functools import cache
import pandas as pd
@cache
def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]:
model_path = "logs/20221226-230930"
weights = os.path.join(model_path, "dancer_net.pt")
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
labels = np.array(sorted(config["classes"]))
model = ShortChunkCNN(n_class=len(labels))
model.load_state_dict(torch.load(weights))
model = model.to(device).eval()
return model, labels
@cache
def get_pipeline(sample_rate:int) -> AudioPipeline:
return AudioPipeline(input_freq=sample_rate)
@cache
def get_dance_map() -> dict:
df = pd.read_csv("data/dance_mapping.csv")
return df.set_index("id").to_dict()["name"]
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
sample_rate, waveform = audio
expected_duration = 6
threshold = 0.5
sample_len = sample_rate * expected_duration
device = "mps"
audio_pipeline = get_pipeline(sample_rate)
model, labels = get_model(device)
if sample_len > len(waveform):
raise gr.Error("You must record for at least 6 seconds")
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
waveform = waveform.transpose(1,0)
waveform = waveform.mean(axis=0, keepdims=True)
else:
waveform = np.expand_dims(waveform, 0)
waveform = waveform[: ,:sample_len]
waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
waveform = waveform.astype("float32")
waveform = torch.from_numpy(waveform)
spectrogram = audio_pipeline(waveform)
spectrogram = spectrogram.unsqueeze(0).to(device)
with torch.no_grad():
results = model(spectrogram)
dance_mapping = get_dance_map()
results = results.squeeze(0).detach().cpu().numpy()
result_mask = results > threshold
probs = results[result_mask]
dances = labels[result_mask]
return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
def demo():
title = "Dance Classifier"
description = "Record 6 seconds of a song and find out what dance fits the music."
with gr.Blocks() as app:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Tab("Record Song"):
mic_audio = gr.Audio(source="microphone", label="Song Recording")
mic_submit = gr.Button("Predict")
with gr.Tab("Upload Song") as t:
audio_file = gr.Audio(label="Song Audio File")
audio_file_submit = gr.Button("Predict")
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
labels = gr.Label(label="Dances")
gr.Markdown("## Examples")
gr.Examples(
examples=example_audio,
inputs=audio_file,
outputs=labels,
fn=predict,
)
audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
return app
if __name__ == "__main__":
demo().launch() |