CarlosMalaga's picture
Upload 201 files
2f044c1 verified
from typing import Any, Optional
import lightning
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
# from relik.reader.relik_reader_core import RelikReaderCoreModel
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
class RelikReaderPLModule(lightning.LightningModule):
def __init__(
self,
cfg: dict,
transformer_model: str,
additional_special_symbols: int,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
*args: Any,
**kwargs: Any
):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.relik_reader_core_model = RelikReaderForSpanExtraction(
transformer_model,
additional_special_symbols,
num_layers,
activation,
linears_hidden_size,
use_last_k_layers,
training=training,
)
self.optimizer_factory = None
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
relik_output = self.relik_reader_core_model(**batch)
self.log("train-loss", relik_output["loss"])
return relik_output["loss"]
def validation_step(
self, batch: dict, *args: Any, **kwargs: Any
) -> Optional[STEP_OUTPUT]:
return
def set_optimizer_factory(self, optimizer_factory) -> None:
self.optimizer_factory = optimizer_factory
def configure_optimizers(self) -> OptimizerLRScheduler:
return self.optimizer_factory(self.relik_reader_core_model)