Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
import numpy as np
|
4 |
+
import soundfile as sf
|
5 |
+
import pandas as pd
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import random
|
8 |
+
|
9 |
+
# Load the zero-shot audio classification model
|
10 |
+
audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
|
11 |
+
|
12 |
+
# Function to generate a random color
|
13 |
+
def random_color():
|
14 |
+
return [random.uniform(0, 1) for _ in range(3)]
|
15 |
+
|
16 |
+
# Define the classification function
|
17 |
+
def classify_audio(audio_filepath, labels):
|
18 |
+
labels = labels.split(',')
|
19 |
+
audio_data, sample_rate = sf.read(audio_filepath) # Read the audio file
|
20 |
+
|
21 |
+
# Convert to mono if audio is multi-channel
|
22 |
+
if audio_data.ndim > 1:
|
23 |
+
audio_data = np.mean(audio_data, axis=1)
|
24 |
+
|
25 |
+
# Get classification results
|
26 |
+
results = audio_classifier(audio_data, candidate_labels=labels)
|
27 |
+
|
28 |
+
# Convert scores to percentages and create a DataFrame
|
29 |
+
data = [(result['label'], round(result['score'] * 100, 2)) for result in results] # Multiply by 100 and round
|
30 |
+
df = pd.DataFrame(data, columns=["Label", "Score (%)"])
|
31 |
+
|
32 |
+
# Create a horizontal bar chart with random colors
|
33 |
+
fig, ax = plt.subplots(figsize=(10, len(labels)))
|
34 |
+
for i in range(len(df)):
|
35 |
+
ax.barh(df['Label'][i], df['Score (%)'][i], color=random_color())
|
36 |
+
ax.set_xlabel('Score (%)')
|
37 |
+
ax.set_title('Audio Classification Scores')
|
38 |
+
ax.grid(axis='x')
|
39 |
+
|
40 |
+
return df, fig
|
41 |
+
|
42 |
+
# Create the Gradio interface
|
43 |
+
iface = gr.Interface(
|
44 |
+
classify_audio,
|
45 |
+
inputs=[
|
46 |
+
gr.Audio(label="Upload your audio file", type="filepath"),
|
47 |
+
gr.Textbox(label="Enter candidate labels separated by commas")
|
48 |
+
],
|
49 |
+
outputs=[gr.components.Dataframe(), gr.components.Plot()],
|
50 |
+
title="Zero-Shot Audio Classifier",
|
51 |
+
description="Upload an audio file and enter candidate labels to classify the audio."
|
52 |
+
)
|
53 |
+
|
54 |
+
# Launch the interface
|
55 |
+
iface.launch()
|