File size: 1,711 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Any, Optional

import lightning
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler

# from relik.reader.relik_reader_core import RelikReaderCoreModel
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction


class RelikReaderPLModule(lightning.LightningModule):
    def __init__(
        self,
        cfg: dict,
        transformer_model: str,
        additional_special_symbols: int,
        num_layers: Optional[int] = None,
        activation: str = "gelu",
        linears_hidden_size: Optional[int] = 512,
        use_last_k_layers: int = 1,
        training: bool = False,
        *args: Any,
        **kwargs: Any
    ):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()
        self.relik_reader_core_model = RelikReaderForSpanExtraction(
            transformer_model,
            additional_special_symbols,
            num_layers,
            activation,
            linears_hidden_size,
            use_last_k_layers,
            training=training,
        )
        self.optimizer_factory = None

    def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        relik_output = self.relik_reader_core_model(**batch)
        self.log("train-loss", relik_output["loss"])
        return relik_output["loss"]

    def validation_step(
        self, batch: dict, *args: Any, **kwargs: Any
    ) -> Optional[STEP_OUTPUT]:
        return

    def set_optimizer_factory(self, optimizer_factory) -> None:
        self.optimizer_factory = optimizer_factory

    def configure_optimizers(self) -> OptimizerLRScheduler:
        return self.optimizer_factory(self.relik_reader_core_model)