Babyloncoder commited on
Commit
3dfefec
·
verified ·
1 Parent(s): 9a1a1d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import random
8
+
9
+ # Load the zero-shot audio classification model
10
+ audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
11
+
12
+ # Function to generate a random color
13
+ def random_color():
14
+ return [random.uniform(0, 1) for _ in range(3)]
15
+
16
+ # Define the classification function
17
+ def classify_audio(audio_filepath, labels):
18
+ labels = labels.split(',')
19
+ audio_data, sample_rate = sf.read(audio_filepath) # Read the audio file
20
+
21
+ # Convert to mono if audio is multi-channel
22
+ if audio_data.ndim > 1:
23
+ audio_data = np.mean(audio_data, axis=1)
24
+
25
+ # Get classification results
26
+ results = audio_classifier(audio_data, candidate_labels=labels)
27
+
28
+ # Convert scores to percentages and create a DataFrame
29
+ data = [(result['label'], round(result['score'] * 100, 2)) for result in results] # Multiply by 100 and round
30
+ df = pd.DataFrame(data, columns=["Label", "Score (%)"])
31
+
32
+ # Create a horizontal bar chart with random colors
33
+ fig, ax = plt.subplots(figsize=(10, len(labels)))
34
+ for i in range(len(df)):
35
+ ax.barh(df['Label'][i], df['Score (%)'][i], color=random_color())
36
+ ax.set_xlabel('Score (%)')
37
+ ax.set_title('Audio Classification Scores')
38
+ ax.grid(axis='x')
39
+
40
+ return df, fig
41
+
42
+ # Create the Gradio interface
43
+ iface = gr.Interface(
44
+ classify_audio,
45
+ inputs=[
46
+ gr.Audio(label="Upload your audio file", type="filepath"),
47
+ gr.Textbox(label="Enter candidate labels separated by commas")
48
+ ],
49
+ outputs=[gr.components.Dataframe(), gr.components.Plot()],
50
+ title="Zero-Shot Audio Classifier",
51
+ description="Upload an audio file and enter candidate labels to classify the audio."
52
+ )
53
+
54
+ # Launch the interface
55
+ iface.launch()