Edit model card

wav2vec2-bert-uk

๐Ÿ‡บ๐Ÿ‡ฆ Join our Discord server - https://discord.gg/yVAjkBgmt4 - where we're talking about Data Science, Machine Learning, Deep Learning, and Artificial Intelligence

๐Ÿ‡บ๐Ÿ‡ฆ Join our Speech Recognition Group in Telegram: https://t.me/speech_recognition_uk

Google Colab

You can run this model using a Google Colab notebook: https://colab.research.google.com/drive/1QoKw2DWo5a5XYw870cfGE3dJf1WjZgrj?usp=sharing

Metrics

  • AM:
    • WER: 0.0727
    • CER: 0.0151
    • Accuracy: 92.73%
  • AM + LM:
    • WER: 0.0655
    • CER: 0.0139
    • Accuracy: 93.45%

Hyperparameters

This model was trained with the following hparams using 2 RTX A4000:

torchrun --standalone --nnodes=1 --nproc-per-node=2 ../train_w2v2_bert.py \
  --custom_set ~/cv10/train.csv \
  --custom_set_eval ~/cv10/test.csv \
  --num_train_epochs 15 \
  --tokenize_config . \
  --w2v2_bert_model facebook/w2v-bert-2.0 \
  --batch 4 \
  --num_proc 5 \
  --grad_accum 1 \
  --learning_rate 3e-5 \
  --logging_steps 20 \
  --eval_step 500 \
  --group_by_length \
  --attention_dropout 0.0 \
  --activation_dropout 0.05 \
  --feat_proj_dropout 0.05 \
  --feat_quantizer_dropout 0.0 \
  --hidden_dropout 0.05 \
  --layerdrop 0.0 \
  --final_dropout 0.0 \
  --mask_time_prob 0.0 \
  --mask_time_length 10 \
  --mask_feature_prob 0.0 \
  --mask_feature_length 10

Usage

# pip install -U torch soundfile transformers

import torch
import soundfile as sf
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor

# Config
model_name = 'Yehor/w2v-bert-2.0-uk'
device = 'cuda:1' # or cpu
sampling_rate = 16_000

# Load the model
asr_model = AutoModelForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)

paths = [
  'sample1.wav',
]

# Extract audio
audio_inputs = []
for path in paths:
  audio_input, _ = sf.read(path)
  audio_inputs.append(audio_input)

# Transcribe the audio
inputs = processor(audio_inputs, sampling_rate=sampling_rate).input_features
features = torch.tensor(inputs).to(device)

with torch.no_grad():
  logits = asr_model(features).logits

predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)

# Log results
print('Predictions:')
print(predictions)

Licenses

Downloads last month
315
Safetensors
Model size
606M params
Tensor type
F32
ยท
Inference API
or
This model can be loaded on Inference API (serverless).

Finetuned from

Space using Yehor/w2v-bert-2.0-uk 1

Evaluation results