File size: 5,810 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import hydra
import lightning as pl
import torch
from lightning.pytorch.trainer.states import RunningStage
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset

from relik.common.log import get_logger
from relik.retriever.data.base.datasets import BaseDataset

logger = get_logger(__name__)


STAGES_COMPATIBILITY_MAP = {
    "train": RunningStage.TRAINING,
    "val": RunningStage.VALIDATING,
    "test": RunningStage.TESTING,
}

DEFAULT_STAGES = {
    RunningStage.VALIDATING,
    RunningStage.TESTING,
    RunningStage.SANITY_CHECKING,
    RunningStage.PREDICTING,
}


class PredictionCallback(pl.Callback):
    def __init__(
        self,
        batch_size: int = 32,
        stages: Optional[Set[Union[str, RunningStage]]] = None,
        other_callbacks: Optional[
            Union[List[DictConfig], List["NLPTemplateCallback"]]
        ] = None,
        datasets: Optional[Union[DictConfig, BaseDataset]] = None,
        dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        *args,
        **kwargs,
    ):
        super().__init__()
        # parameters
        self.batch_size = batch_size
        self.datasets = datasets
        self.dataloaders = dataloaders

        # callback initialization
        if stages is None:
            stages = DEFAULT_STAGES

        # compatibily stuff
        stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages}
        self.stages = [RunningStage(stage) for stage in stages]
        self.other_callbacks = other_callbacks or []
        for i, callback in enumerate(self.other_callbacks):
            if isinstance(callback, DictConfig):
                self.other_callbacks[i] = hydra.utils.instantiate(
                    callback, _recursive_=False
                )

    @torch.no_grad()
    def __call__(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        *args,
        **kwargs,
    ) -> Any:
        # it should return the predictions
        raise NotImplementedError

    def on_validation_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ):
        predictions = self(trainer, pl_module)
        for callback in self.other_callbacks:
            callback(
                trainer=trainer,
                pl_module=pl_module,
                callback=self,
                predictions=predictions,
            )

    def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        predictions = self(trainer, pl_module)
        for callback in self.other_callbacks:
            callback(
                trainer=trainer,
                pl_module=pl_module,
                callback=self,
                predictions=predictions,
            )

    @staticmethod
    def _get_datasets_and_dataloaders(
        dataset: Optional[Union[Dataset, DictConfig]],
        dataloader: Optional[DataLoader],
        trainer: pl.Trainer,
        dataloader_kwargs: Optional[Dict[str, Any]] = None,
        collate_fn: Optional[Callable] = None,
        collate_fn_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[List[Dataset], List[DataLoader]]:
        """
        Get the datasets and dataloaders from the datamodule or from the dataset provided.

        Args:
            dataset (`Optional[Union[Dataset, DictConfig]]`):
                The dataset to use. If `None`, the datamodule is used.
            dataloader (`Optional[DataLoader]`):
                The dataloader to use. If `None`, the datamodule is used.
            trainer (`pl.Trainer`):
                The trainer that contains the datamodule.
            dataloader_kwargs (`Optional[Dict[str, Any]]`):
                The kwargs to pass to the dataloader.
            collate_fn (`Optional[Callable]`):
                The collate function to use.
            collate_fn_kwargs (`Optional[Dict[str, Any]]`):
                The kwargs to pass to the collate function.

        Returns:
            `Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders.
        """
        # if a dataset is provided, use it
        if dataset is not None:
            dataloader_kwargs = dataloader_kwargs or {}
            # get dataset
            if isinstance(dataset, DictConfig):
                dataset = hydra.utils.instantiate(dataset, _recursive_=False)
            datasets = [dataset] if not isinstance(dataset, list) else dataset
            if dataloader is not None:
                dataloaders = (
                    [dataloader] if isinstance(dataloader, DataLoader) else dataloader
                )
            else:
                collate_fn = collate_fn or partial(
                    datasets[0].collate_fn, **collate_fn_kwargs
                )
                dataloader_kwargs["collate_fn"] = collate_fn
                dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)]
        else:
            # get the dataloaders and datasets from the datamodule
            datasets = (
                trainer.datamodule.test_datasets
                if trainer.state.stage == RunningStage.TESTING
                else trainer.datamodule.val_datasets
            )
            dataloaders = (
                trainer.test_dataloaders
                if trainer.state.stage == RunningStage.TESTING
                else trainer.val_dataloaders
            )
        return datasets, dataloaders


class NLPTemplateCallback:
    def __call__(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        callback: PredictionCallback,
        predictions: Dict[str, Any],
        *args,
        **kwargs,
    ) -> Any:
        raise NotImplementedError