omsandeeppatil commited on
Commit
9c11a0a
·
verified ·
1 Parent(s): 8581f9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -43
app.py CHANGED
@@ -1,61 +1,85 @@
1
  import gradio as gr
2
- import spaces ## For ZeroGPU
3
  import torch
4
- import torchaudio
5
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
  model_name = "Hatman/audio-emotion-detection"
10
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
11
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
 
12
 
13
- def preprocess_audio(audio):
14
- waveform, sampling_rate = torchaudio.load(audio)
15
- resampled_waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
16
- return {'speech': resampled_waveform.numpy().flatten(), 'sampling_rate': 16000}
 
 
 
 
 
 
17
 
18
- @spaces.GPU ## For ZeroGPU
19
- def inference(audio):
20
- example = preprocess_audio(audio)
21
- inputs = feature_extractor(example['speech'], sampling_rate=16000, return_tensors="pt", padding=True)
22
- inputs = {k: v.to('cpu') for k, v in inputs.items()} # Not necessary on ZeroGPU
23
- with torch.no_grad():
24
- logits = model(**inputs).logits
25
- predicted_ids = torch.argmax(logits, dim=-1)
26
- return model.config.id2label[predicted_ids.item()], logits, predicted_ids
27
 
28
- @spaces.GPU ## For ZeroGPU
29
- def inference_label(audio):
30
- example = preprocess_audio(audio)
31
- inputs = feature_extractor(example['speech'], sampling_rate=16000, return_tensors="pt", padding=True)
32
- inputs = {k: v.to('cpu') for k, v in inputs.items()} # Not necessary on ZeroGPU
33
- with torch.no_grad():
34
- logits = model(**inputs).logits
35
- predicted_ids = torch.argmax(logits, dim=-1)
36
- return model.config.id2label[predicted_ids.item()]
37
-
38
- with gr.Blocks() as demo:
39
- gr.Markdown("# Audio Sentiment Analysis")
40
 
 
 
41
 
 
 
 
42
 
43
- with gr.Tab("Label Only Inference"):
44
- gr.Interface(
45
- fn=inference_label,
46
- inputs=gr.Audio(type="filepath"),
47
- outputs=gr.Label(label="Predicted Sentiment"),
48
- title="Audio Sentiment Analysis",
49
- description="Upload an audio file or record one to get the predicted sentiment label."
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- with gr.Tab("Full Inference"):
53
- gr.Interface(
54
- fn=inference,
55
- inputs=gr.Audio(type="filepath"),
56
- outputs=[gr.Label(label="Predicted Sentiment"), gr.Textbox(label="Logits"), gr.Textbox(label="Predicted IDs")],
57
- title="Audio Sentiment Analysis (Full)",
58
- description="Upload an audio file or record one to analyze sentiment and get detailed results."
 
 
 
59
  )
 
 
 
 
 
 
 
60
 
61
- demo.launch(share=True)
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import numpy as np
4
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
5
 
6
+ # Initialize model and processor
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
  model_name = "Hatman/audio-emotion-detection"
9
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
10
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
11
+ model.to(device)
12
 
13
+ # Define emotion labels
14
+ EMOTION_LABELS = {
15
+ 0: "angry",
16
+ 1: "disgust",
17
+ 2: "fear",
18
+ 3: "happy",
19
+ 4: "neutral",
20
+ 5: "sad",
21
+ 6: "surprise"
22
+ }
23
 
24
+ def process_audio(audio):
25
+ """Process audio chunk and return emotion"""
26
+ if audio is None:
27
+ return ""
 
 
 
 
 
28
 
29
+ # Get the audio data
30
+ if isinstance(audio, tuple):
31
+ audio = audio[1]
 
 
 
 
 
 
 
 
 
32
 
33
+ # Convert to numpy array if needed
34
+ audio = np.array(audio)
35
 
36
+ # Ensure we have mono audio
37
+ if len(audio.shape) > 1:
38
+ audio = audio.mean(axis=1)
39
 
40
+ try:
41
+ # Prepare input for the model
42
+ inputs = feature_extractor(
43
+ audio,
44
+ sampling_rate=16000,
45
+ return_tensors="pt",
46
+ padding=True
47
  )
48
+
49
+ # Move to appropriate device
50
+ inputs = {k: v.to(device) for k, v in inputs.items()}
51
+
52
+ # Get prediction
53
+ with torch.no_grad():
54
+ outputs = model(**inputs)
55
+ logits = outputs.logits
56
+ predicted_id = torch.argmax(logits, dim=-1).item()
57
+
58
+ emotion = EMOTION_LABELS[predicted_id]
59
+ return emotion
60
+
61
+ except Exception as e:
62
+ print(f"Error processing audio: {e}")
63
+ return "Error processing audio"
64
 
65
+ # Create Gradio interface
66
+ demo = gr.Interface(
67
+ fn=process_audio,
68
+ inputs=[
69
+ gr.Audio(
70
+ sources=["microphone"],
71
+ type="numpy",
72
+ streaming=True,
73
+ label="Speak into your microphone",
74
+ show_label=True
75
  )
76
+ ],
77
+ outputs=gr.Textbox(label="Detected Emotion"),
78
+ title="Live Emotion Detection",
79
+ description="Speak into your microphone to detect emotions in real-time.",
80
+ live=True,
81
+ allow_flagging=False
82
+ )
83
 
84
+ # Launch with a small queue for better real-time performance
85
+ demo.queue(max_size=1).launch(share=True)