CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
2.13 kB
from typing import Any, Optional
import lightning
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
class RelikReaderREPLModule(lightning.LightningModule):
def __init__(
self,
cfg: dict,
transformer_model: str,
additional_special_symbols: int,
additional_special_symbols_types: Optional[int] = 0,
entity_type_loss: bool = None,
add_entity_embedding: bool = None,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
*args: Any,
**kwargs: Any
):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.relik_reader_re_model = RelikReaderForTripletExtraction(
transformer_model,
additional_special_symbols,
additional_special_symbols_types,
entity_type_loss,
add_entity_embedding,
num_layers,
activation,
linears_hidden_size,
use_last_k_layers,
training=training,
**kwargs,
)
self.optimizer_factory = None
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
relik_output = self.relik_reader_re_model(**batch)
self.log("train-loss", relik_output["loss"])
self.log("train-start_loss", relik_output["ned_start_loss"])
self.log("train-end_loss", relik_output["ned_end_loss"])
self.log("train-relation_loss", relik_output["re_loss"])
return relik_output["loss"]
def validation_step(
self, batch: dict, *args: Any, **kwargs: Any
) -> Optional[STEP_OUTPUT]:
return
def set_optimizer_factory(self, optimizer_factory) -> None:
self.optimizer_factory = optimizer_factory
def configure_optimizers(self) -> OptimizerLRScheduler:
return self.optimizer_factory(self.relik_reader_re_model)