tango2-full / diffusers /tests /schedulers /test_scheduler_vq_diffusion.py
soujanyaporia's picture
Upload 1028 files
fd43906 verified
raw
history blame
1.55 kB
import torch
import torch.nn.functional as F
from diffusers import VQDiffusionScheduler
from .test_schedulers import SchedulerCommonTest
class VQDiffusionSchedulerTest(SchedulerCommonTest):
scheduler_classes = (VQDiffusionScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_vec_classes": 4097,
"num_train_timesteps": 100,
}
config.update(**kwargs)
return config
def dummy_sample(self, num_vec_classes):
batch_size = 4
height = 8
width = 8
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
return sample
@property
def dummy_sample_deter(self):
assert False
def dummy_model(self, num_vec_classes):
def model(sample, t, *args):
batch_size, num_latent_pixels = sample.shape
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
return_value = F.log_softmax(logits.double(), dim=1).float()
return return_value
return model
def test_timesteps(self):
for timesteps in [2, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_num_vec_classes(self):
for num_vec_classes in [5, 100, 1000, 4000]:
self.check_over_configs(num_vec_classes=num_vec_classes)
def test_time_indices(self):
for t in [0, 50, 99]:
self.check_over_forward(time_step=t)
def test_add_noise_device(self):
pass