File size: 448 Bytes
476803e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from model import Swin2SR
model_weights = "model-70.pt"
model_params = {
"upscale": 2,
"in_chans": 4,
"img_size": 64,
"window_size": 16,
"img_range": 1.,
"depths": [6, 6, 6, 6],
"embed_dim": 90,
"num_heads": [6, 6, 6, 6],
"mlp_ratio": 2,
"upsampler": "pixelshuffledirect",
"resi_connection": "1conv"
}
sr_model = Swin2SR(**model_params)
sr_model.load_state_dict(torch.load(model_weights))
|