Spaces:
Runtime error
Runtime error
File size: 5,397 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
import torch
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict):
"""
Parse data json and create dataset.
See scripts/asr_prep_json.py which pack json from raw files
Json example:
{
"utts": {
"4771-29403-0025": {
"input": {
"length_ms": 170,
"path": "/tmp/file1.flac"
},
"output": {
"text": "HELLO \n",
"token": "HE LLO",
"tokenid": "4815, 861"
}
},
"1564-142299-0096": {
...
}
}
"""
if not os.path.isfile(data_json_path):
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
with open(data_json_path, "rb") as f:
data_samples = json.load(f)["utts"]
assert len(data_samples) != 0
sorted_samples = sorted(
data_samples.items(),
key=lambda sample: int(sample[1]["input"]["length_ms"]),
reverse=True,
)
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
ids = [s[0] for s in sorted_samples]
speakers = []
for s in sorted_samples:
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
for s in sorted_samples
]
# append eos
tgt = [[*t, tgt_dict.eos()] for t in tgt]
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
@register_task("speech_recognition")
class SpeechRecognitionTask(LegacyFairseqTask):
"""
Task for training speech recognition model.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
parser.add_argument(
"--max-source-positions",
default=sys.maxsize,
type=int,
metavar="N",
help="max number of frames in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries)."""
dict_path = os.path.join(args.data, "dict.txt")
if not os.path.isfile(dict_path):
raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path)
if args.criterion == "ctc_loss":
tgt_dict.add_symbol("<ctc_blank>")
elif args.criterion == "asg_loss":
for i in range(1, args.max_replabel + 1):
tgt_dict.add_symbol(replabel_symbol(i))
print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
def build_generator(self, models, args, **unused):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, self.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(models, args)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.tgt_dict
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
return None
def max_positions(self):
"""Return the max speech and sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
|