File size: 4,987 Bytes
6be1ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import yaml
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.data.paired_image_dataset import PairedImageDataset
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss

from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
from realesrgan.models.realesrgan_model import RealESRGANModel
from realesrgan.models.realesrnet_model import RealESRNetModel


def test_realesrnet_model():
    with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
        opt = yaml.load(f, Loader=yaml.FullLoader)

    # build model
    model = RealESRNetModel(opt)
    # test attributes
    assert model.__class__.__name__ == 'RealESRNetModel'
    assert isinstance(model.net_g, RRDBNet)
    assert isinstance(model.cri_pix, L1Loss)
    assert isinstance(model.optimizers[0], torch.optim.Adam)

    # prepare data
    gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
    kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
    kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
    sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
    data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
    model.feed_data(data)
    # check dequeue
    model.feed_data(data)
    # check data shape
    assert model.lq.shape == (1, 3, 8, 8)
    assert model.gt.shape == (1, 3, 32, 32)

    # change probability to test if-else
    model.opt['gaussian_noise_prob'] = 0
    model.opt['gray_noise_prob'] = 0
    model.opt['second_blur_prob'] = 0
    model.opt['gaussian_noise_prob2'] = 0
    model.opt['gray_noise_prob2'] = 0
    model.feed_data(data)
    # check data shape
    assert model.lq.shape == (1, 3, 8, 8)
    assert model.gt.shape == (1, 3, 32, 32)

    # ----------------- test nondist_validation -------------------- #
    # construct dataloader
    dataset_opt = dict(
        name='Demo',
        dataroot_gt='tests/data/gt',
        dataroot_lq='tests/data/lq',
        io_backend=dict(type='disk'),
        scale=4,
        phase='val')
    dataset = PairedImageDataset(dataset_opt)
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
    assert model.is_train is True
    model.nondist_validation(dataloader, 1, None, False)
    assert model.is_train is True


def test_realesrgan_model():
    with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
        opt = yaml.load(f, Loader=yaml.FullLoader)

    # build model
    model = RealESRGANModel(opt)
    # test attributes
    assert model.__class__.__name__ == 'RealESRGANModel'
    assert isinstance(model.net_g, RRDBNet)  # generator
    assert isinstance(model.net_d, UNetDiscriminatorSN)  # discriminator
    assert isinstance(model.cri_pix, L1Loss)
    assert isinstance(model.cri_perceptual, PerceptualLoss)
    assert isinstance(model.cri_gan, GANLoss)
    assert isinstance(model.optimizers[0], torch.optim.Adam)
    assert isinstance(model.optimizers[1], torch.optim.Adam)

    # prepare data
    gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
    kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
    kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
    sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
    data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
    model.feed_data(data)
    # check dequeue
    model.feed_data(data)
    # check data shape
    assert model.lq.shape == (1, 3, 8, 8)
    assert model.gt.shape == (1, 3, 32, 32)

    # change probability to test if-else
    model.opt['gaussian_noise_prob'] = 0
    model.opt['gray_noise_prob'] = 0
    model.opt['second_blur_prob'] = 0
    model.opt['gaussian_noise_prob2'] = 0
    model.opt['gray_noise_prob2'] = 0
    model.feed_data(data)
    # check data shape
    assert model.lq.shape == (1, 3, 8, 8)
    assert model.gt.shape == (1, 3, 32, 32)

    # ----------------- test nondist_validation -------------------- #
    # construct dataloader
    dataset_opt = dict(
        name='Demo',
        dataroot_gt='tests/data/gt',
        dataroot_lq='tests/data/lq',
        io_backend=dict(type='disk'),
        scale=4,
        phase='val')
    dataset = PairedImageDataset(dataset_opt)
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
    assert model.is_train is True
    model.nondist_validation(dataloader, 1, None, False)
    assert model.is_train is True

    # ----------------- test optimize_parameters -------------------- #
    model.feed_data(data)
    model.optimize_parameters(1)
    assert model.output.shape == (1, 3, 32, 32)
    assert isinstance(model.log_dict, dict)
    # check returned keys
    expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
    assert set(expected_keys).issubset(set(model.log_dict.keys()))