import torch | |
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN | |
def test_unetdiscriminatorsn(): | |
"""Test arch: UNetDiscriminatorSN.""" | |
# model init and forward (cpu) | |
net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True) | |
img = torch.rand((1, 3, 32, 32), dtype=torch.float32) | |
output = net(img) | |
assert output.shape == (1, 1, 32, 32) | |
# model init and forward (gpu) | |
if torch.cuda.is_available(): | |
net.cuda() | |
output = net(img.cuda()) | |
assert output.shape == (1, 1, 32, 32) | |