Spaces:
Running
Running
import math | |
import numpy as np | |
from transformers import WhisperProcessor | |
class WhisperPrePostProcessor(WhisperProcessor): | |
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size): | |
inputs_len = inputs.shape[0] | |
step = chunk_len - stride_left - stride_right | |
all_chunk_start_idx = np.arange(0, inputs_len, step) | |
num_samples = len(all_chunk_start_idx) | |
num_batches = math.ceil(num_samples / batch_size) | |
batch_idx = np.array_split(np.arange(num_samples), num_batches) | |
for i, idx in enumerate(batch_idx): | |
chunk_start_idx = all_chunk_start_idx[idx] | |
chunk_end_idx = chunk_start_idx + chunk_len | |
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)] | |
processed = self.feature_extractor( | |
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" | |
) | |
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left) | |
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len) | |
_stride_right = np.where(is_last, 0, stride_right) | |
chunk_lens = [chunk.shape[0] for chunk in chunks] | |
strides = [ | |
(int(chunk_l), int(_stride_l), int(_stride_r)) | |
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right) | |
] | |
yield {"stride": strides, **processed} | |
def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None): | |
stride = None | |
if isinstance(inputs, dict): | |
stride = inputs.pop("stride", None) | |
# Accepting `"array"` which is the key defined in `datasets` for | |
# better integration | |
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): | |
raise ValueError( | |
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a " | |
'"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key ' | |
"containing the sampling rate associated with the audio array." | |
) | |
_inputs = inputs.pop("raw", None) | |
if _inputs is None: | |
# Remove path which will not be used from `datasets`. | |
inputs.pop("path", None) | |
_inputs = inputs.pop("array", None) | |
in_sampling_rate = inputs.pop("sampling_rate") | |
inputs = _inputs | |
if in_sampling_rate != self.feature_extractor.sampling_rate: | |
try: | |
import librosa | |
except ImportError as err: | |
raise ImportError( | |
"To support resampling audio files, please install 'librosa' and 'soundfile'." | |
) from err | |
inputs = librosa.resample( | |
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate | |
) | |
ratio = self.feature_extractor.sampling_rate / in_sampling_rate | |
else: | |
ratio = 1 | |
if not isinstance(inputs, np.ndarray): | |
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`.") | |
if len(inputs.shape) != 1: | |
raise ValueError( | |
f"We expect a single channel audio input for the Flax Whisper API, got {len(inputs.shape)} channels." | |
) | |
if stride is not None: | |
if stride[0] + stride[1] > inputs.shape[0]: | |
raise ValueError("Stride is too large for input.") | |
# Stride needs to get the chunk length here, it's going to get | |
# swallowed by the `feature_extractor` later, and then batching | |
# can add extra data in the inputs, so we need to keep track | |
# of the original length in the stride so we can cut properly. | |
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) | |
if chunk_length_s: | |
if stride_length_s is None: | |
stride_length_s = chunk_length_s / 6 | |
if isinstance(stride_length_s, (int, float)): | |
stride_length_s = [stride_length_s, stride_length_s] | |
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) | |
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) | |
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) | |
if chunk_len < stride_left + stride_right: | |
raise ValueError("Chunk length must be superior to stride length.") | |
for item in self.chunk_iter_with_batch( | |
inputs, | |
chunk_len, | |
stride_left, | |
stride_right, | |
batch_size, | |
): | |
yield item | |
else: | |
processed = self.feature_extractor( | |
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" | |
) | |
if stride is not None: | |
processed["stride"] = stride | |
yield processed | |
def postprocess(self, model_outputs, return_timestamps=None, return_language=None): | |
# unpack the outputs from list(dict(list)) to list(dict) | |
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())] | |
time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500 | |
# Send the chunking back to seconds, it's easier to handle in whisper | |
sampling_rate = self.feature_extractor.sampling_rate | |
for output in model_outputs: | |
if "stride" in output: | |
chunk_len, stride_left, stride_right = output["stride"] | |
# Go back in seconds | |
chunk_len /= sampling_rate | |
stride_left /= sampling_rate | |
stride_right /= sampling_rate | |
output["stride"] = chunk_len, stride_left, stride_right | |
text, optional = self.tokenizer._decode_asr( | |
model_outputs, | |
return_timestamps=return_timestamps, | |
return_language=return_language, | |
time_precision=time_precision, | |
) | |
return {"text": text, **optional} | |