File size: 2,001 Bytes
6fc43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from abc import ABC
from abc import abstractmethod
from typing import Any, Type
from functools import wraps
from torch.utils.data import DataLoader
from torch import set_grad_enabled
import torch
Tensor = Type[torch.Tensor]

from ..utils.misc import convert_args_kwargs_to_kwargs
from ..utils import TransformerTestingDataset
from ..model import ADRDModel

class BaseExplainer:
    """ ... """
    def __init__(self, model: ADRDModel) -> None:
        """ ... """
        self.model = model

    def shap_values(self, 
        x,
        is_embedding: dict[str, bool] | None = None,
    ):
        """ ... """
        # result placeholder
        phi = [
            {
                tgt_k: {
                    src_k: 0.0 for src_k in self.model.src_modalities
                } for tgt_k in self.model.tgt_modalities
            }
        ]

        # set nn to eval mode
        set_grad_enabled(False)
        self.model.net_.eval()

        # initialize dataset and dataloader object
        dat = TransformerTestingDataset(x, self.model.src_modalities, is_embedding)
        ldr = DataLoader(
            dataset = dat,
            batch_size = 1,
            shuffle = False,
            drop_last = False,
            num_workers = 0,
            collate_fn = TransformerTestingDataset.collate_fn,
        )

        # loop through instances and compute shap values
        for idx, (smp, mask) in enumerate(ldr):
            mask_flat = torch.concatenate(list(mask.values()))
            if torch.logical_not(mask_flat).sum().item() == 0:
                pass
            elif torch.logical_not(mask_flat).sum().item() == 1:
                pass
            else:
                self._shap_values_core(smp, mask, phi[idx], is_embedding)

        return phi

    @abstractmethod
    def _shap_values_core(self,
        smp: dict[str, Tensor], 
        mask: dict[str, Tensor],
        phi_: dict[str, dict[str, float]],
    ):
        """ To implement different algorithms. """
        pass