""" Layer/Module Helpers | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import collections.abc | |
from itertools import repeat | |
# From PyTorch internals | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, collections.abc.Iterable): | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
to_1tuple = _ntuple(1) | |
to_2tuple = _ntuple(2) | |
to_3tuple = _ntuple(3) | |
to_4tuple = _ntuple(4) | |
to_ntuple = _ntuple | |