asigalov61's picture
Update pytorch_utils.py
5179cdc verified
raw
history blame contribute delete
No virus
1.6 kB
import os
import numpy as np
import time
import torch
from utilities import pad_truncate_sequence
def move_data_to_device(x, device):
if 'float' in str(x.dtype):
x = torch.Tensor(x)
elif 'int' in str(x.dtype):
x = torch.LongTensor(x)
else:
return x
return x.to(device)
def append_to_dict(dict, key, value):
if key in dict.keys():
dict[key].append(value)
else:
dict[key] = [value]
def forward(model, x, batch_size):
"""Forward data to model in mini-batch.
Args:
model: object
x: (N, segment_samples)
batch_size: int
Returns:
output_dict: dict, e.g. {
'frame_output': (segments_num, frames_num, classes_num),
'onset_output': (segments_num, frames_num, classes_num),
...}
"""
output_dict = {}
device = next(model.parameters()).device
pointer = 0
total_segments = int(np.ceil(len(x) / batch_size))
while True:
print('Segment {} / {}'.format(pointer, total_segments))
if pointer >= len(x):
break
batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device)
pointer += batch_size
with torch.inference_mode():
model.eval()
batch_output_dict = model(batch_waveform)
for key in batch_output_dict.keys():
append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())
for key in output_dict.keys():
output_dict[key] = np.concatenate(output_dict[key], axis=0)
return output_dict