""" @author:jishnuprakash """ # This file consists of constants, attributes and classes used for training import re import nltk import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup import pytorch_lightning as pl from pytorch_lightning.metrics.functional import auroc from nltk.stem import WordNetLemmatizer from nltk.corpus import stopwords from transformers import AutoTokenizer, AutoModel random_seed = 42 num_epochs = 10 batch = 1 threshold = 0.5 max_tokens = 512 clean_text = False # bert_model = "bert-base-uncased" bert_model = "nlpaueb/legal-bert-base-uncased" checkpoint_dir = "checkpoints" check_filename = "legal-full-data" earlystop_monitor = "val_loss" earlystop_patience = 2 lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6", "Article 8", "Article 9", "Article 10", "Article 11", "Article 14", "Article 1 of Protocol 1", "No Violation"] num_classes = len(lex_classes) #Stop words stop_words = stopwords.words("english") lemmatizer = WordNetLemmatizer() def preprocess_text(text, remove_stopwords, stop_words): """ Clean text """ text = text.lower() # remove special chars and numbers text = re.sub("[^A-Za-z]+", " ", text) # remove stopwords if remove_stopwords: # 1. tokenize tokens = nltk.word_tokenize(text) # 2. check if stopword tokens = [w for w in tokens if not w.lower() in stop_words] # 3. Lemmatize tokens = [lemmatizer.lemmatize(i) for i in tokens] # 4. join back together text = " ".join(tokens) # return text in lower case and stripped of whitespaces text = text.lower().strip() return text def preprocess_data(df, clean=False): """ Perform basic data preprocessing """ df = df[df['text'].map(len)>0] df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10]) df.dropna(inplace=True) if clean: df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1) return df class LexGlueDataset(Dataset): """ Lex GLUE Dataset as pytorch dataset """ def __init__(self, data, tokenizer, max_tokens=512): super().__init__() self.tokenizer = tokenizer self.data = data self.max_tokens = max_tokens def __len__(self): # return len(self.data) return self.data.__len__() def generateLabels(self, labels): out = [0] * num_classes for i in labels: out[i] = 1 return out def __getitem__(self, index): data_row = self.data.iloc[index] lex_text = data_row.text multi_labels = self.generateLabels(data_row.labels) encoding = self.tokenizer.encode_plus(lex_text, add_special_tokens=True, max_length=self.max_tokens, return_token_type_ids=False, padding="max_length", truncation=True, return_attention_mask=True, is_split_into_words=True, return_tensors='pt',) return dict(text = lex_text, input_ids = encoding["input_ids"].flatten(), attention_mask = encoding["attention_mask"].flatten(), labels = torch.FloatTensor(multi_labels)) class LexGlueDataModule(pl.LightningDataModule): """ Data module to load LexGlueDataset for training, validating and testing """ def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512): super().__init__() self.batch_size = batch_size self.train = train self.test = test self.tokenizer = tokenizer self.max_tokens = max_tokens def setup(self, stage=None): self.train_dataset = LexGlueDataset(self.train, self.tokenizer, self.max_tokens) self.test_dataset = LexGlueDataset(self.test, self.tokenizer, self.max_tokens) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,) def val_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size,) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size,) class LexGlueTagger(pl.LightningModule): """ Model and Training instance as LexGlueTagger class for Pytorch Lightning module """ def __init__(self, num_classes, training_steps=None, warmup_steps=None): super().__init__() self.bert = AutoModel.from_pretrained(bert_model, return_dict=True) self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes) self.training_steps = training_steps self.warmup_steps = warmup_steps self.criterion = nn.BCELoss() def forward(self, input_ids, attention_mask, labels=None): """ Forward pass """ output = self.bert(input_ids, attention_mask=attention_mask) output = self.classifier(output.pooler_output) output = torch.sigmoid(output) loss = 0 if labels is not None: loss = self.criterion(output, labels) return loss, output def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] loss, outputs = self(input_ids, attention_mask, labels) self.log("train_loss", loss, prog_bar=True, logger=True) return {"loss": loss, "predictions": outputs, "labels": labels} def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] loss, outputs = self(input_ids, attention_mask, labels) self.log("val_loss", loss, prog_bar=True, logger=True) return loss def test_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] loss, outputs = self(input_ids, attention_mask, labels) self.log("test_loss", loss, prog_bar=True, logger=True) return loss def training_epoch_end(self, outputs): labels = [] predictions = [] for output in outputs: for out_labels in output["labels"].detach().cpu(): labels.append(out_labels) for out_predictions in output["predictions"].detach().cpu(): predictions.append(out_predictions) labels = torch.stack(labels).int() predictions = torch.stack(predictions) for i, name in enumerate(lex_classes): class_roc_auc = auroc(predictions[:, i], labels[:, i]) self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch) def configure_optimizers(self): """ Optimizer and Learning rate scheduler """ optimizer = AdamW(self.parameters(), lr=2e-5) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.training_steps) return dict(optimizer=optimizer, lr_scheduler=dict(scheduler=scheduler, interval='step'))