Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Union | |
import torch | |
import transformers | |
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor | |
class DataCollatorCTCWithPadding: | |
""" | |
Data collator that will dynamically pad the inputs received. | |
Args: | |
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`) | |
The feature_extractor used for proccessing the data. | |
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): | |
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) | |
among: | |
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single | |
sequence if provided). | |
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the | |
maximum acceptable input length for the model if that argument is not provided. | |
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of | |
different lengths). | |
max_length (:obj:`int`, `optional`): | |
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). | |
max_length_labels (:obj:`int`, `optional`): | |
Maximum length of the ``labels`` returned list and optionally padding length (see above). | |
pad_to_multiple_of (:obj:`int`, `optional`): | |
If set will pad the sequence to a multiple of the provided value. | |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= | |
7.5 (Volta). | |
""" | |
feature_extractor: Wav2Vec2FeatureExtractor | |
padding: Union[bool, str] = True | |
max_length: Optional[int] = None | |
max_length_labels: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
pad_to_multiple_of_labels: Optional[int] = None | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
input_features = [{"input_values": feature["input_values"]} for feature in features] | |
label_features = [feature["labels"] for feature in features] | |
d_type = torch.long if isinstance(label_features[0], int) else torch.float | |
batch = self.feature_extractor.pad( | |
input_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors="pt", | |
) | |
batch["labels"] = torch.tensor(label_features, dtype=d_type) | |
return batch | |