csaybar's picture
Upload 5 files
476803e verified
raw
history blame
448 Bytes
import torch
from model import Swin2SR
model_weights = "model-70.pt"
model_params = {
"upscale": 2,
"in_chans": 4,
"img_size": 64,
"window_size": 16,
"img_range": 1.,
"depths": [6, 6, 6, 6],
"embed_dim": 90,
"num_heads": [6, 6, 6, 6],
"mlp_ratio": 2,
"upsampler": "pixelshuffledirect",
"resi_connection": "1conv"
}
sr_model = Swin2SR(**model_params)
sr_model.load_state_dict(torch.load(model_weights))