callum-canavan's picture
Add helpers, change to hot dog example
954caab
import torch
from einops import rearrange
from .permutations import get_inv_perm
from .view_base import BaseView
class PermuteView(BaseView):
def __init__(self, perm_64, perm_256):
'''
Implements arbitrary pixel permutations, for a given permutation.
We need two permutations. One of size 64x64 for stage 1, and
one of size 256x256 for stage 2.
perm_64 (torch.tensor) :
Tensor of integer indexes, defining a permutation, of size 64*64
perm_256 (torch.tensor) :
Tensor of integer indexes, defining a permutation, of size 256*256
'''
assert perm_64.shape == torch.Size([64*64]), \
"`perm_64` must be a permutation tensor of size 64*64"
assert perm_256.shape == torch.Size([256*256]), \
"`perm_256` must be a permutation tensor of size 256*256"
# Get random permutation and inverse permutation for stage 1
self.perm_64 = perm_64
self.perm_64_inv = get_inv_perm(self.perm_64)
# Get random permutation and inverse permutation for stage 2
self.perm_256 = perm_256
self.perm_256_inv = get_inv_perm(self.perm_256)
def view(self, im):
im_size = im.shape[-1]
perm = self.perm_64 if im_size == 64 else self.perm_256
num_patches = im_size
# Permute every pixel in the image
patch_size = 1
# Reshape into patches of size (c, patch_size, patch_size)
patches = rearrange(im,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Permute
patches = patches[perm]
# Reshape back into image
im_rearr = rearrange(patches,
'(h w) c p1 p2 -> c (h p1) (w p2)',
h=num_patches,
w=num_patches,
p1=patch_size,
p2=patch_size)
return im_rearr
def inverse_view(self, noise):
im_size = noise.shape[-1]
perm_inv = self.perm_64_inv if im_size == 64 else self.perm_256_inv
num_patches = im_size
# Permute every pixel in the image
patch_size = 1
# Reshape into patches of size (c, patch_size, patch_size)
patches = rearrange(noise,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Apply inverse permutation
patches = patches[perm_inv]
# Reshape back into image
im_rearr = rearrange(patches,
'(h w) c p1 p2 -> c (h p1) (w p2)',
h=num_patches,
w=num_patches,
p1=patch_size,
p2=patch_size)
return im_rearr
def make_frame(self, im, t):
# TODO: Implement this, as just moving pixels around
raise NotImplementedError()