File size: 1,874 Bytes
3dfefec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
from transformers import pipeline
import numpy as np
import soundfile as sf
import pandas as pd
import matplotlib.pyplot as plt
import random

# Load the zero-shot audio classification model
audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")

# Function to generate a random color
def random_color():
    return [random.uniform(0, 1) for _ in range(3)]

# Define the classification function
def classify_audio(audio_filepath, labels):
    labels = labels.split(',')
    audio_data, sample_rate = sf.read(audio_filepath)  # Read the audio file

    # Convert to mono if audio is multi-channel
    if audio_data.ndim > 1:
        audio_data = np.mean(audio_data, axis=1)

    # Get classification results
    results = audio_classifier(audio_data, candidate_labels=labels)

    # Convert scores to percentages and create a DataFrame
    data = [(result['label'], round(result['score'] * 100, 2)) for result in results]  # Multiply by 100 and round
    df = pd.DataFrame(data, columns=["Label", "Score (%)"])

    # Create a horizontal bar chart with random colors
    fig, ax = plt.subplots(figsize=(10, len(labels)))
    for i in range(len(df)):
        ax.barh(df['Label'][i], df['Score (%)'][i], color=random_color())
    ax.set_xlabel('Score (%)')
    ax.set_title('Audio Classification Scores')
    ax.grid(axis='x')

    return df, fig

# Create the Gradio interface
iface = gr.Interface(
    classify_audio,
    inputs=[
        gr.Audio(label="Upload your audio file", type="filepath"),
        gr.Textbox(label="Enter candidate labels separated by commas")
    ],
    outputs=[gr.components.Dataframe(), gr.components.Plot()],
    title="Zero-Shot Audio Classifier",
    description="Upload an audio file and enter candidate labels to classify the audio."
)

# Launch the interface
iface.launch()