conex / espnet2 /samplers /build_batch_sampler.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
6.21 kB
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union
from typeguard import check_argument_types
from typeguard import check_return_type
from espnet2.samplers.abs_sampler import AbsSampler
from espnet2.samplers.folded_batch_sampler import FoldedBatchSampler
from espnet2.samplers.length_batch_sampler import LengthBatchSampler
from espnet2.samplers.num_elements_batch_sampler import NumElementsBatchSampler
from espnet2.samplers.sorted_batch_sampler import SortedBatchSampler
from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler
BATCH_TYPES = dict(
unsorted="UnsortedBatchSampler has nothing in paticular feature and "
"just creates mini-batches which has constant batch_size. "
"This sampler doesn't require any length "
"information for each feature. "
"'key_file' is just a text file which describes each sample name."
"\n\n"
" utterance_id_a\n"
" utterance_id_b\n"
" utterance_id_c\n"
"\n"
"The fist column is referred, so 'shape file' can be used, too.\n\n"
" utterance_id_a 100,80\n"
" utterance_id_b 400,80\n"
" utterance_id_c 512,80\n",
sorted="SortedBatchSampler sorts samples by the length of the first input "
" in order to make each sample in a mini-batch has close length. "
"This sampler requires a text file which describes the length for each sample "
"\n\n"
" utterance_id_a 1000\n"
" utterance_id_b 1453\n"
" utterance_id_c 1241\n"
"\n"
"The first element of feature dimensions is referred, "
"so 'shape_file' can be also used.\n\n"
" utterance_id_a 1000,80\n"
" utterance_id_b 1453,80\n"
" utterance_id_c 1241,80\n",
folded="FoldedBatchSampler supports variable batch_size. "
"The batch_size is decided by\n"
" batch_size = base_batch_size // (L // fold_length)\n"
"L is referred to the largest length of samples in the mini-batch. "
"This samples requires length information as same as SortedBatchSampler\n",
length="LengthBatchSampler supports variable batch_size. "
"This sampler makes mini-batches which have same number of 'bins' as possible "
"counting by the total lengths of each feature in the mini-batch. "
"This sampler requires a text file which describes the length for each sample. "
"\n\n"
" utterance_id_a 1000\n"
" utterance_id_b 1453\n"
" utterance_id_c 1241\n"
"\n"
"The first element of feature dimensions is referred, "
"so 'shape_file' can be also used.\n\n"
" utterance_id_a 1000,80\n"
" utterance_id_b 1453,80\n"
" utterance_id_c 1241,80\n",
numel="NumElementsBatchSampler supports variable batch_size. "
"Just like LengthBatchSampler, this sampler makes mini-batches"
" which have same number of 'bins' as possible "
"counting by the total number of elements of each feature "
"instead of the length. "
"Thus this sampler requires the full information of the dimension of the features. "
"\n\n"
" utterance_id_a 1000,80\n"
" utterance_id_b 1453,80\n"
" utterance_id_c 1241,80\n",
)
def build_batch_sampler(
type: str,
batch_size: int,
batch_bins: int,
shape_files: Union[Tuple[str, ...], List[str]],
sort_in_batch: str = "descending",
sort_batch: str = "ascending",
drop_last: bool = False,
min_batch_size: int = 1,
fold_lengths: Sequence[int] = (),
padding: bool = True,
utt2category_file: str = None,
) -> AbsSampler:
"""Helper function to instantiate BatchSampler.
Args:
type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length"
batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode
batch_bins: Used for "numel" model
shape_files: Text files describing the length and dimension
of each features. e.g. uttA 1330,80
sort_in_batch:
sort_batch:
drop_last:
min_batch_size: Used for "numel" or "folded" mode
fold_lengths: Used for "folded" mode
padding: Whether sequences are input as a padded tensor or not.
used for "numel" mode
"""
assert check_argument_types()
if len(shape_files) == 0:
raise ValueError("No shape file are given")
if type == "unsorted":
retval = UnsortedBatchSampler(
batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last
)
elif type == "sorted":
retval = SortedBatchSampler(
batch_size=batch_size,
shape_file=shape_files[0],
sort_in_batch=sort_in_batch,
sort_batch=sort_batch,
drop_last=drop_last,
)
elif type == "folded":
if len(fold_lengths) != len(shape_files):
raise ValueError(
f"The number of fold_lengths must be equal to "
f"the number of shape_files: "
f"{len(fold_lengths)} != {len(shape_files)}"
)
retval = FoldedBatchSampler(
batch_size=batch_size,
shape_files=shape_files,
fold_lengths=fold_lengths,
sort_in_batch=sort_in_batch,
sort_batch=sort_batch,
drop_last=drop_last,
min_batch_size=min_batch_size,
utt2category_file=utt2category_file,
)
elif type == "numel":
retval = NumElementsBatchSampler(
batch_bins=batch_bins,
shape_files=shape_files,
sort_in_batch=sort_in_batch,
sort_batch=sort_batch,
drop_last=drop_last,
padding=padding,
min_batch_size=min_batch_size,
)
elif type == "length":
retval = LengthBatchSampler(
batch_bins=batch_bins,
shape_files=shape_files,
sort_in_batch=sort_in_batch,
sort_batch=sort_batch,
drop_last=drop_last,
padding=padding,
min_batch_size=min_batch_size,
)
else:
raise ValueError(f"Not supported: {type}")
assert check_return_type(retval)
return retval