Spaces:
Runtime error
Runtime error
from score_sde.models.projected_discriminator import ProjectedDiscriminator | |
import torch | |
discr = ProjectedDiscriminator(num_discs=4, backbone_kwargs={"cond_size": 768}) | |
x = torch.randn(1,3,224,224) | |
t = torch.randint(0, 1, size=(1,)) | |
cond = (None, torch.randn(1,77, 768), torch.ones(1,77, dtype=torch.bool)) | |
y = discr(x, t, x, cond=cond) | |
print(y.shape) | |