|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class AbstractPermuter(nn.Module): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
def forward(self, x, reverse=False): |
|
raise NotImplementedError |
|
|
|
|
|
class Identity(AbstractPermuter): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, reverse=False): |
|
return x |
|
|
|
|
|
class Subsample(AbstractPermuter): |
|
def __init__(self, H, W): |
|
super().__init__() |
|
C = 1 |
|
indices = np.arange(H*W).reshape(C,H,W) |
|
while min(H, W) > 1: |
|
indices = indices.reshape(C,H//2,2,W//2,2) |
|
indices = indices.transpose(0,2,4,1,3) |
|
indices = indices.reshape(C*4,H//2, W//2) |
|
H = H//2 |
|
W = W//2 |
|
C = C*4 |
|
assert H == W == 1 |
|
idx = torch.tensor(indices.ravel()) |
|
self.register_buffer('forward_shuffle_idx', |
|
nn.Parameter(idx, requires_grad=False)) |
|
self.register_buffer('backward_shuffle_idx', |
|
nn.Parameter(torch.argsort(idx), requires_grad=False)) |
|
|
|
def forward(self, x, reverse=False): |
|
if not reverse: |
|
return x[:, self.forward_shuffle_idx] |
|
else: |
|
return x[:, self.backward_shuffle_idx] |
|
|
|
|
|
def mortonify(i, j): |
|
"""(i,j) index to linear morton code""" |
|
i = np.uint64(i) |
|
j = np.uint64(j) |
|
|
|
z = np.uint(0) |
|
|
|
for pos in range(32): |
|
z = (z | |
|
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | |
|
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) |
|
) |
|
return z |
|
|
|
|
|
class ZCurve(AbstractPermuter): |
|
def __init__(self, H, W): |
|
super().__init__() |
|
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] |
|
idx = np.argsort(reverseidx) |
|
idx = torch.tensor(idx) |
|
reverseidx = torch.tensor(reverseidx) |
|
self.register_buffer('forward_shuffle_idx', |
|
idx) |
|
self.register_buffer('backward_shuffle_idx', |
|
reverseidx) |
|
|
|
def forward(self, x, reverse=False): |
|
if not reverse: |
|
return x[:, self.forward_shuffle_idx] |
|
else: |
|
return x[:, self.backward_shuffle_idx] |
|
|
|
|
|
class SpiralOut(AbstractPermuter): |
|
def __init__(self, H, W): |
|
super().__init__() |
|
assert H == W |
|
size = W |
|
indices = np.arange(size*size).reshape(size,size) |
|
|
|
i0 = size//2 |
|
j0 = size//2-1 |
|
|
|
i = i0 |
|
j = j0 |
|
|
|
idx = [indices[i0, j0]] |
|
step_mult = 0 |
|
for c in range(1, size//2+1): |
|
step_mult += 1 |
|
|
|
for k in range(step_mult): |
|
i = i - 1 |
|
j = j |
|
idx.append(indices[i, j]) |
|
|
|
|
|
for k in range(step_mult): |
|
i = i |
|
j = j + 1 |
|
idx.append(indices[i, j]) |
|
|
|
step_mult += 1 |
|
if c < size//2: |
|
|
|
for k in range(step_mult): |
|
i = i + 1 |
|
j = j |
|
idx.append(indices[i, j]) |
|
|
|
|
|
for k in range(step_mult): |
|
i = i |
|
j = j - 1 |
|
idx.append(indices[i, j]) |
|
else: |
|
|
|
for k in range(step_mult-1): |
|
i = i + 1 |
|
idx.append(indices[i, j]) |
|
|
|
assert len(idx) == size*size |
|
idx = torch.tensor(idx) |
|
self.register_buffer('forward_shuffle_idx', idx) |
|
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
|
def forward(self, x, reverse=False): |
|
if not reverse: |
|
return x[:, self.forward_shuffle_idx] |
|
else: |
|
return x[:, self.backward_shuffle_idx] |
|
|
|
|
|
class SpiralIn(AbstractPermuter): |
|
def __init__(self, H, W): |
|
super().__init__() |
|
assert H == W |
|
size = W |
|
indices = np.arange(size*size).reshape(size,size) |
|
|
|
i0 = size//2 |
|
j0 = size//2-1 |
|
|
|
i = i0 |
|
j = j0 |
|
|
|
idx = [indices[i0, j0]] |
|
step_mult = 0 |
|
for c in range(1, size//2+1): |
|
step_mult += 1 |
|
|
|
for k in range(step_mult): |
|
i = i - 1 |
|
j = j |
|
idx.append(indices[i, j]) |
|
|
|
|
|
for k in range(step_mult): |
|
i = i |
|
j = j + 1 |
|
idx.append(indices[i, j]) |
|
|
|
step_mult += 1 |
|
if c < size//2: |
|
|
|
for k in range(step_mult): |
|
i = i + 1 |
|
j = j |
|
idx.append(indices[i, j]) |
|
|
|
|
|
for k in range(step_mult): |
|
i = i |
|
j = j - 1 |
|
idx.append(indices[i, j]) |
|
else: |
|
|
|
for k in range(step_mult-1): |
|
i = i + 1 |
|
idx.append(indices[i, j]) |
|
|
|
assert len(idx) == size*size |
|
idx = idx[::-1] |
|
idx = torch.tensor(idx) |
|
self.register_buffer('forward_shuffle_idx', idx) |
|
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
|
def forward(self, x, reverse=False): |
|
if not reverse: |
|
return x[:, self.forward_shuffle_idx] |
|
else: |
|
return x[:, self.backward_shuffle_idx] |
|
|
|
|
|
class Random(nn.Module): |
|
def __init__(self, H, W): |
|
super().__init__() |
|
indices = np.random.RandomState(1).permutation(H*W) |
|
idx = torch.tensor(indices.ravel()) |
|
self.register_buffer('forward_shuffle_idx', idx) |
|
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
|
def forward(self, x, reverse=False): |
|
if not reverse: |
|
return x[:, self.forward_shuffle_idx] |
|
else: |
|
return x[:, self.backward_shuffle_idx] |
|
|
|
|
|
class AlternateParsing(AbstractPermuter): |
|
def __init__(self, H, W): |
|
super().__init__() |
|
indices = np.arange(W*H).reshape(H,W) |
|
for i in range(1, H, 2): |
|
indices[i, :] = indices[i, ::-1] |
|
idx = indices.flatten() |
|
assert len(idx) == H*W |
|
idx = torch.tensor(idx) |
|
self.register_buffer('forward_shuffle_idx', idx) |
|
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
|
def forward(self, x, reverse=False): |
|
if not reverse: |
|
return x[:, self.forward_shuffle_idx] |
|
else: |
|
return x[:, self.backward_shuffle_idx] |
|
|
|
|
|
if __name__ == "__main__": |
|
p0 = AlternateParsing(16, 16) |
|
print(p0.forward_shuffle_idx) |
|
print(p0.backward_shuffle_idx) |
|
|
|
x = torch.randint(0, 768, size=(11, 256)) |
|
y = p0(x) |
|
xre = p0(y, reverse=True) |
|
assert torch.equal(x, xre) |
|
|
|
p1 = SpiralOut(2, 2) |
|
print(p1.forward_shuffle_idx) |
|
print(p1.backward_shuffle_idx) |
|
|