5roop's picture
Create README.md
ffa29b0 verified
|
raw
history blame
1.72 kB
metadata
license: apache-2.0
language:
  - sl
  - hr
  - sr
base_model:
  - facebook/w2v-bert-2.0
pipeline_tag: audio-classification
metrics:
  - f1

Frame classification for filled pauses

This model classifies individual 20ms frames of audio based on presence of filled pauses ("eee", "errm", ...).

It was trained on human-annotated Slovenian speech corpus ROG-Artur and achieves F1 of 0.952868.

Example use:


from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification
from datasets import Dataset, Audio
import torch
import numpy as np
from pathlib import Path

device = torch.device("cuda")
model_name = "5roop/wav2vecbert2-filledPause"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device)

ds = Dataset.from_dict(
    {
        "audio": [
            "/cache/peterr/mezzanine_resources/filled_pauses/data/dev/Iriss-J-Gvecg-P500001-avd_2082.293_2112.194.wav"
        ],
    }
).cast_column("audio", Audio(sampling_rate=16_000, mono=True))


def evaluator(chunks):
    sampling_rate = chunks["audio"][0]["sampling_rate"]
    with torch.no_grad():
        inputs = feature_extractor(
            [i["array"] for i in chunks["audio"]],
            return_tensors="pt",
            sampling_rate=sampling_rate,
        ).to(device)
        logits = model(**inputs).logits
    y_pred = np.array(logits.cpu()).argmax(axis=-1)
    return {"y_pred": y_pred.tolist()}


ds = ds.map(evaluator, batched=True)
print(ds["y_pred"][0])
# Returns a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,....]
# with 0 indicating no filled pause detected in that frame

Citation

Coming soon.