Kosuke-Szk commited on
Commit
27b35f3
·
1 Parent(s): 9e633bc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -0
README.md CHANGED
@@ -9,4 +9,46 @@ tags:
9
  ---
10
 
11
  # Fine-tuned Japanese Wav2Vec2 model for speech recognition using XLSR-53 large
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
  # Fine-tuned Japanese Wav2Vec2 model for speech recognition using XLSR-53 large
12
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using [Common Voice](https://commonvoice.mozilla.org/ja/datasets), [JVS](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus) and [JSUT](https://sites.google.com/site/shinnosuketakamichi/publication/jsut).
13
+ When using this model, make sure that your speech input is sampled at 16kHz.
14
 
15
+ ## Usage
16
+ The model can be used directly (without a language model) as follows.
17
+
18
+ ```python
19
+ import torch
20
+ import librosa
21
+ from datasets import load_dataset
22
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
23
+
24
+ LANG_ID = "ja"
25
+ MODEL_ID = "Ivydata/wav2vec2-large-xlsr-53-japanese"
26
+ SAMPLES = 10
27
+
28
+ test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
29
+
30
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
31
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
32
+
33
+ # Preprocessing the datasets.
34
+ # We need to read the audio files as arrays
35
+ def speech_file_to_array_fn(batch):
36
+ speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
37
+ batch["speech"] = speech_array
38
+ batch["sentence"] = batch["sentence"].upper()
39
+ return batch
40
+
41
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
42
+ inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
43
+
44
+ with torch.no_grad():
45
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
46
+
47
+ predicted_ids = torch.argmax(logits, dim=-1)
48
+ predicted_sentences = processor.batch_decode(predicted_ids)
49
+
50
+ for i, predicted_sentence in enumerate(predicted_sentences):
51
+ print("-" * 100)
52
+ print("Reference: ", test_dataset[i]["sentence"])
53
+ print("Prediction:", predicted_sentence)
54
+ ```