import torch from model import Swin2SR model_weights = "model-70.pt" model_params = { "upscale": 2, "in_chans": 4, "img_size": 64, "window_size": 16, "img_range": 1., "depths": [6, 6, 6, 6], "embed_dim": 90, "num_heads": [6, 6, 6, 6], "mlp_ratio": 2, "upsampler": "pixelshuffledirect", "resi_connection": "1conv" } sr_model = Swin2SR(**model_params) sr_model.load_state_dict(torch.load(model_weights))