import os import numpy as np import time import torch from utilities import pad_truncate_sequence def move_data_to_device(x, device): if 'float' in str(x.dtype): x = torch.Tensor(x) elif 'int' in str(x.dtype): x = torch.LongTensor(x) else: return x return x.to(device) def append_to_dict(dict, key, value): if key in dict.keys(): dict[key].append(value) else: dict[key] = [value] def forward(model, x, batch_size): """Forward data to model in mini-batch. Args: model: object x: (N, segment_samples) batch_size: int Returns: output_dict: dict, e.g. { 'frame_output': (segments_num, frames_num, classes_num), 'onset_output': (segments_num, frames_num, classes_num), ...} """ output_dict = {} device = next(model.parameters()).device pointer = 0 total_segments = int(np.ceil(len(x) / batch_size)) while True: print('Segment {} / {}'.format(pointer, total_segments)) if pointer >= len(x): break batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device) pointer += batch_size with torch.inference_mode(): model.eval() batch_output_dict = model(batch_waveform) for key in batch_output_dict.keys(): append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy()) for key in output_dict.keys(): output_dict[key] = np.concatenate(output_dict[key], axis=0) return output_dict