File size: 18,605 Bytes
92d3638
 
 
 
 
85e192b
92d3638
 
 
 
 
 
 
 
85e192b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92d3638
 
 
 
 
 
 
 
 
 
 
85e192b
92d3638
 
 
 
 
 
 
 
 
85e192b
92d3638
 
 
 
 
 
85e192b
92d3638
85e192b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92d3638
 
 
 
 
 
 
 
 
 
85e192b
92d3638
 
 
 
 
 
 
 
 
 
85e192b
 
92d3638
 
 
 
 
 
 
 
 
 
 
85e192b
92d3638
 
 
85e192b
 
92d3638
 
 
 
 
 
 
 
85e192b
92d3638
 
85e192b
92d3638
 
 
 
 
85e192b
 
92d3638
 
 
 
 
 
85e192b
 
92d3638
 
85e192b
92d3638
 
e9cdb01
92d3638
 
 
 
 
 
 
c686e91
 
 
92d3638
 
 
 
 
 
 
 
 
c686e91
 
92d3638
 
 
85e192b
92d3638
85e192b
92d3638
e9cdb01
92d3638
 
 
 
85e192b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92d3638
c686e91
85e192b
 
 
 
e9cdb01
c686e91
9637da5
85e192b
9637da5
85e192b
e9cdb01
c686e91
92d3638
 
 
 
85e192b
 
 
 
 
 
 
 
 
 
 
92d3638
 
 
 
 
e9cdb01
 
 
92d3638
 
 
 
c686e91
92d3638
 
 
 
 
 
 
c686e91
 
92d3638
 
85e192b
 
 
92d3638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import copy
from typing import Any, Dict

import torch
from torch import Tensor, nn
from torch.nn.functional import one_hot
from transformers import AutoConfig, AutoModel, PretrainedConfig

from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder
from fim.models.utils import create_matrix_from_off_diagonal, create_padding_mask, get_off_diagonal_elements
from fim.utils.helper import create_class_instance


class FIMMJPConfig(PretrainedConfig):
    """
    FIMMJPConfig is a configuration class for the FIMMJP model.
    Attributes:
        model_type (str): The type of the model, default is "fimmjp".
        n_states (int): Number of states in the model. Default is 2.
        use_adjacency_matrix (bool): Whether to use an adjacency matrix. Default is False.
        ts_encoder (dict): Configuration for the time series encoder. Default is None.
        pos_encodings (dict): Configuration for the positional encodings. Default is None.
        path_attention (dict): Configuration for the path attention mechanism. Default is None.
        intensity_matrix_decoder (dict): Configuration for the intensity matrix decoder. Default is None.
        initial_distribution_decoder (dict): Configuration for the initial distribution decoder. Default is None.
        use_num_of_paths (bool): Whether to use the number of paths. Default is True.
    Args:
        n_states (int, optional): Number of states in the model. Default is 2.
        use_adjacency_matrix (bool, optional): Whether to use an adjacency matrix. Default is False.
        ts_encoder (dict, optional): Configuration for the time series encoder. Default is None.
        pos_encodings (dict, optional): Configuration for the positional encodings. Default is None.
        path_attention (dict, optional): Configuration for the path attention mechanism. Default is None.
        intensity_matrix_decoder (dict, optional): Configuration for the intensity matrix decoder. Default is None.
        initial_distribution_decoder (dict, optional): Configuration for the initial distribution decoder. Default is None.
        use_num_of_paths (bool, optional): Whether to use the number of paths. Default is True.
        **kwargs: Additional keyword arguments.
    """

    model_type = "fimmjp"

    def __init__(
        self,
        n_states: int = 2,
        use_adjacency_matrix: bool = False,
        ts_encoder: dict = None,
        pos_encodings: dict = None,
        path_attention: dict = None,
        intensity_matrix_decoder: dict = None,
        initial_distribution_decoder: dict = None,
        use_num_of_paths: bool = True,
        **kwargs,
    ):
        self.n_states = n_states
        self.use_adjacency_matrix = use_adjacency_matrix
        self.ts_encoder = ts_encoder
        self.pos_encodings = pos_encodings
        self.path_attention = path_attention
        self.intensity_matrix_decoder = intensity_matrix_decoder
        self.initial_distribution_decoder = initial_distribution_decoder
        self.use_num_of_paths = use_num_of_paths

        super().__init__(**kwargs)


class FIMMJP(AModel):
    """
    **FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes**

    This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs)
    on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper:

    Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces.
    These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial.
    In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs),
    on bounded state spaces, from noisy and sparse observations, which consists of two components.

    First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms,
    with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that
    processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target
    MJP in a supervised way.

    We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion,
    hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe
    *(i) discrete flashing ratchet systems*, which are a type of Brownian motors, and the conformational dynamics in
    *(ii) molecular simulations*, *(iii) experimental ion channel data* and *(iv) simple protein folding models*.
    What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets.

    It is model from the paper: **"Foundation Inference Models for Markov Jump Processes"** --- https://arxiv.org/abs/2406.06419.
    **Attributes:**
        n_states (int): Number of states in the Markov jump process.
        use_adjacency_matrix (bool): Whether to use an adjacency matrix.
        ts_encoder (dict | TransformerEncoder): Time series encoder.
        pos_encodings (dict | SineTimeEncoding): Positional encodings.
        path_attention (dict | nn.Module): Path attention mechanism.
        intensity_matrix_decoder (dict | nn.Module): Decoder for the intensity matrix.
        initial_distribution_decoder (dict | nn.Module): Decoder for the initial distribution.
        gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss.
        init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution.

    **Methods:**
        forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
            Forward pass of the model.
        __decode(h: Tensor) -> tuple[Tensor, Tensor]:
            Decode the hidden representation to obtain the intensity matrix and initial condition.
        __encode(x: Tensor, obs_grid_normalized: Tensor, obs_values_one_hot: Tensor) -> Tensor:
            Encode the input observations to obtain the hidden representation.
        __denormalize_offdiag_mean_logvar(norm_constants: Tensor, pred_offdiag_im_mean_logvar: Tensor) -> tuple[Tensor, Tensor]:
            Denormalize the predicted off-diagonal mean and log-variance.
        __normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]:
            Normalize the observation grid.
        loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor,
            adjaceny_matrix: Tensor, normalization_constants: Tensor, schedulers: dict = None, step: int = None) -> dict:
            Compute the loss for the model.
        new_stats() -> dict:
            Initialize new statistics.
        metric(y: Any, y_target: Any) -> Dict:
            Compute the metric for the model.
    """

    config_class = FIMMJPConfig

    def __init__(self, config: FIMMJPConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.total_offdiagonal_transitions = self.config.n_states**2 - self.config.n_states
        self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none")
        self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none")

        self.__create_modules()

    def __create_modules(self):
        pos_encodings = copy.deepcopy(self.config.pos_encodings)
        ts_encoder = copy.deepcopy(self.config.ts_encoder)
        path_attention = copy.deepcopy(self.config.path_attention)
        intensity_matrix_decoder = copy.deepcopy(self.config.intensity_matrix_decoder)
        initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder)

        if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder":
            pos_encodings["out_features"] -= self.config.n_states
        self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings)

        ts_encoder["in_features"] = self.config.n_states + self.pos_encodings.out_features
        self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder)

        self.path_attention = create_class_instance(path_attention.pop("name"), path_attention)

        in_features = intensity_matrix_decoder.get(
            "in_features",
            self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.config.use_adjacency_matrix else 1),
        )
        intensity_matrix_decoder["in_features"] = in_features
        intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions
        self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder)

        in_features = initial_distribution_decoder.get(
            "in_features",
            self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.config.use_adjacency_matrix else 1),
        )
        initial_distribution_decoder["in_features"] = in_features
        initial_distribution_decoder["out_features"] = self.config.n_states
        self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)

    def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
        """
        Forward pass for the model.

        Args:
            x (dict[str, Tensor]): A dictionary containing the input tensors:
                - "observation_grid": Tensor representing the observation grid.
                - "observation_values": Tensor representing the observation values.
                - "seq_lengths": Tensor representing the sequence lengths.
                - Optional keys:
                    - "time_normalization_factors": Tensor representing the time normalization factors.
                - Optional keys for loss calculation:
                    - "intensity_matrices": Tensor representing the intensity matrices.
                    - "initial_distributions": Tensor representing the initial distributions.
                    - "adjacency_matrices": Tensor representing the adjacency matrices.
            schedulers (dict, optional): A dictionary of schedulers for the training process. Default is None.
            step (int, optional): The current step in the training process. Default is None.
        Returns:
            dict: A dictionary containing the following keys:
                - "im": Tensor representing the intensity matrix.
                - "intensity_matrices_variance": Tensor representing the log variance of the intensity matrix.
                - "initial_condition": Tensor representing the initial conditions.
                - "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
        """

        norm_constants = self.__normalize_observation_grid(x)

        x["observation_values_one_hot"] = one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.config.n_states)

        h = self.__encode(x)
        pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)

        pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)

        out = self.__prepare_output(n_states, init_cond, pred_offdiag_im_mean, pred_offdiag_im_logvar)
        self.__calculate_train_loss_if_targe_exists(
            x, schedulers, step, norm_constants, init_cond, pred_offdiag_im_mean, pred_offdiag_im_logvar, out
        )

        return out

    def __calculate_train_loss_if_targe_exists(
        self,
        x: dict[str, Tensor],
        schedulers: dict,
        step: int,
        norm_constants: Tensor,
        init_cond: Tensor,
        pred_offdiag_im_mean: Tensor,
        pred_offdiag_im_logvar: Tensor,
        out: dict,
    ):
        if "intensity_matrices" in x and "initial_distributions" in x:
            out["losses"] = self.loss(
                pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step
            )

    def __prepare_output(self, n_states: int, init_cond: Tensor, pred_offdiag_im_mean: Tensor, pred_offdiag_im_logvar: Tensor) -> dict:
        out = {
            "intensity_matrices": create_matrix_from_off_diagonal(
                pred_offdiag_im_mean,
                self.config.n_states,
                mode="negative_sum_row",
                n_states=self.config.n_states if n_states is None else n_states,
            ),
            "intensity_matrices_variance": create_matrix_from_off_diagonal(
                torch.exp(pred_offdiag_im_logvar),
                self.config.n_states,
                mode="negative_sum_row",
                n_states=self.config.n_states if n_states is None else n_states,
            ),
            "initial_condition": init_cond,
        }

        return out

    def __normalize_observation_grid(self, x: dict[str, Tensor]) -> Tensor:
        obs_grid = x["observation_grid"]
        if "time_normalization_factors" not in x:
            norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
            x["time_normalization_factors"] = norm_constants
            x["observation_grid_normalized"] = obs_grid
        else:
            norm_constants = x["time_normalization_factors"]
            x["observation_grid_normalized"] = obs_grid
        return norm_constants

    def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]:
        pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h)
        init_cond = self.initial_distribution_decoder(h)
        return pred_offdiag_logmean_logstd, init_cond

    def __encode(self, x: dict[str, Tensor]) -> Tensor:
        obs_grid_normalized = x["observation_grid_normalized"]
        obs_values_one_hot = x["observation_values_one_hot"]
        B, P, L = obs_grid_normalized.shape[:3]
        pos_enc = self.pos_encodings(obs_grid_normalized)
        path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
        if isinstance(self.ts_encoder, TransformerEncoder):
            padding_mask = create_padding_mask(x["seq_lengths"].view(B * P), L)
            padding_mask[:, 0] = True
            h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1)
            if isinstance(self.path_attention, nn.MultiheadAttention):
                h = self.path_attention(h, h, h)[0][:, -1]
            else:
                h = self.path_attention(h, h, h)
        elif isinstance(self.ts_encoder, RNNEncoder):
            h = self.ts_encoder(path.view(B * P, L, -1), x["seq_lengths"].view(B * P))
            last_observation = x["seq_lengths"].view(B * P) - 1
            h = h[torch.arange(B * P), last_observation].view(B, P, -1)
            h = self.path_attention(h, h, h)
        if self.config.use_num_of_paths:
            h = torch.cat([h, torch.ones(B, 1).to(h.device) / 100.0 * P], dim=-1)
        if self.config.use_adjacency_matrix:
            h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1)
        return h

    def __denormalize_offdiag_mean_logstd(self, norm_constants: Tensor, pred_offdiag_im_logmean_logstd: Tensor) -> tuple[Tensor, Tensor]:
        pred_offdiag_im_logmean, pred_offdiag_im_logstd = pred_offdiag_im_logmean_logstd.chunk(2, dim=-1)
        pred_offdiag_im_mean = torch.exp(pred_offdiag_im_logmean) / norm_constants.view(-1, 1)
        pred_offdiag_im_logstd = pred_offdiag_im_logstd - torch.log(norm_constants.view(-1, 1))
        return pred_offdiag_im_mean, pred_offdiag_im_logstd

    def __normalize_obs_grid(self, obs_grid: Tensor) -> tuple[Tensor, Tensor]:
        norm_constants = obs_grid.amax(dim=[-3, -2, -1])
        obs_grid_normalized = obs_grid / norm_constants.view(-1, 1, 1, 1)
        return norm_constants, obs_grid_normalized

    def loss(
        self,
        pred_im: Tensor,
        pred_logstd_im: Tensor,
        pred_init_cond: Tensor,
        target: dict,
        normalization_constants: Tensor,
        schedulers: dict = None,
        step: int = None,
    ) -> dict:
        target_im = target["intensity_matrices"]
        target_init_cond = target["initial_distributions"]
        adjaceny_matrix = target["adjacency_matrices"]
        target_mean = get_off_diagonal_elements(target_im)
        P = target["observation_grid"].shape[1]
        adjaceny_matrix = get_off_diagonal_elements(adjaceny_matrix)
        target_init_cond = torch.argmax(target_init_cond, dim=-1).long()
        pred_im_std = torch.exp(pred_logstd_im)
        loss_gauss = adjaceny_matrix * self.gaussian_nll(pred_im, target_mean, torch.pow(pred_im_std, 2))
        loss_gauss = loss_gauss.sum() / (adjaceny_matrix.sum() + 1e-8)
        loss_initial = self.init_cross_entropy(pred_init_cond, target_init_cond).mean()
        zero_entries = 1.0 - adjaceny_matrix
        loss_missing_link = normalization_constants * zero_entries * (torch.pow(pred_im, 2) + torch.pow(pred_im_std, 2))
        loss_missing_link = loss_missing_link.sum() / (zero_entries.sum() + 1e-8)
        rmse_loss = torch.sqrt(torch.mean((target_mean - pred_im) ** 2))

        gaus_cons = schedulers.get("gauss_nll")(step) if schedulers else torch.tensor(1.0)
        init_cons = schedulers.get("init_cross_entropy")(step) if schedulers else torch.tensor(1.0)
        missing_link_cons = schedulers.get("missing_link")(step) if schedulers else torch.tensor(1.0)
        gaus_cons = gaus_cons.to(self.device)
        init_cons = init_cons.to(self.device)
        missing_link_cons = missing_link_cons.to(self.device)

        loss = gaus_cons * loss_gauss + init_cons * loss_initial + missing_link_cons * loss_missing_link
        # loss = rmse_loss
        return {
            "loss": loss,
            "loss_gauss": loss_gauss,
            "loss_initial": loss_initial,
            "loss_missing_link": loss_missing_link,
            "rmse_loss": rmse_loss,
            "beta_gauss_nll": gaus_cons,
            "beta_init_cross_entropy": init_cons,
            "beta_missing_link": missing_link_cons,
            "number_of_paths": torch.tensor(P, device=self.device),
        }

    def metric(self, y: Any, y_target: Any) -> Dict:
        return super().metric(y, y_target)


ModelFactory.register(FIMMJPConfig.model_type, FIMMJP)
AutoConfig.register(FIMMJPConfig.model_type, FIMMJPConfig)
AutoModel.register(FIMMJPConfig, FIMMJP)