In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_outputs import (
 Seq2SeqQuestionAnsweringModelOutput,
 Seq2SeqSequenceClassifierOutput,
 BaseModelOutput,
)
from transformers import (
 T5ForQuestionAnswering,
 T5PreTrainedModel,
 MBartPreTrainedModel,
 MBartModel,
 T5Config,
 T5Model,
 T5EncoderModel,
 get_scheduler
)
from tqdm import tqdm 
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import random
import os 
from datetime import datetime
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

In [2]:
import json
import yaml
from addict import Dict


def load_json(file_path):
 with open(file_path, "r", encoding="utf-8-sig") as f:
 data = json.load(f)
 return data


def read_config(path):
 # read yaml and return contents
 with open(path, "r") as file:
 try:
 return Dict(yaml.safe_load(file))
 except yaml.YAMLError as exc:
 print(exc)


def batch_to_device(batch: dict, device: str):
 for k in batch:
 batch[k] = batch[k].to(device)
 return batch


def save_json(obj, path):
 with open(path, "w") as outfile:
 json.dump(obj, outfile, ensure_ascii=False, indent=2)


In [None]:
@dataclass
class TokenClassificationOutput:
 loss: Optional[torch.FloatTensor] = None
 sent_loss: Optional[torch.FloatTensor] = None
 token_loss: Optional[torch.FloatTensor] = None
 claim_logits: torch.FloatTensor = None
 evidence_logits: torch.FloatTensor = None


In [None]:
def random_seed(value):
 torch.backends.cudnn.deterministic = True
 torch.manual_seed(value)
 torch.cuda.manual_seed(value)
 np.random.seed(value)
 random.seed(value)

In [None]:
@dataclass 
class TrainingArguments:
 data_path = "data/ise-dsc01-train.json"
 model_name = "VietAI/vit5-base"
 tokenizer_name = "VietAI/vit5-base"
 gradient_accumulation_steps = 8
 gradient_checkpointing = False
 num_epochs = 10
 lr = 3.0e-5
 weight_decay = 1.0e-2
 scheduler_name = "cosine"
 warmup_steps = 0
 patience = 3
 max_seq_length = 1024
 seed = 1401
 test_size = 0.1
 train_batch_size = 1
 val_batch_size = 1

 save_best = True

 freeze_backbone = False
 freeze_encoder = False
 freeze_decoder = False

training_args = TrainingArguments()

In [None]:
_LABEL_MAPPING = {"SUPPORTED": 0, "NEI": 1, "REFUTED": 2}
 
class TokenStanceDataset(Dataset):
 def __init__(self, dataset, dataset_keys, tokenizer, max_seq_length=1024) -> None:
 super().__init__()
 self.tokenizer = tokenizer
 self.max_seq_length = max_seq_length
 self.dataset = dataset
 self.dataset_keys = dataset_keys

 def __getitem__(self, idx):
 data_id = self.dataset_keys[idx]
 data_item = self.dataset[data_id]
 
 claim = data_item['claim']
 evidence = data_item['evidence']
 context = data_item['context']
 
 encodings = self.tokenizer(
 context, 
 claim,
 truncation=True, 
 padding="max_length", 
 max_length=self.max_seq_length, 
 return_tensors="pt"
 )
 
 if evidence is None:
 start_position, end_position = 0, 0
 else:
 start_idx = context.find(evidence)
 end_idx = start_idx + len(evidence)
 
 evidence_start = start_idx
 evidence_end = end_idx
 
 if context[start_idx: end_idx] == evidence:
 evidence_end = end_idx
 else:
 for n in [1, 2]:
 if context[start_idx-n: end_idx-n] == evidence:
 evidence_start = start_idx - n
 evidence_end = end_idx - n
 
 if evidence_start < 0:
 evidence_start = 0
 
 if evidence_end < 0:
 evidence_end = 0
 
 start_position = encodings.char_to_token(0, evidence_start)
 end_position = encodings.char_to_token(0, evidence_end)
 
 trace_back = 1
 while end_position is None:
 end_position = encodings.char_to_token(0, evidence_end-trace_back)
 trace_back += 1
 
 if start_position is None:
 start_position = 0
 end_position = 0
 
 evidence_labels = torch.zeros(self.max_seq_length,)
 if end_position > 0:
 evidence_labels[start_position: end_position] = 1
 evidence_labels = evidence_labels.long()
 
 #print("====")
 #print(evidence)
 #print(self.tokenizer.decode(encodings.input_ids[0][evidence_labels.bool()]))
 
 label = torch.tensor(_LABEL_MAPPING[data_item["verdict"]], dtype=torch.long)
 
 return {
 "input_ids": encodings.input_ids.squeeze(0),
 "attention_mask": encodings.attention_mask.squeeze(0),
 "evidence_labels": evidence_labels,
 "labels": label
 }

 def __len__(self):
 return len(self.dataset)
 
 

In [None]:
random_seed(training_args.seed)

data = load_json(training_args.data_path)

data_keys = list(data.keys())

train_keys, dev_keys = train_test_split(
 data_keys,
 test_size=training_args.test_size,
 random_state=training_args.seed,
 shuffle=True,
)

train_set = {k: v for k, v in data.items() if k in train_keys}
dev_set = {k: v for k, v in data.items() if k in dev_keys}

tokenizer = AutoTokenizer.from_pretrained(
 training_args.tokenizer_name, use_fast=True
)

train_dataset = TokenStanceDataset(
 train_set, train_keys, tokenizer, training_args.max_seq_length
)
val_dataset = TokenStanceDataset(
 dev_set, dev_keys, tokenizer, training_args.max_seq_length
)

train_dataloader = DataLoader(
 train_dataset, batch_size=training_args.train_batch_size, shuffle=True
)
val_dataloader = DataLoader(
 val_dataset, batch_size=training_args.val_batch_size, shuffle=False
)


In [None]:
class T5FeedForwardHead(nn.Module):
 """Head for sentence-level classification tasks."""

 def __init__(self, config, out_dim):
 super().__init__()
 self.dense = nn.Linear(config.d_model, config.d_model)
 self.dropout = nn.Dropout(p=config.classifier_dropout)
 self.out_proj = nn.Linear(config.d_model, out_dim)

 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 hidden_states = self.dropout(hidden_states)
 hidden_states = self.dense(hidden_states)
 hidden_states = torch.relu(hidden_states)
 hidden_states = self.dropout(hidden_states)
 hidden_states = self.out_proj(hidden_states)
 return hidden_states



class ViT5ForTokenClassification(T5PreTrainedModel):
 def __init__(self, config):
 super().__init__(config)
 self.transformer = T5Model(config)
 self.num_labels = 2
 self.num_verdicts = 3
 
 self.verdict_head = T5FeedForwardHead(config, self.num_verdicts)
 self.evidence_head = T5FeedForwardHead(config, self.num_labels)
 
 def forward(
 self,
 input_ids: torch.LongTensor = None,
 attention_mask: Optional[torch.Tensor] = None,
 decoder_input_ids: Optional[torch.LongTensor] = None,
 decoder_attention_mask: Optional[torch.LongTensor] = None,
 head_mask: Optional[torch.Tensor] = None,
 decoder_head_mask: Optional[torch.Tensor] = None,
 cross_attn_head_mask: Optional[torch.Tensor] = None,
 encoder_outputs: Optional[List[torch.FloatTensor]] = None,
 inputs_embeds: Optional[torch.FloatTensor] = None,
 decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
 labels: Optional[torch.LongTensor] = None,
 evidence_labels: Optional[torch.LongTensor] = None,
 use_cache: Optional[bool] = None,
 output_attentions: Optional[bool] = None,
 output_hidden_states: Optional[bool] = None,
 return_dict: Optional[bool] = None,
 ):
 r"""
 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
 config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
 Returns:
 """
 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 if labels is not None:
 use_cache = False

 if input_ids is None and inputs_embeds is not None:
 raise NotImplementedError(
 f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
 )

 # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
 # decoder_input_ids from input_ids if no decoder_input_ids are provided
 if decoder_input_ids is None and decoder_inputs_embeds is None:
 if input_ids is None:
 raise ValueError(
 "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
 "passed, `input_ids` cannot be `None`. Please pass either "
 "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
 )
 decoder_input_ids = self._shift_right(input_ids)

 outputs = self.transformer(
 input_ids,
 attention_mask=attention_mask,
 decoder_input_ids=decoder_input_ids,
 decoder_attention_mask=decoder_attention_mask,
 head_mask=head_mask,
 decoder_head_mask=decoder_head_mask,
 cross_attn_head_mask=cross_attn_head_mask,
 encoder_outputs=encoder_outputs,
 inputs_embeds=inputs_embeds,
 decoder_inputs_embeds=decoder_inputs_embeds,
 use_cache=use_cache,
 output_attentions=output_attentions,
 output_hidden_states=output_hidden_states,
 return_dict=return_dict,
 )
 sequence_output = outputs[0] # (bsz, max_length, hidden_size)
 
 token_logits = self.evidence_head(sequence_output) # (bsz, max_length, 2)
 token_loss = None
 if evidence_labels is not None:
 evidence_labels = evidence_labels.to(token_logits.device)
 loss_fct = nn.CrossEntropyLoss()
 token_loss = loss_fct(token_logits.view(-1, self.num_labels), evidence_labels.view(-1))
 
 eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)

 if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
 raise ValueError("All examples must have the same number of tokens.")
 batch_size, _, hidden_size = sequence_output.shape
 sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] # (bsz, hidden_size)
 sent_logits = self.verdict_head(sentence_representation)

 sent_loss = None
 if labels is not None:
 labels = labels.to(sent_logits.device)
 if self.config.problem_type is None:
 if self.config.num_labels == 1:
 self.config.problem_type = "regression"
 elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 self.config.problem_type = "single_label_classification"
 else:
 self.config.problem_type = "multi_label_classification"

 if self.config.problem_type == "regression":
 loss_fct = nn.MSELoss()
 if self.config.num_labels == 1:
 sent_loss = loss_fct(sent_logits.squeeze(), labels.squeeze())
 else:
 sent_loss = loss_fct(sent_logits, labels)
 elif self.config.problem_type == "single_label_classification":
 loss_fct = nn.CrossEntropyLoss()
 sent_loss = loss_fct(sent_logits.view(-1, self.num_verdicts), labels.view(-1))
 elif self.config.problem_type == "multi_label_classification":
 loss_fct = nn.BCEWithLogitsLoss()
 sent_loss = loss_fct(sent_logits, labels)
 
 
 total_loss = None
 if sent_loss is not None and token_loss is not None:
 total_loss = 0.7*sent_loss + 0.3*token_loss
 
 
 return TokenClassificationOutput(
 loss=total_loss,
 token_loss=token_loss,
 sent_loss=sent_loss,
 claim_logits=sent_logits,
 evidence_logits=token_logits
 )
 
 

In [None]:
def train(model, train_dataloader, val_dataloader, args):
 print(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
 
 # creating a tmp directory to save the models
 out_dir = os.path.abspath(os.path.join(os.path.curdir, "tmp-runs", datetime.today().strftime('%a-%d-%b-%Y-%I:%M:%S%p')))

 # hparams
 min_loss = float('inf')
 sub_cycle = 0
 best_path = None
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 if args.freeze_backbone:
 model.freeze_backbone()
 
 if args.freeze_encoder:
 model.freeze_encoder()
 
 if args.freeze_decoder:
 model.freeze_decoder()
 
 if args.gradient_checkpointing:
 model.gradient_checkpointing_enable()
 
 total_num_steps = (len(train_dataloader) / args.gradient_accumulation_steps) * args.num_epochs
 opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
 
 sched = get_scheduler(
 name=args.scheduler_name,
 optimizer=opt,
 num_warmup_steps=args.warmup_steps,
 num_training_steps=total_num_steps,
 )
 
 model.to(device)
 
 print("Start Training")

 for ep in range(args.num_epochs):
 model.train()
 train_loss = 0.0
 train_acc = {'qa': 0.0, 'cls': 0.0}
 
 for step, batch in enumerate(pbar := tqdm(train_dataloader, desc=f"Epoch {ep} - training")):
 # transfer data to training device (gpu/cpu)
 batch = batch_to_device(batch, device)
 
 # forward
 outputs = model(**batch)
 
 # compute loss
 loss = outputs.loss
 
 # gather metrics
 train_loss += loss.item()

 # progress bar logging
 pbar.set_postfix(loss=loss.item(), sent_loss=outputs.sent_loss.item(), token_loss=outputs.token_loss.item())

 # backward and optimize
 loss.backward()
 
 if (step + 1) % args.gradient_accumulation_steps == 0 or (step+1) == len(train_dataloader):
 opt.step()
 sched.step()
 opt.zero_grad()
 
 train_loss /= len(train_dataloader)
 
 # Evaluate at the end_acc of the epoch (distributed evaluation as we have all GPU cores)
 model.eval()
 val_loss = 0.0
 
 for batch in (pbar := tqdm(val_dataloader, desc=f"Epoch {ep} - validation")):
 with torch.no_grad():
 batch = batch_to_device(batch, device)
 # forward
 outputs = model(**batch)

 # compute loss
 loss = outputs.loss

 # gather metrics
 val_loss += loss.item()
 
 pbar.set_postfix(loss=loss.item(), sent_loss=outputs.sent_loss.item(), token_loss=outputs.token_loss.item())
 
 val_loss /= len(val_dataloader)
 
 print(f"Summary epoch {ep}:\n" 
 f"\ttrain_loss: {train_loss:.4f} \t val_loss: {val_loss:.4f}")
 
 if val_loss < min_loss:
 min_loss = val_loss
 sub_cycle = 0
 
 best_path = os.path.join(out_dir, f"epoch_{ep}")
 print(f"Save cur model to {best_path}")
 
 try:
 model.push_to_hub('hduc-le/VyT5-Siamese-Fact-Check', private=True)
 except: 
 print("Failed to push model to hub")
 pass
 
 model.save_pretrained(best_path)
 
 else:
 sub_cycle += 1
 if sub_cycle == args.patience:
 print("Early stopping!")
 break
 
 print("End of training. Restore the best weights")
 best_model = ViT5ForTokenClassification.from_pretrained(best_path)
 
 if args.save_best:
 # save the current model
 out_dir = os.path.abspath(os.path.join(os.path.curdir, "saved-runs", datetime.today().strftime('%a-%d-%b-%Y-%I:%M:%S%p')))
 
 best_path = os.path.join(out_dir, 'best')
 try:
 model.push_to_hub('hduc-le/VyT5-SentToken-Classification', private=True)
 except:
 print("Failed to push model to hub")
 pass
 
 print(f"Save best model to {best_path}")
 
 best_model.save_pretrained(best_path)
 
 return 


In [None]:
model = ViT5ForTokenClassification.from_pretrained(
 training_args.model_name, use_cache=False, output_hidden_states=True
)
print(model)

In [None]:
train(model, train_dataloader, val_dataloader, args=training_args)