File size: 448 Bytes
476803e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from model import Swin2SR

model_weights = "model-70.pt"
model_params = {
    "upscale": 2,
    "in_chans": 4,
    "img_size": 64,
    "window_size": 16,
    "img_range": 1.,
    "depths": [6, 6, 6, 6],
    "embed_dim": 90,
    "num_heads": [6, 6, 6, 6],
    "mlp_ratio": 2,
    "upsampler": "pixelshuffledirect",
    "resi_connection": "1conv"
}

sr_model = Swin2SR(**model_params)
sr_model.load_state_dict(torch.load(model_weights))