import torch import torch.nn as nn import numpy as np from torchvision.utils import save_image, make_grid import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter import os import torchvision.transforms as transforms from torch.utils.data import Dataset from PIL import Image class ResidualConvBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, is_res: bool = False ) -> None: super().__init__() # Check if input and output channels are the same for the residual connection self.same_channels = in_channels == out_channels # Flag for whether or not to use residual connection self.is_res = is_res # First convolutional layer self.conv1 = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 nn.BatchNorm2d(out_channels), # Batch normalization nn.GELU(), # GELU activation function ) # Second convolutional layer self.conv2 = nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 nn.BatchNorm2d(out_channels), # Batch normalization nn.GELU(), # GELU activation function ) def forward(self, x: torch.Tensor) -> torch.Tensor: # If using residual connection if self.is_res: # Apply first convolutional layer x1 = self.conv1(x) # Apply second convolutional layer x2 = self.conv2(x1) # If input and output channels are the same, add residual connection directly if self.same_channels: out = x + x2 else: # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device) out = shortcut(x) + x2 #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}") # Normalize output tensor return out / 1.414 # If not using residual connection, return output of second convolutional layer else: x1 = self.conv1(x) x2 = self.conv2(x1) return x2 # Method to get the number of output channels for this block def get_out_channels(self): return self.conv2[0].out_channels # Method to set the number of output channels for this block def set_out_channels(self, out_channels): self.conv1[0].out_channels = out_channels self.conv2[0].in_channels = out_channels self.conv2[0].out_channels = out_channels class UnetUp(nn.Module): def __init__(self, in_channels, out_channels): super(UnetUp, self).__init__() # Create a list of layers for the upsampling block # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers layers = [ nn.ConvTranspose2d(in_channels, out_channels, 2, 2), ResidualConvBlock(out_channels, out_channels), ResidualConvBlock(out_channels, out_channels), ] # Use the layers to create a sequential model self.model = nn.Sequential(*layers) def forward(self, x, skip): # Concatenate the input tensor x with the skip connection tensor along the channel dimension x = torch.cat((x, skip), 1) # Pass the concatenated tensor through the sequential model and return the output x = self.model(x) return x class UnetDown(nn.Module): def __init__(self, in_channels, out_channels): super(UnetDown, self).__init__() # Create a list of layers for the downsampling block # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)] # Use the layers to create a sequential model self.model = nn.Sequential(*layers) def forward(self, x): # Pass the input through the sequential model and return the output return self.model(x) class EmbedFC(nn.Module): def __init__(self, input_dim, emb_dim): super(EmbedFC, self).__init__() ''' This class defines a generic one layer feed-forward neural network for embedding input data of dimensionality input_dim to an embedding space of dimensionality emb_dim. ''' self.input_dim = input_dim # define the layers for the network layers = [ nn.Linear(input_dim, emb_dim), nn.GELU(), nn.Linear(emb_dim, emb_dim), ] # create a PyTorch sequential model consisting of the defined layers self.model = nn.Sequential(*layers) def forward(self, x): # flatten the input tensor x = x.view(-1, self.input_dim) # apply the model layers to the flattened tensor return self.model(x) def unorm(x): # unity norm. results in range of [0,1] # assume x (h,w,3) xmax = x.max((0,1)) xmin = x.min((0,1)) return(x - xmin)/(xmax - xmin) def norm_all(store, n_t, n_s): # runs unity norm on all timesteps of all samples nstore = np.zeros_like(store) for t in range(n_t): for s in range(n_s): nstore[t,s] = unorm(store[t,s]) return nstore def norm_torch(x_all): # runs unity norm on all timesteps of all samples # input is (n_samples, 3,h,w), the torch image format x = x_all.cpu().numpy() xmax = x.max((2,3)) xmin = x.min((2,3)) xmax = np.expand_dims(xmax,(2,3)) xmin = np.expand_dims(xmin,(2,3)) nstore = (x - xmin)/(xmax - xmin) return torch.from_numpy(nstore) def gen_tst_context(n_cfeat): """ Generate test context vectors """ vec = torch.tensor([ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing ) return len(vec), vec def plot_grid(x,n_sample,n_rows,save_dir,w): # x:(n_sample, 3, h, w) ncols = n_sample//n_rows grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row. save_image(grid, save_dir + f"run_image_w{w}.png") print('saved image at ' + save_dir + f"run_image_w{w}.png") return grid def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): ncols = n_sample//nrows sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w) nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow # create gif of images evolving over time, based on x_gen_store fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) def animate_diff(i, store): print(f'gif animating frame {i} of {store.shape[0]}', end='\r') plots = [] for row in range(nrows): for col in range(ncols): axs[row, col].clear() axs[row, col].set_xticks([]) axs[row, col].set_yticks([]) plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) return plots ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) plt.close() if save: ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") return ani class CustomDataset(Dataset): def __init__(self, sfilename, lfilename, transform, null_context=False): self.sprites = np.load(sfilename) self.slabels = np.load(lfilename) print(f"sprite shape: {self.sprites.shape}") print(f"labels shape: {self.slabels.shape}") self.transform = transform self.null_context = null_context self.sprites_shape = self.sprites.shape self.slabel_shape = self.slabels.shape # Return the number of images in the dataset def __len__(self): return len(self.sprites) # Get the image and label at a given index def __getitem__(self, idx): # Return the image and label as a tuple if self.transform: image = self.transform(self.sprites[idx]) if self.null_context: label = torch.tensor(0).to(torch.int64) else: label = torch.tensor(self.slabels[idx]).to(torch.int64) return (image, label) def getshapes(self): # return shapes of data and labels return self.sprites_shape, self.slabel_shape transform = transforms.Compose([ transforms.ToTensor(), # from [0,255] to range [0.0,1.0] transforms.Normalize((0.5,), (0.5,)) # range [-1,1] ])