|
import gradio as gr |
|
from transformers import pipeline |
|
import numpy as np |
|
import soundfile as sf |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import random |
|
|
|
|
|
audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused") |
|
|
|
|
|
def random_color(): |
|
return [random.uniform(0, 1) for _ in range(3)] |
|
|
|
|
|
def classify_audio(audio_filepath, labels): |
|
labels = labels.split(',') |
|
audio_data, sample_rate = sf.read(audio_filepath) |
|
|
|
|
|
if audio_data.ndim > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
|
|
results = audio_classifier(audio_data, candidate_labels=labels) |
|
|
|
|
|
data = [(result['label'], round(result['score'] * 100, 2)) for result in results] |
|
df = pd.DataFrame(data, columns=["Label", "Score (%)"]) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, len(labels))) |
|
for i in range(len(df)): |
|
ax.barh(df['Label'][i], df['Score (%)'][i], color=random_color()) |
|
ax.set_xlabel('Score (%)') |
|
ax.set_title('Audio Classification Scores') |
|
ax.grid(axis='x') |
|
|
|
return df, fig |
|
|
|
|
|
iface = gr.Interface( |
|
classify_audio, |
|
inputs=[ |
|
gr.Audio(label="Upload your audio file", type="filepath"), |
|
gr.Textbox(label="Enter candidate labels separated by commas") |
|
], |
|
outputs=[gr.components.Dataframe(), gr.components.Plot()], |
|
title="Zero-Shot Audio Classifier", |
|
description="Upload an audio file and enter candidate labels to classify the audio." |
|
) |
|
|
|
|
|
iface.launch() |
|
|