yuwd's picture
init
03f6091
raw
history blame
309 Bytes
import torch
def collate_fn(batch):
if isinstance(batch, tuple) and isinstance(batch[0], list):
return batch
elif isinstance(batch, list):
transposed = list(zip(*batch))
return [collate_fn(samples) for samples in transposed]
return torch.utils.data.default_collate(batch)