Boltz79's picture
Update app.py
cb9a254 verified
raw
history blame
5.77 kB
import gradio as gr
import numpy as np
import torch
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import librosa
import os
import warnings
warnings.filterwarnings("ignore")
class EmotionRecognizer:
def __init__(self):
# Initialize the model and feature extractor
self.model_name = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
self.model = AutoModelForAudioClassification.from_pretrained(self.model_name)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.sample_rate = 16000
# Define emotion labels
self.labels = ['angry', 'happy', 'sad', 'neutral', 'fearful']
def process_audio(self, audio):
"""Process audio and return emotions with confidence scores"""
try:
# Check if audio is a tuple (new Gradio audio format)
if isinstance(audio, tuple):
sample_rate, audio_data = audio
else:
return "Error: Invalid audio format", None
# Resample if necessary
if sample_rate != self.sample_rate:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=self.sample_rate)
# Convert to float32 if not already
audio_data = audio_data.astype(np.float32)
# Extract features
inputs = self.feature_extractor(
audio_data,
sampling_rate=self.sample_rate,
return_tensors="pt",
padding=True
).to(self.device)
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Process results
scores = predictions[0].cpu().numpy()
results = [
{"label": label, "score": float(score)}
for label, score in zip(self.labels, scores)
]
# Sort by confidence
results.sort(key=lambda x: x["score"], reverse=True)
# Format results for display
output_text = "Emotion Analysis Results:\n\n"
output_text += "\n".join([
f"{result['label'].title()}: {result['score']*100:.2f}%"
for result in results
])
# Prepare plot data
plot_data = {
"labels": [r["label"].title() for r in results],
"values": [r["score"] * 100 for r in results]
}
return output_text, plot_data
except Exception as e:
return f"Error processing audio: {str(e)}", None
def create_interface():
# Initialize the emotion recognizer
recognizer = EmotionRecognizer()
# Define processing function for Gradio
def process_audio_file(audio):
if audio is None:
return "Please provide an audio input.", None
output_text, plot_data = recognizer.process_audio(audio)
if plot_data is not None:
return (
output_text,
gr.BarPlot.update(
value=plot_data,
x="labels",
y="values",
title="Emotion Confidence Scores",
x_title="Emotions",
y_title="Confidence (%)"
)
)
return output_text, None
# Create the Gradio interface
with gr.Blocks(title="Audio Emotion Recognition") as interface:
gr.Markdown("# 🎭 Audio Emotion Recognition")
gr.Markdown("""
Upload an audio file or record directly to analyze the emotional content.
The model will detect emotions like angry, happy, sad, neutral, and fearful.
""")
with gr.Row():
with gr.Column():
# Input audio component (updated format)
audio_input = gr.Audio(
label="Upload or Record Audio",
type="numpy",
sources=["microphone", "upload"]
)
# Process button
process_btn = gr.Button("Analyze Emotion", variant="primary")
with gr.Column():
# Output components
output_text = gr.Textbox(
label="Analysis Results",
lines=6
)
output_plot = gr.BarPlot(
title="Emotion Confidence Scores",
x_title="Emotions",
y_title="Confidence (%)"
)
# Set up event handler
process_btn.click(
fn=process_audio_file,
inputs=[audio_input],
outputs=[output_text, output_plot]
)
gr.Markdown("""
### Usage Instructions:
1. Click the microphone button to record audio or upload an audio file
2. Click "Analyze Emotion" to process the audio
3. View the results and confidence scores
### Notes:
- For best results, ensure clear audio with minimal background noise
- Speak naturally and clearly when recording
- The model works best with speech in English
""")
return interface
def main():
# Create and launch the interface
interface = create_interface()
interface.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)
if __name__ == "__main__":
main()