File size: 1,438 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import logging
from abc import abstractmethod
from typing import Any, Dict
import torch
from torch import nn
logger = logging.getLogger(__name__)
class BaseNLPAug(nn.Module):
"""Base class for NLP augmentation"""
def __init__(self, cfg: Any):
"""
Args:
cfg: config with all the hyperparameters
"""
super().__init__()
self.cfg = cfg
@abstractmethod
def forward(self, batch: Dict) -> Dict:
"""Augmenting
Args:
batch: current batch
Returns:
augmented batch
"""
if self.cfg.augmentation.token_mask_probability > 0:
input_ids = batch["input_ids"].clone()
# special_mask = ~batch["special_tokens_mask"].clone().bool()
mask = (
torch.bernoulli(
torch.full(
input_ids.shape,
float(self.cfg.augmentation.token_mask_probability),
)
)
.to(input_ids.device)
.bool()
# & special_mask
).bool()
input_ids[mask] = self.cfg._tokenizer_mask_token_id
batch["input_ids"] = input_ids.clone()
batch["attention_mask"][mask] = 0
if batch["labels"].shape[1] == batch["input_ids"].shape[1]:
batch["labels"][mask] = -100
return batch
|