File size: 2,125 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import Any, Optional
import lightning
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
class RelikReaderREPLModule(lightning.LightningModule):
def __init__(
self,
cfg: dict,
transformer_model: str,
additional_special_symbols: int,
additional_special_symbols_types: Optional[int] = 0,
entity_type_loss: bool = None,
add_entity_embedding: bool = None,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
*args: Any,
**kwargs: Any
):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.relik_reader_re_model = RelikReaderForTripletExtraction(
transformer_model,
additional_special_symbols,
additional_special_symbols_types,
entity_type_loss,
add_entity_embedding,
num_layers,
activation,
linears_hidden_size,
use_last_k_layers,
training=training,
**kwargs,
)
self.optimizer_factory = None
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
relik_output = self.relik_reader_re_model(**batch)
self.log("train-loss", relik_output["loss"])
self.log("train-start_loss", relik_output["ned_start_loss"])
self.log("train-end_loss", relik_output["ned_end_loss"])
self.log("train-relation_loss", relik_output["re_loss"])
return relik_output["loss"]
def validation_step(
self, batch: dict, *args: Any, **kwargs: Any
) -> Optional[STEP_OUTPUT]:
return
def set_optimizer_factory(self, optimizer_factory) -> None:
self.optimizer_factory = optimizer_factory
def configure_optimizers(self) -> OptimizerLRScheduler:
return self.optimizer_factory(self.relik_reader_re_model)
|