Spaces:
Sleeping
Sleeping
File size: 4,601 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
from torch.distributions import Categorical
from tqdm import trange
from src.ddpm.dataset import create_dataloader
import numpy as np
import os
from pathlib import Path
from src.smb.level import MarioLevel
DATAPATH = os.path.join(Path(__file__).parent.resolve(), "levels", "ground", "unique_onehot.npz")
class Diffusion:
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=MarioLevel.seg_width, device="cuda", schedule='linear'):
self.noise_steps = noise_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.img_size = img_size
self.device = device
self.schedule = schedule
# Prepare noise schedule according to the selected schedule type
self.beta = self.prepare_noise_schedule().to(device)
self.alpha = 1. - self.beta
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
def prepare_noise_schedule(self):
# Linear beta schedule
if self.schedule == 'linear':
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
# Quadratic beta schedule
elif self.schedule == 'quadratic':
return (torch.linspace(self.beta_start**0.5, self.beta_end**0.5, self.noise_steps) ** 2)
# Sigmoid beta schedule
elif self.schedule == 'sigmoid':
s = torch.tensor(10.0)
betas = torch.sigmoid(torch.linspace(-s, s, self.noise_steps))
return betas * (self.beta_end - self.beta_start) + self.beta_start
else:
raise ValueError("Invalid schedule type. Supported schedules: 'linear', 'quadratic', and 'sigmoid'.")
def noise_images(self, x_0, t):
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None, None, None]
eps = torch.randn_like(x_0)
return sqrt_alpha_hat * x_0 + sqrt_one_minus_alpha_hat * eps, eps
def sample_timesteps(self, n):
return torch.randint(low=1, high=self.noise_steps, size=(n,))
def sample_only_final(self, x_t, t, noise, temperature=None):
'''
using the reparameterization trick we represent x_0 and return categorical distribution
'''
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None, None, None]
logits = (1.0/sqrt_alpha_hat) * (x_t - sqrt_one_minus_alpha_hat * noise)
# Check if a temperature tensor is provided
if temperature is not None:
# Ensure the temperature tensor has the same shape as the logits tensor
temperature = temperature.view(1, MarioLevel.n_types, 1, 1)
# Divide the logits by the temperature tensor element-wise
logits = logits / temperature
return Categorical(logits=logits.permute(0, 2, 3, 1))
def sample(self, model, n):
print(f"Sampling {n} new images....")
imgs = []
model.eval()
with torch.no_grad():
x = torch.randn((n, MarioLevel.n_types, self.img_size, self.img_size)).to(self.device)
imgs.append(x)
for i in trange(self.noise_steps - 1, 0, -1, position=0):
t = (torch.ones(n) * i).long().to(self.device)
predicted_noise = model(x, t)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
if i % 100 == 1:
imgs.append(x)
model.train()
return imgs
if __name__ == '__main__':
diffusion = Diffusion(schedule='quadratic')
data = create_dataloader(DATAPATH).dataset.data_
noise = torch.randn((MarioLevel.n_types, MarioLevel.seg_width, MarioLevel.height))
level = data[-15]
for t in [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
noised_level = diffusion.noise_images(torch.tensor(level).to("cuda"), torch.tensor([t]).to("cuda"))[0]
noised_level = noised_level[0].cpu().numpy()
noised_level = np.argmax(noised_level, axis=0)
# plt.imshow(get_img_from_level(noised_level))
# plt.show() |