spritediffuser / diffusion_utilities.py
debisoft's picture
lib
442c1b2
import torch
import torch.nn as nn
import numpy as np
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import os
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
class ResidualConvBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
# Check if input and output channels are the same for the residual connection
self.same_channels = in_channels == out_channels
# Flag for whether or not to use residual connection
self.is_res = is_res
# First convolutional layer
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
nn.BatchNorm2d(out_channels), # Batch normalization
nn.GELU(), # GELU activation function
)
# Second convolutional layer
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
nn.BatchNorm2d(out_channels), # Batch normalization
nn.GELU(), # GELU activation function
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# If using residual connection
if self.is_res:
# Apply first convolutional layer
x1 = self.conv1(x)
# Apply second convolutional layer
x2 = self.conv2(x1)
# If input and output channels are the same, add residual connection directly
if self.same_channels:
out = x + x2
else:
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
out = shortcut(x) + x2
#print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
# Normalize output tensor
return out / 1.414
# If not using residual connection, return output of second convolutional layer
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
# Method to get the number of output channels for this block
def get_out_channels(self):
return self.conv2[0].out_channels
# Method to set the number of output channels for this block
def set_out_channels(self, out_channels):
self.conv1[0].out_channels = out_channels
self.conv2[0].in_channels = out_channels
self.conv2[0].out_channels = out_channels
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
# Create a list of layers for the upsampling block
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualConvBlock(out_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
]
# Use the layers to create a sequential model
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
# Concatenate the input tensor x with the skip connection tensor along the channel dimension
x = torch.cat((x, skip), 1)
# Pass the concatenated tensor through the sequential model and return the output
x = self.model(x)
return x
class UnetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetDown, self).__init__()
# Create a list of layers for the downsampling block
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]
# Use the layers to create a sequential model
self.model = nn.Sequential(*layers)
def forward(self, x):
# Pass the input through the sequential model and return the output
return self.model(x)
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
'''
This class defines a generic one layer feed-forward neural network for embedding input data of
dimensionality input_dim to an embedding space of dimensionality emb_dim.
'''
self.input_dim = input_dim
# define the layers for the network
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
# create a PyTorch sequential model consisting of the defined layers
self.model = nn.Sequential(*layers)
def forward(self, x):
# flatten the input tensor
x = x.view(-1, self.input_dim)
# apply the model layers to the flattened tensor
return self.model(x)
def unorm(x):
# unity norm. results in range of [0,1]
# assume x (h,w,3)
xmax = x.max((0,1))
xmin = x.min((0,1))
return(x - xmin)/(xmax - xmin)
def norm_all(store, n_t, n_s):
# runs unity norm on all timesteps of all samples
nstore = np.zeros_like(store)
for t in range(n_t):
for s in range(n_s):
nstore[t,s] = unorm(store[t,s])
return nstore
def norm_torch(x_all):
# runs unity norm on all timesteps of all samples
# input is (n_samples, 3,h,w), the torch image format
x = x_all.cpu().numpy()
xmax = x.max((2,3))
xmin = x.min((2,3))
xmax = np.expand_dims(xmax,(2,3))
xmin = np.expand_dims(xmin,(2,3))
nstore = (x - xmin)/(xmax - xmin)
return torch.from_numpy(nstore)
def gen_tst_context(n_cfeat):
"""
Generate test context vectors
"""
vec = torch.tensor([
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing
)
return len(vec), vec
def plot_grid(x,n_sample,n_rows,save_dir,w):
# x:(n_sample, 3, h, w)
ncols = n_sample//n_rows
grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.
save_image(grid, save_dir + f"run_image_w{w}.png")
print('saved image at ' + save_dir + f"run_image_w{w}.png")
return grid
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
ncols = n_sample//nrows
sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w)
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow
# create gif of images evolving over time, based on x_gen_store
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))
def animate_diff(i, store):
print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
plots = []
for row in range(nrows):
for col in range(ncols):
axs[row, col].clear()
axs[row, col].set_xticks([])
axs[row, col].set_yticks([])
plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))
return plots
ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0])
plt.close()
if save:
ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
print('saved gif at ' + save_dir + f"{fn}_w{w}.gif")
return ani
class CustomDataset(Dataset):
def __init__(self, sfilename, lfilename, transform, null_context=False):
self.sprites = np.load(sfilename)
self.slabels = np.load(lfilename)
print(f"sprite shape: {self.sprites.shape}")
print(f"labels shape: {self.slabels.shape}")
self.transform = transform
self.null_context = null_context
self.sprites_shape = self.sprites.shape
self.slabel_shape = self.slabels.shape
# Return the number of images in the dataset
def __len__(self):
return len(self.sprites)
# Get the image and label at a given index
def __getitem__(self, idx):
# Return the image and label as a tuple
if self.transform:
image = self.transform(self.sprites[idx])
if self.null_context:
label = torch.tensor(0).to(torch.int64)
else:
label = torch.tensor(self.slabels[idx]).to(torch.int64)
return (image, label)
def getshapes(self):
# return shapes of data and labels
return self.sprites_shape, self.slabel_shape
transform = transforms.Compose([
transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
])