Babyloncoder's picture
Create app.py
3dfefec verified
raw
history blame
1.87 kB
import gradio as gr
from transformers import pipeline
import numpy as np
import soundfile as sf
import pandas as pd
import matplotlib.pyplot as plt
import random
# Load the zero-shot audio classification model
audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
# Function to generate a random color
def random_color():
return [random.uniform(0, 1) for _ in range(3)]
# Define the classification function
def classify_audio(audio_filepath, labels):
labels = labels.split(',')
audio_data, sample_rate = sf.read(audio_filepath) # Read the audio file
# Convert to mono if audio is multi-channel
if audio_data.ndim > 1:
audio_data = np.mean(audio_data, axis=1)
# Get classification results
results = audio_classifier(audio_data, candidate_labels=labels)
# Convert scores to percentages and create a DataFrame
data = [(result['label'], round(result['score'] * 100, 2)) for result in results] # Multiply by 100 and round
df = pd.DataFrame(data, columns=["Label", "Score (%)"])
# Create a horizontal bar chart with random colors
fig, ax = plt.subplots(figsize=(10, len(labels)))
for i in range(len(df)):
ax.barh(df['Label'][i], df['Score (%)'][i], color=random_color())
ax.set_xlabel('Score (%)')
ax.set_title('Audio Classification Scores')
ax.grid(axis='x')
return df, fig
# Create the Gradio interface
iface = gr.Interface(
classify_audio,
inputs=[
gr.Audio(label="Upload your audio file", type="filepath"),
gr.Textbox(label="Enter candidate labels separated by commas")
],
outputs=[gr.components.Dataframe(), gr.components.Plot()],
title="Zero-Shot Audio Classifier",
description="Upload an audio file and enter candidate labels to classify the audio."
)
# Launch the interface
iface.launch()