NCERL-Diverse-PCG / test_ddpm.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
9.07 kB
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import logging
# from tqdm import tqdm
# from torch.utils.tensorboard import SummaryWriter
from src.ddpm.diffusion import Diffusion
from src.ddpm.modules import UNet
# from pytorch_model_summary import summary
# from matplotlib import pyplot as plt
from src.ddpm.dataset import create_dataloader
# from utils.plot import get_img_from_level
from pathlib import Path
# from src.smb.level import MarioLevel
import argparse
import datetime
from src.gan.gankits import process_onehot, get_decoder
from src.smb.level import MarioLevel, lvlhcat, save_batch
from src.utils.filesys import getpath
from src.utils.img import make_img_sheet
# sprite_counts = np.power(np.array([102573, 9114, 1017889, 930, 3032, 7330, 2278, 2279, 5227, 5229, 5419]), 1/4)
sprite_counts = np.power(np.array([
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
]), 1/4
)
min_count = np.min(sprite_counts)
# filepath = Path(__file__).parent.resolve()
# DATA_PATH = os.path.join(filepath, "levels", "ground", "unique_onehot.npz")
def setup_logging(run_name, beta_schedule):
model_path = os.path.join("models", beta_schedule, run_name)
result_path = os.path.join("results", beta_schedule, run_name)
os.makedirs(model_path, exist_ok=True)
os.makedirs(result_path, exist_ok=True)
return model_path, result_path
# def plot_images(epoch, sampled_images, result_path):
# fig = plt.figure(figsize=(30, 15))
# for i in range(len(sampled_images)):
# ax1 = fig.add_subplot(4, int(len(sampled_images)/4), i+1)
# ax1.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
# level = sampled_images[i].argmax(dim=0).cpu().numpy()
# level_img = get_img_from_level(level)
# ax1.imshow(level_img)
# plt.savefig(os.path.join(result_path, f"{epoch:04d}_sample.png"))
# plt.close()
# def plot_training_images(epoch, original_img, x_t, noise, predicted_noise, reconstructed_img, training_result_path):
# fig = plt.figure(figsize=(15, 10))
# for i in range(2):
# ax1 = fig.add_subplot(2, 5, i*5+1)
# ax1.imshow(get_img_from_level(original_img[i].cpu().numpy()))
# ax1.set_title(f"Original {i}")
# ax2 = fig.add_subplot(2, 5, i*5+2)
# ax2.imshow(get_img_from_level(noise[i].cpu().numpy()))
# ax2.set_title(f"Noise {i}")
# ax3 = fig.add_subplot(2, 5, i*5+3)
# ax3.imshow(get_img_from_level(x_t.argmax(dim=1).cpu().numpy()[i]))
# ax3.set_title(f"x_t {i}")
# ax4 = fig.add_subplot(2, 5, i*5+4)
# ax4.imshow(get_img_from_level(predicted_noise[i].cpu().numpy()))
# ax4.set_title(f"Predicted Noise {i}")
# ax5 = fig.add_subplot(2, 5, i*5+5)
# ax5.imshow(get_img_from_level(reconstructed_img.probs.argmax(dim=-1).cpu().numpy()[i]))
# ax5.set_title(f"Reconstructed Image {i}")
# plt.savefig(os.path.join(training_result_path, f"{epoch:04d}.png"))
# plt.close()
def train(args):
# model_path, result_path = setup_logging(args.run_name, args.beta_schedule)
# training_result_path = os.path.join(result_path, "training")
path = getpath(args.res_path)
os.makedirs(path, exist_ok=True)
dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0)
device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}'
model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
mse = nn.MSELoss()
diffusion = Diffusion(device=device, schedule=args.beta_schedule)
# logger = SummaryWriter(os.path.join("logs", args.beta_schedule, args.run_name))
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
l = len(dataloader)
# print(summary(model, torch.zeros((64, MarioLevel.n_types, 14, 14)).to(device), diffusion.sample_timesteps(64).to(device), show_input=True))
# if args.resume_from != 0:
# checkpoint = torch.load(os.path.join(model_path, f'ckpt_{args.resume_from}'))
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
logging.info(f"Starting epoch {epoch}:")
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
# pbar = tqdm(dataloader)
for i, images in enumerate(dataloader):
images = images.to(device)
# print(images.shape)
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
original_img = images.argmax(dim=1) # batch x 14 x 14
reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures)
rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch
mse_loss = mse(noise.float(), predicted_noise.float())
loss = 0.001 * rec_loss + mse_loss
epoch_loss['rec_loss'] += rec_loss.item()
epoch_loss['mse'] += mse_loss.item()
epoch_loss['loss'] += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# pbar.set_postfix(LOSS=loss.item())
# logger.add_scalar("Rec_loss", rec_loss.item(), global_step=(epoch - 1) * l + i)
# logger.add_scalar("MSE", mse_loss.item(), global_step=(epoch - 1) * l + i)
# logger.add_scalar("LOSS", loss.item(), global_step=(epoch - 1) * l + i)
# logger.add_scalar("Epoch_Rec_loss", epoch_loss['rec_loss']/l, global_step=epoch)
# logger.add_scalar("Epoch_MSE", epoch_loss['mse']/l, global_step=epoch)
# logger.add_scalar("Epoch_LOSS", epoch_loss['loss']/l, global_step=epoch)
print(
'\nIteration: %d' % epoch,
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
'mse: %.5g' % (epoch_loss['mse']/l)
)
# if epoch % 20 == 19:
# sampled_images = diffusion.sample(model, n=50)
# imgs = [lvl.to_img() for lvl in process_onehot(sampled_images[-1])]
# make_img_sheet(imgs, 10, save_path=f'{args.res_path}/sample{epoch+1}.png')
# plot_images(epoch, sampled_images[-1], result_path)
# plot_training_images(epoch, original_img, x_t, noise.argmax(dim=1), predicted_noise.argmax(dim=1), reconstructed_img, training_result_path)
if epoch % 1000 == 0:
# torch.save(model.state_dict(), os.path.join(model_path, f"ckpt_{epoch:04d}.pt"))
# torch.save({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'Epoch_Rec_loss': epoch_loss['rec_loss']/l,
# 'Epoch_MSE': epoch_loss['mse']/l,
# 'Epoch_LOSS': epoch_loss['loss']/l
# }, getpath(f"{args.res_path}/ddpm_{epoch}.pt"))
itpath = getpath(path, f'it{epoch}')
os.makedirs(itpath, exist_ok=True)
model.save(getpath(path, itpath, 'ddpm.pth'))
lvls = []
init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy')))
gan = get_decoder()
init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1))
i = 0
for init_seg_onehot in init_seg_onhots:
seg_onehots = diffusion.sample(model, n=25)[-1]
a = init_seg_onehot.view(1, *init_seg_onehot.shape)
b = seg_onehots.detach().cpu()
print(a.shape, b.shape)
segs = process_onehot(torch.cat([a, b], dim=0))
level = lvlhcat(segs)
lvls.append(level)
save_batch(lvls, getpath(path, 'samples.lvls'))
model.save(getpath(path, 'ddpm.pth'))
def launch():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10000)
# parser.add_argument("--data_path", type=str, default=DATA_PATH)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
# parser.add_argument("--image_size", type=int, default=14)
# parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--gpuid", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
parser.add_argument("--resume_from", type=int, default=0)
args = parser.parse_args()
train(args)
if __name__ == "__main__":
launch()