import torch | |
from model import Swin2SR | |
model_weights = "model-70.pt" | |
model_params = { | |
"upscale": 2, | |
"in_chans": 4, | |
"img_size": 64, | |
"window_size": 16, | |
"img_range": 1., | |
"depths": [6, 6, 6, 6], | |
"embed_dim": 90, | |
"num_heads": [6, 6, 6, 6], | |
"mlp_ratio": 2, | |
"upsampler": "pixelshuffledirect", | |
"resi_connection": "1conv" | |
} | |
sr_model = Swin2SR(**model_params) | |
sr_model.load_state_dict(torch.load(model_weights)) | |