|
from model import _load_one_model, TowPartModel, BrainEncodingModel |
|
from config_utils import load_from_yaml |
|
import torch |
|
|
|
subject = 'subj01' |
|
cfg_path = "/workspace/model_packed2/config.yaml" |
|
model_path1 = f"/workspace/model_packed2/ckpts/{subject}_part1.pth" |
|
model_path2 = f"/workspace/model_packed2/ckpts/{subject}_part2.pth" |
|
|
|
|
|
model1: BrainEncodingModel = _load_one_model(model_path1, subject, cfg_path) |
|
model2: BrainEncodingModel = _load_one_model(model_path2, subject, cfg_path) |
|
|
|
voxel_indices_path = "/workspace/model_packed2/ckpts/part1_voxel_indices.pt" |
|
voxel_indices = torch.load(voxel_indices_path)[subject] |
|
model = TowPartModel(model1, model2, voxel_indices) |
|
|
|
model = model.cuda().eval() |
|
|
|
|
|
x = torch.randn(1, 3, 224, 224) |
|
def transform_image(x): |
|
means = [0.485, 0.456, 0.406] |
|
stds = [0.229, 0.224, 0.225] |
|
x = (x - torch.tensor(means).view(1, 3, 1, 1)) / torch.tensor(stds).view(1, 3, 1, 1) |
|
return x |
|
x = transform_image(x) |
|
x = x.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
out = model(x) |
|
print(out.shape) |
|
|