Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
import logging | |
# from tqdm import tqdm | |
# from torch.utils.tensorboard import SummaryWriter | |
from src.ddpm.diffusion import Diffusion | |
from src.ddpm.modules import UNet | |
# from pytorch_model_summary import summary | |
# from matplotlib import pyplot as plt | |
from src.ddpm.dataset import create_dataloader | |
# from utils.plot import get_img_from_level | |
from pathlib import Path | |
# from src.smb.level import MarioLevel | |
import argparse | |
import datetime | |
from src.gan.gankits import process_onehot, get_decoder | |
from src.smb.level import MarioLevel, lvlhcat, save_batch | |
from src.utils.filesys import getpath | |
from src.utils.img import make_img_sheet | |
# sprite_counts = np.power(np.array([102573, 9114, 1017889, 930, 3032, 7330, 2278, 2279, 5227, 5229, 5419]), 1/4) | |
sprite_counts = np.power(np.array([ | |
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428 | |
]), 1/4 | |
) | |
min_count = np.min(sprite_counts) | |
# filepath = Path(__file__).parent.resolve() | |
# DATA_PATH = os.path.join(filepath, "levels", "ground", "unique_onehot.npz") | |
def setup_logging(run_name, beta_schedule): | |
model_path = os.path.join("models", beta_schedule, run_name) | |
result_path = os.path.join("results", beta_schedule, run_name) | |
os.makedirs(model_path, exist_ok=True) | |
os.makedirs(result_path, exist_ok=True) | |
return model_path, result_path | |
# def plot_images(epoch, sampled_images, result_path): | |
# fig = plt.figure(figsize=(30, 15)) | |
# for i in range(len(sampled_images)): | |
# ax1 = fig.add_subplot(4, int(len(sampled_images)/4), i+1) | |
# ax1.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False) | |
# level = sampled_images[i].argmax(dim=0).cpu().numpy() | |
# level_img = get_img_from_level(level) | |
# ax1.imshow(level_img) | |
# plt.savefig(os.path.join(result_path, f"{epoch:04d}_sample.png")) | |
# plt.close() | |
# def plot_training_images(epoch, original_img, x_t, noise, predicted_noise, reconstructed_img, training_result_path): | |
# fig = plt.figure(figsize=(15, 10)) | |
# for i in range(2): | |
# ax1 = fig.add_subplot(2, 5, i*5+1) | |
# ax1.imshow(get_img_from_level(original_img[i].cpu().numpy())) | |
# ax1.set_title(f"Original {i}") | |
# ax2 = fig.add_subplot(2, 5, i*5+2) | |
# ax2.imshow(get_img_from_level(noise[i].cpu().numpy())) | |
# ax2.set_title(f"Noise {i}") | |
# ax3 = fig.add_subplot(2, 5, i*5+3) | |
# ax3.imshow(get_img_from_level(x_t.argmax(dim=1).cpu().numpy()[i])) | |
# ax3.set_title(f"x_t {i}") | |
# ax4 = fig.add_subplot(2, 5, i*5+4) | |
# ax4.imshow(get_img_from_level(predicted_noise[i].cpu().numpy())) | |
# ax4.set_title(f"Predicted Noise {i}") | |
# ax5 = fig.add_subplot(2, 5, i*5+5) | |
# ax5.imshow(get_img_from_level(reconstructed_img.probs.argmax(dim=-1).cpu().numpy()[i])) | |
# ax5.set_title(f"Reconstructed Image {i}") | |
# plt.savefig(os.path.join(training_result_path, f"{epoch:04d}.png")) | |
# plt.close() | |
def train(args): | |
# model_path, result_path = setup_logging(args.run_name, args.beta_schedule) | |
# training_result_path = os.path.join(result_path, "training") | |
path = getpath(args.res_path) | |
os.makedirs(path, exist_ok=True) | |
dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0) | |
device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}' | |
model = UNet().to(device) | |
optimizer = optim.AdamW(model.parameters(), lr=args.lr) | |
mse = nn.MSELoss() | |
diffusion = Diffusion(device=device, schedule=args.beta_schedule) | |
# logger = SummaryWriter(os.path.join("logs", args.beta_schedule, args.run_name)) | |
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device) | |
l = len(dataloader) | |
# print(summary(model, torch.zeros((64, MarioLevel.n_types, 14, 14)).to(device), diffusion.sample_timesteps(64).to(device), show_input=True)) | |
# if args.resume_from != 0: | |
# checkpoint = torch.load(os.path.join(model_path, f'ckpt_{args.resume_from}')) | |
# model.load_state_dict(checkpoint['model_state_dict']) | |
# optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1): | |
logging.info(f"Starting epoch {epoch}:") | |
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0} | |
# pbar = tqdm(dataloader) | |
for i, images in enumerate(dataloader): | |
images = images.to(device) | |
# print(images.shape) | |
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000 | |
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise | |
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta | |
original_img = images.argmax(dim=1) # batch x 14 x 14 | |
reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures) | |
rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch | |
mse_loss = mse(noise.float(), predicted_noise.float()) | |
loss = 0.001 * rec_loss + mse_loss | |
epoch_loss['rec_loss'] += rec_loss.item() | |
epoch_loss['mse'] += mse_loss.item() | |
epoch_loss['loss'] += loss.item() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# pbar.set_postfix(LOSS=loss.item()) | |
# logger.add_scalar("Rec_loss", rec_loss.item(), global_step=(epoch - 1) * l + i) | |
# logger.add_scalar("MSE", mse_loss.item(), global_step=(epoch - 1) * l + i) | |
# logger.add_scalar("LOSS", loss.item(), global_step=(epoch - 1) * l + i) | |
# logger.add_scalar("Epoch_Rec_loss", epoch_loss['rec_loss']/l, global_step=epoch) | |
# logger.add_scalar("Epoch_MSE", epoch_loss['mse']/l, global_step=epoch) | |
# logger.add_scalar("Epoch_LOSS", epoch_loss['loss']/l, global_step=epoch) | |
print( | |
'\nIteration: %d' % epoch, | |
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l), | |
'mse: %.5g' % (epoch_loss['mse']/l) | |
) | |
# if epoch % 20 == 19: | |
# sampled_images = diffusion.sample(model, n=50) | |
# imgs = [lvl.to_img() for lvl in process_onehot(sampled_images[-1])] | |
# make_img_sheet(imgs, 10, save_path=f'{args.res_path}/sample{epoch+1}.png') | |
# plot_images(epoch, sampled_images[-1], result_path) | |
# plot_training_images(epoch, original_img, x_t, noise.argmax(dim=1), predicted_noise.argmax(dim=1), reconstructed_img, training_result_path) | |
if epoch % 1000 == 0: | |
# torch.save(model.state_dict(), os.path.join(model_path, f"ckpt_{epoch:04d}.pt")) | |
# torch.save({ | |
# 'epoch': epoch, | |
# 'model_state_dict': model.state_dict(), | |
# 'optimizer_state_dict': optimizer.state_dict(), | |
# 'Epoch_Rec_loss': epoch_loss['rec_loss']/l, | |
# 'Epoch_MSE': epoch_loss['mse']/l, | |
# 'Epoch_LOSS': epoch_loss['loss']/l | |
# }, getpath(f"{args.res_path}/ddpm_{epoch}.pt")) | |
itpath = getpath(path, f'it{epoch}') | |
os.makedirs(itpath, exist_ok=True) | |
model.save(getpath(path, itpath, 'ddpm.pth')) | |
lvls = [] | |
init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy'))) | |
gan = get_decoder() | |
init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1)) | |
i = 0 | |
for init_seg_onehot in init_seg_onhots: | |
seg_onehots = diffusion.sample(model, n=25)[-1] | |
a = init_seg_onehot.view(1, *init_seg_onehot.shape) | |
b = seg_onehots.detach().cpu() | |
print(a.shape, b.shape) | |
segs = process_onehot(torch.cat([a, b], dim=0)) | |
level = lvlhcat(segs) | |
lvls.append(level) | |
save_batch(lvls, getpath(path, 'samples.lvls')) | |
model.save(getpath(path, 'ddpm.pth')) | |
def launch(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--epochs", type=int, default=10000) | |
# parser.add_argument("--data_path", type=str, default=DATA_PATH) | |
parser.add_argument("--batch_size", type=int, default=256) | |
parser.add_argument("--res_path", type=str, default='exp_data/DDPM') | |
# parser.add_argument("--image_size", type=int, default=14) | |
# parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--gpuid", type=int, default=0) | |
parser.add_argument("--lr", type=float, default=3e-4) | |
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid']) | |
parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}") | |
parser.add_argument("--resume_from", type=int, default=0) | |
args = parser.parse_args() | |
train(args) | |
if __name__ == "__main__": | |
launch() | |