5roop's picture
Update README.md
1cdf067 verified
|
raw
history blame
3.11 kB
---
license: apache-2.0
language:
- sl
- hr
- sr
base_model:
- facebook/w2v-bert-2.0
pipeline_tag: audio-classification
metrics:
- f1
---
# Frame classification for filled pauses
This model classifies individual 20ms frames of audio based on presence of filled pauses ("eee", "errm", ...).
It was trained on human-annotated Slovenian speech corpus ROG-Artur and achieves F1 of 0.95 for the positive class on
te test split of the same dataset.
# Evaluation
Although the output of the model is a series 0 or 1, describing their 20ms frames, the evaluation was done on
event level; spans of consecutive outputs 1 were bundled together into one event. When the true and predicted
events partially overlap, this is counted as a true positive.
## Evaluation on ROG corpus
In evaluation, we only evaluate positive events, i.e.
```
precision recall f1-score support
1 0.907 0.987 0.946 1834
```
## Evaluation on ParlaSpeech [HR](https://huggingface.co/datasets/classla/ParlaSpeech-HR) and [RS](https://huggingface.co/datasets/classla/ParlaSpeech-RS) corpora
Evaluation on 800 human-annotated instances ParlaSpeech-HR and ParlaSpeech-RS produced the following metrics:
```
Performance on RS:
Classification report for human vs model on event level:
precision recall f1-score support
1 0.95 0.99 0.97 542
Performance on HR:
Classification report for human vs model on event level:
precision recall f1-score support
1 0.93 0.98 0.95 531
```
The metrics reported are on event level, which means that if true and
predicted filled pauses at least partially overlap, we count them as a
True Positive event.
# Example use:
```python
from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification
from datasets import Dataset, Audio
import torch
import numpy as np
from pathlib import Path
device = torch.device("cuda")
model_name = "classla/wav2vecbert2-filledPause"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device)
ds = Dataset.from_dict(
{
"audio": [
"/cache/peterr/mezzanine_resources/filled_pauses/data/dev/Iriss-J-Gvecg-P500001-avd_2082.293_2112.194.wav"
],
}
).cast_column("audio", Audio(sampling_rate=16_000, mono=True))
def evaluator(chunks):
sampling_rate = chunks["audio"][0]["sampling_rate"]
with torch.no_grad():
inputs = feature_extractor(
[i["array"] for i in chunks["audio"]],
return_tensors="pt",
sampling_rate=sampling_rate,
).to(device)
logits = model(**inputs).logits
y_pred = np.array(logits.cpu()).argmax(axis=-1)
return {"y_pred": y_pred.tolist()}
ds = ds.map(evaluator, batched=True)
print(ds["y_pred"][0])
# Returns a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,....]
# with 0 indicating no filled pause detected in that frame
```
# Citation
Coming soon.