Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
from torchvision.utils import save_image, make_grid | |
import matplotlib.pyplot as plt | |
from matplotlib.animation import FuncAnimation, PillowWriter | |
import os | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset | |
from PIL import Image | |
class ResidualConvBlock(nn.Module): | |
def __init__( | |
self, in_channels: int, out_channels: int, is_res: bool = False | |
) -> None: | |
super().__init__() | |
# Check if input and output channels are the same for the residual connection | |
self.same_channels = in_channels == out_channels | |
# Flag for whether or not to use residual connection | |
self.is_res = is_res | |
# First convolutional layer | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 | |
nn.BatchNorm2d(out_channels), # Batch normalization | |
nn.GELU(), # GELU activation function | |
) | |
# Second convolutional layer | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 | |
nn.BatchNorm2d(out_channels), # Batch normalization | |
nn.GELU(), # GELU activation function | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# If using residual connection | |
if self.is_res: | |
# Apply first convolutional layer | |
x1 = self.conv1(x) | |
# Apply second convolutional layer | |
x2 = self.conv2(x1) | |
# If input and output channels are the same, add residual connection directly | |
if self.same_channels: | |
out = x + x2 | |
else: | |
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection | |
shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device) | |
out = shortcut(x) + x2 | |
#print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}") | |
# Normalize output tensor | |
return out / 1.414 | |
# If not using residual connection, return output of second convolutional layer | |
else: | |
x1 = self.conv1(x) | |
x2 = self.conv2(x1) | |
return x2 | |
# Method to get the number of output channels for this block | |
def get_out_channels(self): | |
return self.conv2[0].out_channels | |
# Method to set the number of output channels for this block | |
def set_out_channels(self, out_channels): | |
self.conv1[0].out_channels = out_channels | |
self.conv2[0].in_channels = out_channels | |
self.conv2[0].out_channels = out_channels | |
class UnetUp(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(UnetUp, self).__init__() | |
# Create a list of layers for the upsampling block | |
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers | |
layers = [ | |
nn.ConvTranspose2d(in_channels, out_channels, 2, 2), | |
ResidualConvBlock(out_channels, out_channels), | |
ResidualConvBlock(out_channels, out_channels), | |
] | |
# Use the layers to create a sequential model | |
self.model = nn.Sequential(*layers) | |
def forward(self, x, skip): | |
# Concatenate the input tensor x with the skip connection tensor along the channel dimension | |
x = torch.cat((x, skip), 1) | |
# Pass the concatenated tensor through the sequential model and return the output | |
x = self.model(x) | |
return x | |
class UnetDown(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(UnetDown, self).__init__() | |
# Create a list of layers for the downsampling block | |
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling | |
layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)] | |
# Use the layers to create a sequential model | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
# Pass the input through the sequential model and return the output | |
return self.model(x) | |
class EmbedFC(nn.Module): | |
def __init__(self, input_dim, emb_dim): | |
super(EmbedFC, self).__init__() | |
''' | |
This class defines a generic one layer feed-forward neural network for embedding input data of | |
dimensionality input_dim to an embedding space of dimensionality emb_dim. | |
''' | |
self.input_dim = input_dim | |
# define the layers for the network | |
layers = [ | |
nn.Linear(input_dim, emb_dim), | |
nn.GELU(), | |
nn.Linear(emb_dim, emb_dim), | |
] | |
# create a PyTorch sequential model consisting of the defined layers | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
# flatten the input tensor | |
x = x.view(-1, self.input_dim) | |
# apply the model layers to the flattened tensor | |
return self.model(x) | |
def unorm(x): | |
# unity norm. results in range of [0,1] | |
# assume x (h,w,3) | |
xmax = x.max((0,1)) | |
xmin = x.min((0,1)) | |
return(x - xmin)/(xmax - xmin) | |
def norm_all(store, n_t, n_s): | |
# runs unity norm on all timesteps of all samples | |
nstore = np.zeros_like(store) | |
for t in range(n_t): | |
for s in range(n_s): | |
nstore[t,s] = unorm(store[t,s]) | |
return nstore | |
def norm_torch(x_all): | |
# runs unity norm on all timesteps of all samples | |
# input is (n_samples, 3,h,w), the torch image format | |
x = x_all.cpu().numpy() | |
xmax = x.max((2,3)) | |
xmin = x.min((2,3)) | |
xmax = np.expand_dims(xmax,(2,3)) | |
xmin = np.expand_dims(xmin,(2,3)) | |
nstore = (x - xmin)/(xmax - xmin) | |
return torch.from_numpy(nstore) | |
def gen_tst_context(n_cfeat): | |
""" | |
Generate test context vectors | |
""" | |
vec = torch.tensor([ | |
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing | |
) | |
return len(vec), vec | |
def plot_grid(x,n_sample,n_rows,save_dir,w): | |
# x:(n_sample, 3, h, w) | |
ncols = n_sample//n_rows | |
grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row. | |
save_image(grid, save_dir + f"run_image_w{w}.png") | |
print('saved image at ' + save_dir + f"run_image_w{w}.png") | |
return grid | |
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): | |
ncols = n_sample//nrows | |
sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w) | |
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow | |
# create gif of images evolving over time, based on x_gen_store | |
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) | |
def animate_diff(i, store): | |
print(f'gif animating frame {i} of {store.shape[0]}', end='\r') | |
plots = [] | |
for row in range(nrows): | |
for col in range(ncols): | |
axs[row, col].clear() | |
axs[row, col].set_xticks([]) | |
axs[row, col].set_yticks([]) | |
plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) | |
return plots | |
ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) | |
plt.close() | |
if save: | |
ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) | |
print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") | |
return ani | |
class CustomDataset(Dataset): | |
def __init__(self, sfilename, lfilename, transform, null_context=False): | |
self.sprites = np.load(sfilename) | |
self.slabels = np.load(lfilename) | |
print(f"sprite shape: {self.sprites.shape}") | |
print(f"labels shape: {self.slabels.shape}") | |
self.transform = transform | |
self.null_context = null_context | |
self.sprites_shape = self.sprites.shape | |
self.slabel_shape = self.slabels.shape | |
# Return the number of images in the dataset | |
def __len__(self): | |
return len(self.sprites) | |
# Get the image and label at a given index | |
def __getitem__(self, idx): | |
# Return the image and label as a tuple | |
if self.transform: | |
image = self.transform(self.sprites[idx]) | |
if self.null_context: | |
label = torch.tensor(0).to(torch.int64) | |
else: | |
label = torch.tensor(self.slabels[idx]).to(torch.int64) | |
return (image, label) | |
def getshapes(self): | |
# return shapes of data and labels | |
return self.sprites_shape, self.slabel_shape | |
transform = transforms.Compose([ | |
transforms.ToTensor(), # from [0,255] to range [0.0,1.0] | |
transforms.Normalize((0.5,), (0.5,)) # range [-1,1] | |
]) | |