Lingeshg's picture
Update README.md
24610a3 verified
|
raw
history blame
4.4 kB
---
language:
- en
- ta
- fr
- ml
pipeline_tag: voice-activity-detection
base_model: facebook/wav2vec2-base
---
# Model Card for Emotion Classification from Voice
This model performs emotion classification from voice data using fine-tuned `Wav2Vec2Model` from Facebook. The model predicts one of seven emotion labels: Angry, Disgust, Fear, Happy, Neutral, Sad, and Surprise.
## Model Details
- **Developed by:** [Your Name/Organization]
- **Model type:** Fine-tuned Wav2Vec2Model
- **Language(s):** English (en), Tamil (ta), French (fr), Malayalam (ml)
- **License:** [Choose a license]
- **Finetuned from model:** [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base)
### Model Sources
- **Repository:** [Link to your repository]
- **Demo:** [Gradio Demo Link if Available]
## Uses
### Direct Use
This model can be directly used for emotion detection in speech audio files, which can have applications in call centers, virtual assistants, and mental health monitoring.
### Out-of-Scope Use
The model is not intended for general speech recognition or other NLP tasks outside emotion classification.
## Datasets Used
The model has been trained on a combination of the following datasets:
- **CREMA-D:** 7,442 clips of actors speaking with various emotions
- **Torrento:** Emotional speech in Spanish, captured from various environments
- **RAVDESS:** 24 professional actors, 7 emotions
- **Emo-DB:** 535 utterances, covering 7 emotions
The combination of these datasets allows the model to generalize across multiple languages and accents.
## Bias, Risks, and Limitations
- **Bias:** The model might underperform on speech data with accents or languages not present in the training data.
- **Limitations:** The model is trained specifically for emotion detection and might not generalize well for other speech tasks.
## How to Get Started with the Model
```python
import torch
import numpy as np
from transformers import Wav2Vec2Model
from torchaudio.transforms import Resample
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base", output_hidden_states=True).to(device)
class FineTunedWav2Vec2Model(torch.nn.Module):
def __init__(self, wav2vec2_model, output_size):
super(FineTunedWav2Vec2Model, self).__init__()
self.wav2vec2 = wav2vec2_model
self.fc = torch.nn.Linear(self.wav2vec2.config.hidden_size, output_size)
def forward(self, x):
self.wav2vec2 = self.wav2vec2.double()
self.fc = self.fc.double()
outputs = self.wav2vec2(x.double())
out = outputs.hidden_states[-1]
out = self.fc(out[:, 0, :])
return out
def preprocess_audio(audio):
sample_rate, waveform = audio
if isinstance(waveform, np.ndarray):
waveform = torch.from_numpy(waveform)
if waveform.dim() == 2:
waveform = waveform.mean(dim=0)
# Normalize audio
if waveform.dtype != torch.float32:
waveform = waveform.float() / torch.iinfo(waveform.dtype).max
# Resample to 16kHz
if sample_rate != 16000:
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform
def predict(audio):
model_path = "model.pth" # Path to your fine-tuned model
model = FineTunedWav2Vec2Model(wav2vec2_model, 7).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
waveform = preprocess_audio(audio)
waveform = waveform.unsqueeze(0).to(device)
with torch.no_grad():
output = model(waveform)
predicted_label = torch.argmax(output, dim=1).item()
emotion_labels = ["Angry", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
return emotion_labels[predicted_label]
# Example usage
audio_data = (sample_rate, waveform) # Replace with your actual audio data
emotion = predict(audio_data)
print(f"Predicted Emotion: {emotion}")
```
## Training Procedure
- Preprocessing: Resampled all audio to 16kHz.
- Training: Fine-tuned facebook/wav2vec2-base with emotion labels.
- Hyperparameters: Batch size: 16, Learning rate: 5e-5, Epochs: 50
## Evaluation
Testing Data
Evaluation was performed on a held-out test set from the CREMA-D and RAVDESS datasets.
## Metrics
Accuracy: 85%
F1-score: 82% (weighted average across all classes)