{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from transformers.modeling_outputs import (\n", " Seq2SeqQuestionAnsweringModelOutput,\n", " Seq2SeqSequenceClassifierOutput,\n", " BaseModelOutput,\n", ")\n", "from transformers import (\n", " T5ForQuestionAnswering,\n", " T5PreTrainedModel,\n", " MBartPreTrainedModel,\n", " MBartModel,\n", " T5Config,\n", " T5Model,\n", " T5EncoderModel,\n", " get_scheduler\n", ")\n", "from tqdm import tqdm \n", "from dataclasses import dataclass\n", "from typing import List, Optional, Tuple, Union\n", "\n", "import numpy as np\n", "import random\n", "import os \n", "from datetime import datetime\n", "from torch.utils.data import DataLoader, Dataset\n", "from transformers import AutoTokenizer\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import yaml\n", "from addict import Dict\n", "\n", "\n", "def load_json(file_path):\n", " with open(file_path, \"r\", encoding=\"utf-8-sig\") as f:\n", " data = json.load(f)\n", " return data\n", "\n", "\n", "def read_config(path):\n", " # read yaml and return contents\n", " with open(path, \"r\") as file:\n", " try:\n", " return Dict(yaml.safe_load(file))\n", " except yaml.YAMLError as exc:\n", " print(exc)\n", "\n", "\n", "def batch_to_device(batch: dict, device: str):\n", " for k in batch:\n", " batch[k] = batch[k].to(device)\n", " return batch\n", "\n", "\n", "def save_json(obj, path):\n", " with open(path, \"w\") as outfile:\n", " json.dump(obj, outfile, ensure_ascii=False, indent=2)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class TokenClassificationOutput:\n", " loss: Optional[torch.FloatTensor] = None\n", " sent_loss: Optional[torch.FloatTensor] = None\n", " token_loss: Optional[torch.FloatTensor] = None\n", " claim_logits: torch.FloatTensor = None\n", " evidence_logits: torch.FloatTensor = None\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def random_seed(value):\n", " torch.backends.cudnn.deterministic = True\n", " torch.manual_seed(value)\n", " torch.cuda.manual_seed(value)\n", " np.random.seed(value)\n", " random.seed(value)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass \n", "class TrainingArguments:\n", " data_path = \"data/ise-dsc01-train.json\"\n", " model_name = \"VietAI/vit5-base\"\n", " tokenizer_name = \"VietAI/vit5-base\"\n", " gradient_accumulation_steps = 8\n", " gradient_checkpointing = False\n", " num_epochs = 10\n", " lr = 3.0e-5\n", " weight_decay = 1.0e-2\n", " scheduler_name = \"cosine\"\n", " warmup_steps = 0\n", " patience = 3\n", " max_seq_length = 1024\n", " seed = 1401\n", " test_size = 0.1\n", " train_batch_size = 1\n", " val_batch_size = 1\n", "\n", " save_best = True\n", "\n", " freeze_backbone = False\n", " freeze_encoder = False\n", " freeze_decoder = False\n", "\n", "training_args = TrainingArguments()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_LABEL_MAPPING = {\"SUPPORTED\": 0, \"NEI\": 1, \"REFUTED\": 2}\n", " \n", "class TokenStanceDataset(Dataset):\n", " def __init__(self, dataset, dataset_keys, tokenizer, max_seq_length=1024) -> None:\n", " super().__init__()\n", " self.tokenizer = tokenizer\n", " self.max_seq_length = max_seq_length\n", " self.dataset = dataset\n", " self.dataset_keys = dataset_keys\n", "\n", " def __getitem__(self, idx):\n", " data_id = self.dataset_keys[idx]\n", " data_item = self.dataset[data_id]\n", " \n", " claim = data_item['claim']\n", " evidence = data_item['evidence']\n", " context = data_item['context']\n", " \n", " encodings = self.tokenizer(\n", " context, \n", " claim,\n", " truncation=True, \n", " padding=\"max_length\", \n", " max_length=self.max_seq_length, \n", " return_tensors=\"pt\"\n", " )\n", " \n", " if evidence is None:\n", " start_position, end_position = 0, 0\n", " else:\n", " start_idx = context.find(evidence)\n", " end_idx = start_idx + len(evidence)\n", " \n", " evidence_start = start_idx\n", " evidence_end = end_idx\n", " \n", " if context[start_idx: end_idx] == evidence:\n", " evidence_end = end_idx\n", " else:\n", " for n in [1, 2]:\n", " if context[start_idx-n: end_idx-n] == evidence:\n", " evidence_start = start_idx - n\n", " evidence_end = end_idx - n\n", " \n", " if evidence_start < 0:\n", " evidence_start = 0\n", " \n", " if evidence_end < 0:\n", " evidence_end = 0\n", " \n", " start_position = encodings.char_to_token(0, evidence_start)\n", " end_position = encodings.char_to_token(0, evidence_end)\n", " \n", " trace_back = 1\n", " while end_position is None:\n", " end_position = encodings.char_to_token(0, evidence_end-trace_back)\n", " trace_back += 1\n", " \n", " if start_position is None:\n", " start_position = 0\n", " end_position = 0\n", " \n", " evidence_labels = torch.zeros(self.max_seq_length,)\n", " if end_position > 0:\n", " evidence_labels[start_position: end_position] = 1\n", " evidence_labels = evidence_labels.long()\n", " \n", " #print(\"====\")\n", " #print(evidence)\n", " #print(self.tokenizer.decode(encodings.input_ids[0][evidence_labels.bool()]))\n", " \n", " label = torch.tensor(_LABEL_MAPPING[data_item[\"verdict\"]], dtype=torch.long)\n", " \n", " return {\n", " \"input_ids\": encodings.input_ids.squeeze(0),\n", " \"attention_mask\": encodings.attention_mask.squeeze(0),\n", " \"evidence_labels\": evidence_labels,\n", " \"labels\": label\n", " }\n", "\n", " def __len__(self):\n", " return len(self.dataset)\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_seed(training_args.seed)\n", "\n", "data = load_json(training_args.data_path)\n", "\n", "data_keys = list(data.keys())\n", "\n", "train_keys, dev_keys = train_test_split(\n", " data_keys,\n", " test_size=training_args.test_size,\n", " random_state=training_args.seed,\n", " shuffle=True,\n", ")\n", "\n", "train_set = {k: v for k, v in data.items() if k in train_keys}\n", "dev_set = {k: v for k, v in data.items() if k in dev_keys}\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\n", " training_args.tokenizer_name, use_fast=True\n", ")\n", "\n", "train_dataset = TokenStanceDataset(\n", " train_set, train_keys, tokenizer, training_args.max_seq_length\n", ")\n", "val_dataset = TokenStanceDataset(\n", " dev_set, dev_keys, tokenizer, training_args.max_seq_length\n", ")\n", "\n", "train_dataloader = DataLoader(\n", " train_dataset, batch_size=training_args.train_batch_size, shuffle=True\n", ")\n", "val_dataloader = DataLoader(\n", " val_dataset, batch_size=training_args.val_batch_size, shuffle=False\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class T5FeedForwardHead(nn.Module):\n", " \"\"\"Head for sentence-level classification tasks.\"\"\"\n", "\n", " def __init__(self, config, out_dim):\n", " super().__init__()\n", " self.dense = nn.Linear(config.d_model, config.d_model)\n", " self.dropout = nn.Dropout(p=config.classifier_dropout)\n", " self.out_proj = nn.Linear(config.d_model, out_dim)\n", "\n", " def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n", " hidden_states = self.dropout(hidden_states)\n", " hidden_states = self.dense(hidden_states)\n", " hidden_states = torch.relu(hidden_states)\n", " hidden_states = self.dropout(hidden_states)\n", " hidden_states = self.out_proj(hidden_states)\n", " return hidden_states\n", "\n", "\n", "\n", "class ViT5ForTokenClassification(T5PreTrainedModel):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " self.transformer = T5Model(config)\n", " self.num_labels = 2\n", " self.num_verdicts = 3\n", " \n", " self.verdict_head = T5FeedForwardHead(config, self.num_verdicts)\n", " self.evidence_head = T5FeedForwardHead(config, self.num_labels)\n", " \n", " def forward(\n", " self,\n", " input_ids: torch.LongTensor = None,\n", " attention_mask: Optional[torch.Tensor] = None,\n", " decoder_input_ids: Optional[torch.LongTensor] = None,\n", " decoder_attention_mask: Optional[torch.LongTensor] = None,\n", " head_mask: Optional[torch.Tensor] = None,\n", " decoder_head_mask: Optional[torch.Tensor] = None,\n", " cross_attn_head_mask: Optional[torch.Tensor] = None,\n", " encoder_outputs: Optional[List[torch.FloatTensor]] = None,\n", " inputs_embeds: Optional[torch.FloatTensor] = None,\n", " decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n", " labels: Optional[torch.LongTensor] = None,\n", " evidence_labels: Optional[torch.LongTensor] = None,\n", " use_cache: Optional[bool] = None,\n", " output_attentions: Optional[bool] = None,\n", " output_hidden_states: Optional[bool] = None,\n", " return_dict: Optional[bool] = None,\n", " ):\n", " r\"\"\"\n", " labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n", " Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n", " config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n", " Returns:\n", " \"\"\"\n", " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", " if labels is not None:\n", " use_cache = False\n", "\n", " if input_ids is None and inputs_embeds is not None:\n", " raise NotImplementedError(\n", " f\"Passing input embeddings is currently not supported for {self.__class__.__name__}\"\n", " )\n", "\n", " # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates\n", " # decoder_input_ids from input_ids if no decoder_input_ids are provided\n", " if decoder_input_ids is None and decoder_inputs_embeds is None:\n", " if input_ids is None:\n", " raise ValueError(\n", " \"If no `decoder_input_ids` or `decoder_inputs_embeds` are \"\n", " \"passed, `input_ids` cannot be `None`. Please pass either \"\n", " \"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`.\"\n", " )\n", " decoder_input_ids = self._shift_right(input_ids)\n", "\n", " outputs = self.transformer(\n", " input_ids,\n", " attention_mask=attention_mask,\n", " decoder_input_ids=decoder_input_ids,\n", " decoder_attention_mask=decoder_attention_mask,\n", " head_mask=head_mask,\n", " decoder_head_mask=decoder_head_mask,\n", " cross_attn_head_mask=cross_attn_head_mask,\n", " encoder_outputs=encoder_outputs,\n", " inputs_embeds=inputs_embeds,\n", " decoder_inputs_embeds=decoder_inputs_embeds,\n", " use_cache=use_cache,\n", " output_attentions=output_attentions,\n", " output_hidden_states=output_hidden_states,\n", " return_dict=return_dict,\n", " )\n", " sequence_output = outputs[0] # (bsz, max_length, hidden_size)\n", " \n", " token_logits = self.evidence_head(sequence_output) # (bsz, max_length, 2)\n", " token_loss = None\n", " if evidence_labels is not None:\n", " evidence_labels = evidence_labels.to(token_logits.device)\n", " loss_fct = nn.CrossEntropyLoss()\n", " token_loss = loss_fct(token_logits.view(-1, self.num_labels), evidence_labels.view(-1))\n", " \n", " eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)\n", "\n", " if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:\n", " raise ValueError(\"All examples must have the same number of tokens.\")\n", " batch_size, _, hidden_size = sequence_output.shape\n", " sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] # (bsz, hidden_size)\n", " sent_logits = self.verdict_head(sentence_representation)\n", "\n", " sent_loss = None\n", " if labels is not None:\n", " labels = labels.to(sent_logits.device)\n", " if self.config.problem_type is None:\n", " if self.config.num_labels == 1:\n", " self.config.problem_type = \"regression\"\n", " elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n", " self.config.problem_type = \"single_label_classification\"\n", " else:\n", " self.config.problem_type = \"multi_label_classification\"\n", "\n", " if self.config.problem_type == \"regression\":\n", " loss_fct = nn.MSELoss()\n", " if self.config.num_labels == 1:\n", " sent_loss = loss_fct(sent_logits.squeeze(), labels.squeeze())\n", " else:\n", " sent_loss = loss_fct(sent_logits, labels)\n", " elif self.config.problem_type == \"single_label_classification\":\n", " loss_fct = nn.CrossEntropyLoss()\n", " sent_loss = loss_fct(sent_logits.view(-1, self.num_verdicts), labels.view(-1))\n", " elif self.config.problem_type == \"multi_label_classification\":\n", " loss_fct = nn.BCEWithLogitsLoss()\n", " sent_loss = loss_fct(sent_logits, labels)\n", " \n", " \n", " total_loss = None\n", " if sent_loss is not None and token_loss is not None:\n", " total_loss = 0.7*sent_loss + 0.3*token_loss\n", " \n", " \n", " return TokenClassificationOutput(\n", " loss=total_loss,\n", " token_loss=token_loss,\n", " sent_loss=sent_loss,\n", " claim_logits=sent_logits,\n", " evidence_logits=token_logits\n", " )\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train(model, train_dataloader, val_dataloader, args):\n", " print(f\"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB\")\n", " \n", " # creating a tmp directory to save the models\n", " out_dir = os.path.abspath(os.path.join(os.path.curdir, \"tmp-runs\", datetime.today().strftime('%a-%d-%b-%Y-%I:%M:%S%p')))\n", "\n", " # hparams\n", " min_loss = float('inf')\n", " sub_cycle = 0\n", " best_path = None\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " \n", " if args.freeze_backbone:\n", " model.freeze_backbone()\n", " \n", " if args.freeze_encoder:\n", " model.freeze_encoder()\n", " \n", " if args.freeze_decoder:\n", " model.freeze_decoder()\n", " \n", " if args.gradient_checkpointing:\n", " model.gradient_checkpointing_enable()\n", " \n", " total_num_steps = (len(train_dataloader) / args.gradient_accumulation_steps) * args.num_epochs\n", " opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n", " \n", " sched = get_scheduler(\n", " name=args.scheduler_name,\n", " optimizer=opt,\n", " num_warmup_steps=args.warmup_steps,\n", " num_training_steps=total_num_steps,\n", " )\n", " \n", " model.to(device)\n", " \n", " print(\"Start Training\")\n", "\n", " for ep in range(args.num_epochs):\n", " model.train()\n", " train_loss = 0.0\n", " train_acc = {'qa': 0.0, 'cls': 0.0}\n", " \n", " for step, batch in enumerate(pbar := tqdm(train_dataloader, desc=f\"Epoch {ep} - training\")):\n", " # transfer data to training device (gpu/cpu)\n", " batch = batch_to_device(batch, device)\n", " \n", " # forward\n", " outputs = model(**batch)\n", " \n", " # compute loss\n", " loss = outputs.loss\n", " \n", " # gather metrics\n", " train_loss += loss.item()\n", "\n", " # progress bar logging\n", " pbar.set_postfix(loss=loss.item(), sent_loss=outputs.sent_loss.item(), token_loss=outputs.token_loss.item())\n", "\n", " # backward and optimize\n", " loss.backward()\n", " \n", " if (step + 1) % args.gradient_accumulation_steps == 0 or (step+1) == len(train_dataloader):\n", " opt.step()\n", " sched.step()\n", " opt.zero_grad()\n", " \n", " train_loss /= len(train_dataloader)\n", " \n", " # Evaluate at the end_acc of the epoch (distributed evaluation as we have all GPU cores)\n", " model.eval()\n", " val_loss = 0.0\n", " \n", " for batch in (pbar := tqdm(val_dataloader, desc=f\"Epoch {ep} - validation\")):\n", " with torch.no_grad():\n", " batch = batch_to_device(batch, device)\n", " # forward\n", " outputs = model(**batch)\n", "\n", " # compute loss\n", " loss = outputs.loss\n", "\n", " # gather metrics\n", " val_loss += loss.item()\n", " \n", " pbar.set_postfix(loss=loss.item(), sent_loss=outputs.sent_loss.item(), token_loss=outputs.token_loss.item())\n", " \n", " val_loss /= len(val_dataloader)\n", " \n", " print(f\"Summary epoch {ep}:\\n\" \n", " f\"\\ttrain_loss: {train_loss:.4f} \\t val_loss: {val_loss:.4f}\")\n", " \n", " if val_loss < min_loss:\n", " min_loss = val_loss\n", " sub_cycle = 0\n", " \n", " best_path = os.path.join(out_dir, f\"epoch_{ep}\")\n", " print(f\"Save cur model to {best_path}\")\n", " \n", " try:\n", " model.push_to_hub('hduc-le/VyT5-Siamese-Fact-Check', private=True)\n", " except: \n", " print(\"Failed to push model to hub\")\n", " pass\n", " \n", " model.save_pretrained(best_path)\n", " \n", " else:\n", " sub_cycle += 1\n", " if sub_cycle == args.patience:\n", " print(\"Early stopping!\")\n", " break\n", " \n", " print(\"End of training. Restore the best weights\")\n", " best_model = ViT5ForTokenClassification.from_pretrained(best_path)\n", " \n", " if args.save_best:\n", " # save the current model\n", " out_dir = os.path.abspath(os.path.join(os.path.curdir, \"saved-runs\", datetime.today().strftime('%a-%d-%b-%Y-%I:%M:%S%p')))\n", " \n", " best_path = os.path.join(out_dir, 'best')\n", " try:\n", " model.push_to_hub('hduc-le/VyT5-SentToken-Classification', private=True)\n", " except:\n", " print(\"Failed to push model to hub\")\n", " pass\n", " \n", " print(f\"Save best model to {best_path}\")\n", " \n", " best_model.save_pretrained(best_path)\n", " \n", " return \n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = ViT5ForTokenClassification.from_pretrained(\n", " training_args.model_name, use_cache=False, output_hidden_states=True\n", ")\n", "print(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train(model, train_dataloader, val_dataloader, args=training_args)" ] } ], "metadata": { "kernelspec": { "display_name": "mlds", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 2 }