update model card
Browse files
README.md
CHANGED
@@ -26,12 +26,12 @@ from datasets import load_dataset
|
|
26 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
27 |
|
28 |
model_name = "elgeish/wav2vec2-base-timit-asr"
|
29 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name
|
30 |
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
31 |
model.eval()
|
32 |
|
33 |
dataset = load_dataset("timit_asr", split="test").shuffle().select(range(10))
|
34 |
-
char_translations = str.maketrans({"-": " ", ".": "", "?": ""})
|
35 |
|
36 |
def prepare_example(example):
|
37 |
example["speech"], _ = sf.read(example["file"])
|
@@ -47,6 +47,7 @@ with torch.no_grad():
|
|
47 |
predicted_ids = torch.argmax(model(inputs.input_values).logits, dim=-1)
|
48 |
predicted_ids[predicted_ids == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
|
49 |
predicted_transcripts = processor.tokenizer.batch_decode(predicted_ids)
|
|
|
50 |
for reference, predicted in zip(dataset["text"], predicted_transcripts):
|
51 |
print("reference:", reference)
|
52 |
print("predicted:", predicted)
|
|
|
26 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
27 |
|
28 |
model_name = "elgeish/wav2vec2-base-timit-asr"
|
29 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
30 |
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
31 |
model.eval()
|
32 |
|
33 |
dataset = load_dataset("timit_asr", split="test").shuffle().select(range(10))
|
34 |
+
char_translations = str.maketrans({"-": " ", ",": "", ".": "", "?": ""})
|
35 |
|
36 |
def prepare_example(example):
|
37 |
example["speech"], _ = sf.read(example["file"])
|
|
|
47 |
predicted_ids = torch.argmax(model(inputs.input_values).logits, dim=-1)
|
48 |
predicted_ids[predicted_ids == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
|
49 |
predicted_transcripts = processor.tokenizer.batch_decode(predicted_ids)
|
50 |
+
|
51 |
for reference, predicted in zip(dataset["text"], predicted_transcripts):
|
52 |
print("reference:", reference)
|
53 |
print("predicted:", predicted)
|