5roop's picture
Update README.md
1cdf067 verified
metadata
license: apache-2.0
language:
  - sl
  - hr
  - sr
base_model:
  - facebook/w2v-bert-2.0
pipeline_tag: audio-classification
metrics:
  - f1

Frame classification for filled pauses

This model classifies individual 20ms frames of audio based on presence of filled pauses ("eee", "errm", ...).

It was trained on human-annotated Slovenian speech corpus ROG-Artur and achieves F1 of 0.95 for the positive class on te test split of the same dataset.

Evaluation

Although the output of the model is a series 0 or 1, describing their 20ms frames, the evaluation was done on event level; spans of consecutive outputs 1 were bundled together into one event. When the true and predicted events partially overlap, this is counted as a true positive.

Evaluation on ROG corpus

In evaluation, we only evaluate positive events, i.e.

              precision    recall  f1-score   support

           1      0.907     0.987     0.946      1834

Evaluation on ParlaSpeech HR and RS corpora

Evaluation on 800 human-annotated instances ParlaSpeech-HR and ParlaSpeech-RS produced the following metrics:

Performance on RS:
Classification report for human vs model on event level: 
              precision    recall  f1-score   support

           1       0.95      0.99      0.97       542
Performance on HR:
Classification report for human vs model on event level: 
              precision    recall  f1-score   support

           1       0.93      0.98      0.95       531

The metrics reported are on event level, which means that if true and predicted filled pauses at least partially overlap, we count them as a True Positive event.

Example use:


from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification
from datasets import Dataset, Audio
import torch
import numpy as np
from pathlib import Path

device = torch.device("cuda")
model_name = "classla/wav2vecbert2-filledPause"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device)

ds = Dataset.from_dict(
    {
        "audio": [
            "/cache/peterr/mezzanine_resources/filled_pauses/data/dev/Iriss-J-Gvecg-P500001-avd_2082.293_2112.194.wav"
        ],
    }
).cast_column("audio", Audio(sampling_rate=16_000, mono=True))


def evaluator(chunks):
    sampling_rate = chunks["audio"][0]["sampling_rate"]
    with torch.no_grad():
        inputs = feature_extractor(
            [i["array"] for i in chunks["audio"]],
            return_tensors="pt",
            sampling_rate=sampling_rate,
        ).to(device)
        logits = model(**inputs).logits
    y_pred = np.array(logits.cpu()).argmax(axis=-1)
    return {"y_pred": y_pred.tolist()}


ds = ds.map(evaluator, batched=True)
print(ds["y_pred"][0])
# Returns a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,....]
# with 0 indicating no filled pause detected in that frame

Citation

Coming soon.