nsd_model / model.py
huzey's picture
Upload folder using huggingface_hub
905cc0d verified
#%%
from functools import partial
import logging
from einops import rearrange, repeat
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from config import AutoConfig
from backbone import (
build_backbone,
AdaLNLoRADiNOv2ViT,
)
from blocks import (
build_conv_blocks,
build_class_token_mlp,
DictConvBlocks,
ClassTokenMLPs,
)
from config_utils import load_from_yaml
from topyneck import (
build_coords_mlp,
CachedCoordsMLP,
build_voxelouts_weight,
CoordsMLPLinearWeight,
VoxelNonShareLinearWeight,
)
import numpy as np
class BrainEncodingModel(nn.Module):
def __init__(
self,
cfg: AutoConfig,
n_voxel_dict = {'subj01': 327684},
):
super().__init__()
self.subject_list = list(n_voxel_dict.keys())
assert len(self.subject_list) == 1, "Only one subject is supported"
self.layers = cfg.MODEL.BACKBONE.LAYERS
self.layers_small = cfg.MODEL.BACKBONE_SMALL.LAYERS
self.n_layers = len(self.layers)
r = cfg.MODEL.WIDTH_RATIO
cfg.MODEL.CONV_HEAD.WIDTH = int(cfg.MODEL.CONV_HEAD.WIDTH * r)
self.cfg = cfg
self.backbone: AdaLNLoRADiNOv2ViT = build_backbone(cfg)
self.conv_blocks: DictConvBlocks = build_conv_blocks(cfg)
self.cls_blocks: ClassTokenMLPs = build_class_token_mlp(cfg)
def build_each_subject(fn, subject_list):
return nn.ModuleDict({subject: fn() for subject in subject_list})
self.layer_selector: Dict[str, CachedCoordsMLP] = build_each_subject(
partial(
build_coords_mlp,
cfg=cfg,
in_dim=cfg.POSITION_ENCODING.IN_DIM,
out_dim=self.n_layers,
act_fn=partial(nn.Softmax, dim=-1),
),
self.subject_list,
)
self.retina_mapper: Dict[str, CachedCoordsMLP] = build_each_subject(
partial(
build_coords_mlp,
cfg=cfg,
in_dim=cfg.POSITION_ENCODING.IN_DIM,
out_dim=2,
act_fn=nn.Tanh,
),
self.subject_list,
)
self.mu_sigma = cfg.MODEL.RETINA_MAPPER.CONSTANT_SIGMA
# voxel-wise output
d_model = self.cfg.MODEL.CONV_HEAD.WIDTH
self.n_voxel_dict = n_voxel_dict
self.d_model = d_model
self.voxel_outs_weight: Dict[
str, Union[VoxelNonShareLinearWeight, CoordsMLPLinearWeight]
] = nn.ModuleDict(
{
subject: build_voxelouts_weight(cfg, self.n_voxel_dict[subject], self.d_model)
for subject in self.subject_list
}
)
self.coords : nn.Parameter = None
def forward(
self,
x: Tensor, # [B, C, H, W]
voxel_indices: Optional[Tensor] = None,
chunk_size=4096,
**kwargs,
):
coords = self.coords
subject = self.subject_list[0]
bsz = x.shape[0]
device = x.device
dtype = x.dtype
x_retina_grid, x_cls_dict = self.backbone.get_intermediate_layers(
x, n=self.layers, c=None
)
x_retina_grid = self.conv_blocks(x_retina_grid)
x_cls_dict = self.cls_blocks(x_cls_dict)
x_cls = torch.stack(list(x_cls_dict.values()), dim=-1) # [B, D, 4]
#############################
### voxel-wise prediction ###
#############################
# divide voxels into chunks to avoid OOM
n_voxels = coords.shape[0]
if voxel_indices is None or voxel_indices == ...:
voxel_indices = torch.arange(n_voxels, device=coords.device)
voxel_indices_chunks = torch.split(voxel_indices, chunk_size)
out_ys, reg_layers = [], []
for voxel_indices_chunk in voxel_indices_chunks:
out_y, reg_layer = self._forward_voxels(
x_retina_grid,
x_cls,
subject,
coords,
voxel_indices_chunk,
bsz,
device,
dtype
)
out_ys.append(out_y)
reg_layers.append(reg_layer)
out_y = torch.cat(out_ys, dim=1) # [B, N]
reg_layer = torch.cat(reg_layers, dim=0).mean() # [1]
# if self.training:
# return out_y, reg_layer
# else:
return out_y
def _forward_voxels(
self,
x_retina_grid: Dict[str, Tensor], # {layer: [B, D, H/k, W/k], ...}
x_cls: Tensor, # [B, D, 4]
subject: str,
coords: Tensor,
voxel_indices: Tensor,
bsz,
device,
dtype,
):
N = len(voxel_indices)
## Layer Selector
w_layer = self.layer_selector[subject](coords, voxel_indices) # [N, 4]
# regularization
def entropy(x):
return (x * x.log()).sum(dim=1)
if self.training and next(self.layer_selector.parameters()).requires_grad:
reg_layer = entropy(w_layer) # [N]
else:
reg_layer = torch.zeros_like(w_layer[:, 0]) # [N]
x_cls = repeat(x_cls, "b d l -> b n d l", n=1)
_w_layer = repeat(w_layer, "n l -> b n d l", b=1, d=1)
x_cls = (x_cls * _w_layer).sum(dim=-1) # [B, N, D]
## Retina Mapper
mu = self.retina_mapper[subject](coords, voxel_indices) # [N, 2]
mu = mu * (1 - self.mu_sigma)
if self.training:
norm = torch.normal(0, torch.ones_like(mu) * self.mu_sigma)
mu = mu + norm
bsz = x_cls.shape[0]
mu = repeat(mu, "n d -> b n d", b=bsz)
mu = rearrange(mu, "b n (d c) -> b n d c", d=1, c=2)
if self.cfg.EXPERIMENTAL.USE_LAYER_SELECTOR:
_w_layer = repeat(w_layer, "n l -> b n l", b=1)
x_retina = None # [B, N, D]
for i, layer in zip(range(self.n_layers), self.layers):
x = x_retina_grid[str(layer)]
_x_retina = F.grid_sample(
x,
mu,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
) # [B, C, N, D] (C=D_model, D=1, N=N_voxels)
_x_retina = rearrange(_x_retina, "b c n d -> b n (c d)")
if self.cfg.EXPERIMENTAL.USE_LAYER_SELECTOR:
_x_retina = _x_retina * _w_layer[:, :, i : i + 1]
if x_retina is None:
x_retina = _x_retina
else:
x_retina += _x_retina
# x_retina: [B, N, D]
x_y = x_retina + x_cls # [B, N, D] # T=0
w, b = self.voxel_outs_weight[subject](coords, voxel_indices) # [N, DDD], [N]
out_y = (x_y * w.unsqueeze(0)).mean(-1) + b.unsqueeze(0) # [B, N]
return out_y, reg_layer # [B, N], [N]
def _load_one_model(model_path: str, subject: str='subj01', cfg_path: str=None):
cfg = load_from_yaml(cfg_path)
# load model weights
sd = torch.load(model_path, map_location='cpu')
n_voxels = sd[f'model.voxel_outs_weight.{subject}.weight'].shape[0]
# create model
model = BrainEncodingModel(cfg, {subject: n_voxels})
# save voxel's coordinates to model
coords = sd[f'coord_dict.{subject}']
model.coords = nn.Parameter(coords)
# load weights
filtered_sd = {k: v for k, v in sd.items() if k.startswith('model')}
filtered_sd = {k[6:]: v for k, v in filtered_sd.items() if k.startswith('model')}
filtered_sd['coords'] = model.coords # add coordinates of voxels
model.load_state_dict(filtered_sd)
model = model.eval()
return model
class TowPartModel(nn.Module):
def __init__(self, model_part1, model_part2, part1_voxel_indices):
super().__init__()
self.model_part1 = model_part1
self.model_part2 = model_part2
self.part1_voxel_indices = part1_voxel_indices
def forward(self, x):
# x: [B, 3, 224, 224] # image after normalization
out1 = self.model_part1(x)
out2 = self.model_part2(x)
out = out2
out[:, self.part1_voxel_indices] = out1
return out
# %%
if __name__ == '__main__':
# model_path = "/nfscc/alg23/xalex_distill2/high/t826c6_00016_DATASET.SUBJECT_LIST=subj01,LOSS.DARK.MAX_EPOCH=90,/soup.pth"
subject = 'subj01'
cfg_path = "/workspace/model_packed2/config.yaml"
model_path1 = f"/workspace/model_packed2/ckpts/{subject}_part1.pth"
model_path2 = f"/workspace/model_packed2/ckpts/{subject}_part2.pth"
model1 = _load_one_model(model_path1, subject, cfg_path)
model2 = _load_one_model(model_path2, subject, cfg_path)
voxel_indices_path = "/workspace/model_packed2/ckpts/part1_voxel_indices.pt"
voxel_indices = torch.load(voxel_indices_path)[subject]
model = TowPartModel(model1, model2, voxel_indices)
x = torch.randn(1, 3, 224, 224)
x = x.cuda()
model = model.cuda()
out = model(x)
print(out.shape)