Mehdi Cherti
update
e96a195
from score_sde.models.projected_discriminator import ProjectedDiscriminator
import torch
discr = ProjectedDiscriminator(num_discs=4, backbone_kwargs={"cond_size": 768})
x = torch.randn(1,3,224,224)
t = torch.randint(0, 1, size=(1,))
cond = (None, torch.randn(1,77, 768), torch.ones(1,77, dtype=torch.bool))
y = discr(x, t, x, cond=cond)
print(y.shape)