elsayedissa commited on
Commit
484b1d2
1 Parent(s): f3b7d4f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -2
README.md CHANGED
@@ -79,17 +79,71 @@ commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=160
79
  sample = next(iter(commonvoice_eval))["audio"]
80
 
81
  # features and generate token ids
82
- input_features = processor(sample["array"], sampling_rate=input_speech["sampling_rate"], return_tensors="pt").input_features
83
  predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
84
 
85
  # decode
86
- transcription = processor.batch_decode(predicted_ids)
87
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
88
 
89
  print(transcription)
90
 
91
  ```
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ### Framework versions
94
 
95
  - Transformers 4.26.0.dev0
 
79
  sample = next(iter(commonvoice_eval))["audio"]
80
 
81
  # features and generate token ids
82
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
83
  predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
84
 
85
  # decode
 
86
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
87
 
88
  print(transcription)
89
 
90
  ```
91
 
92
+ ### Evaluation:
93
+
94
+ Evaluates this model on `mozilla-foundation/common_voice_11_0` test split.
95
+
96
+ ```python
97
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
98
+ from datasets import load_dataset, Audio
99
+ import evaluate
100
+ import torch
101
+ import re
102
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
103
+
104
+ # device
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+
107
+ # metric
108
+ wer_metric = evaluate.load("wer")
109
+
110
+ # model
111
+ processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-japanese-5k-steps")
112
+ model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-japanese-5k-steps")
113
+
114
+ # dataset
115
+ dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ja", split="test", ) #cache_dir=args.cache_dir
116
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
117
+
118
+ #for debuggings: it gets some examples
119
+ #dataset = dataset.shard(num_shards=7000, index=0)
120
+ #print(dataset)
121
+
122
+ def normalize(batch):
123
+ batch["gold_text"] = whisper_norm(batch['sentence'])
124
+ return batch
125
+
126
+ def map_wer(batch):
127
+ model.to(device)
128
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ja", task = "transcribe")
129
+ inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
130
+ with torch.no_grad():
131
+ generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
132
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
133
+ batch["predicted_text"] = whisper_norm(transcription)
134
+ return batch
135
+
136
+ # process GOLD text
137
+ processed_dataset = dataset.map(normalize)
138
+ # get predictions
139
+ predicted = processed_dataset.map(map_wer)
140
+
141
+ # word error rate
142
+ wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
143
+ wer = round(100 * wer, 2)
144
+ print("WER:", wer)
145
+ ```
146
+
147
  ### Framework versions
148
 
149
  - Transformers 4.26.0.dev0