File size: 28,508 Bytes
2b46659
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30747,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import pandas as pd\n\nsplits = {'train': 'data/train-00000-of-00001.parquet', \n          'validation_matched': 'data/validation_matched-00000-of-00001.parquet', \n          'validation_mismatched': 'data/validation_mismatched-00000-of-00001.parquet'}\n          \ndf = pd.read_parquet(\"hf://datasets/nyu-mll/multi_nli/\" + splits[\"train\"])","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2024-07-20T13:04:43.497190Z","iopub.execute_input":"2024-07-20T13:04:43.497536Z","iopub.status.idle":"2024-07-20T13:04:51.222716Z","shell.execute_reply.started":"2024-07-20T13:04:43.497506Z","shell.execute_reply":"2024-07-20T13:04:51.221916Z"},"trusted":true},"execution_count":1,"outputs":[]},{"cell_type":"code","source":"df = df[['label', 'premise', 'hypothesis']].iloc[:13000]\ndf","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:52.891990Z","iopub.execute_input":"2024-07-20T13:04:52.892933Z","iopub.status.idle":"2024-07-20T13:04:53.013716Z","shell.execute_reply.started":"2024-07-20T13:04:52.892902Z","shell.execute_reply":"2024-07-20T13:04:53.012747Z"},"trusted":true},"execution_count":2,"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"       label                                            premise  \\\n0          1  Conceptually cream skimming has two basic dime...   \n1          0  you know during the season and i guess at at y...   \n2          0  One of our number will carry out your instruct...   \n3          0  How do you know? All this is their information...   \n4          1  yeah i tell you what though if you go price so...   \n...      ...                                                ...   \n12995      1  right you have to question you have to wonder ...   \n12996      2  Reviewers may not be familiar with the charact...   \n12997      1  yeah it was Twins was good too  because when i...   \n12998      0                         The Jews are Neanderthals.   \n12999      1  25--to get a copy of my book legally from my W...   \n\n                                              hypothesis  \n0      Product and geography are what make cream skim...  \n1      You lose the things to the following level if ...  \n2      A member of my team will execute your orders w...  \n3                      This information belongs to them.  \n4               The tennis shoes have a range of prices.  \n...                                                  ...  \n12995  I would not mind living on an island to find out.  \n12996  Typically, reviewers are fully aware of an eme...  \n12997          Twins was the best movie I saw last year.  \n12998               Jewish people are like Neanderthals.  \n12999                        My book is free on my site.  \n\n[13000 rows x 3 columns]","text/html":"<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>label</th>\n      <th>premise</th>\n      <th>hypothesis</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>1</td>\n      <td>Conceptually cream skimming has two basic dime...</td>\n      <td>Product and geography are what make cream skim...</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>0</td>\n      <td>you know during the season and i guess at at y...</td>\n      <td>You lose the things to the following level if ...</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>0</td>\n      <td>One of our number will carry out your instruct...</td>\n      <td>A member of my team will execute your orders w...</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>0</td>\n      <td>How do you know? All this is their information...</td>\n      <td>This information belongs to them.</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>1</td>\n      <td>yeah i tell you what though if you go price so...</td>\n      <td>The tennis shoes have a range of prices.</td>\n    </tr>\n    <tr>\n      <th>...</th>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n    </tr>\n    <tr>\n      <th>12995</th>\n      <td>1</td>\n      <td>right you have to question you have to wonder ...</td>\n      <td>I would not mind living on an island to find out.</td>\n    </tr>\n    <tr>\n      <th>12996</th>\n      <td>2</td>\n      <td>Reviewers may not be familiar with the charact...</td>\n      <td>Typically, reviewers are fully aware of an eme...</td>\n    </tr>\n    <tr>\n      <th>12997</th>\n      <td>1</td>\n      <td>yeah it was Twins was good too  because when i...</td>\n      <td>Twins was the best movie I saw last year.</td>\n    </tr>\n    <tr>\n      <th>12998</th>\n      <td>0</td>\n      <td>The Jews are Neanderthals.</td>\n      <td>Jewish people are like Neanderthals.</td>\n    </tr>\n    <tr>\n      <th>12999</th>\n      <td>1</td>\n      <td>25--to get a copy of my book legally from my W...</td>\n      <td>My book is free on my site.</td>\n    </tr>\n  </tbody>\n</table>\n<p>13000 rows × 3 columns</p>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"import torch\nfrom torch.utils.data import Dataset, TensorDataset, DataLoader\nfrom torch.nn.utils.rnn import pad_sequence\nimport pickle\nimport os\nfrom transformers import BertTokenizer\nfrom sklearn.model_selection import train_test_split","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:53.472075Z","iopub.execute_input":"2024-07-20T13:04:53.472883Z","iopub.status.idle":"2024-07-20T13:04:57.918248Z","shell.execute_reply.started":"2024-07-20T13:04:53.472853Z","shell.execute_reply":"2024-07-20T13:04:57.917380Z"},"trusted":true},"execution_count":3,"outputs":[]},{"cell_type":"code","source":"class MNLIDataBert(Dataset):\n\n  def __init__(self, train_df, val_df):\n\n    self.train_df = train_df\n    self.val_df = val_df\n\n    self.base_path = '/content/'\n    self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # Using a pre-trained BERT tokenizer to encode sentences\n    self.train_data = None\n    self.val_data = None\n    self.init_data()\n\n  def init_data(self):\n    self.train_data = self.load_data(self.train_df)\n    self.val_data = self.load_data(self.val_df)\n\n  def load_data(self, df):\n    MAX_LEN = 512\n    token_ids = []\n    mask_ids = []\n    seg_ids = []\n    y = []\n\n    premise_list = df['premise'].to_list()\n    hypothesis_list = df['hypothesis'].to_list()\n    label_list = df['label'].to_list()\n\n    for (premise, hypothesis, label) in zip(premise_list, hypothesis_list, label_list):\n      premise_id = self.tokenizer.encode(premise, add_special_tokens = False)\n      hypothesis_id = self.tokenizer.encode(hypothesis, add_special_tokens = False)\n      pair_token_ids = [self.tokenizer.cls_token_id] + premise_id + [self.tokenizer.sep_token_id] + hypothesis_id + [self.tokenizer.sep_token_id]\n      premise_len = len(premise_id)\n      hypothesis_len = len(hypothesis_id)\n\n      segment_ids = torch.tensor([0] * (premise_len + 2) + [1] * (hypothesis_len + 1))  # premise and hypothesis \n      attention_mask_ids = torch.tensor([1] * (premise_len + hypothesis_len + 3))  # mask padded values\n\n      token_ids.append(torch.tensor(pair_token_ids))\n      seg_ids.append(segment_ids)\n      mask_ids.append(attention_mask_ids)\n      y.append(label)\n    \n    token_ids = pad_sequence(token_ids, batch_first=True)\n    mask_ids = pad_sequence(mask_ids, batch_first=True)\n    seg_ids = pad_sequence(seg_ids, batch_first=True)\n    y = torch.tensor(y)\n    dataset = TensorDataset(token_ids, mask_ids, seg_ids, y)\n    print(len(dataset))\n    return dataset\n\n  def get_data_loaders(self, batch_size=32, shuffle=True):\n    train_loader = DataLoader(\n      self.train_data,\n      shuffle=shuffle,\n      batch_size=batch_size\n    )\n\n    val_loader = DataLoader(\n      self.val_data,\n      shuffle=shuffle,\n      batch_size=batch_size\n    )\n\n    return train_loader, val_loader","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.919963Z","iopub.execute_input":"2024-07-20T13:04:57.920822Z","iopub.status.idle":"2024-07-20T13:04:57.935843Z","shell.execute_reply.started":"2024-07-20T13:04:57.920788Z","shell.execute_reply":"2024-07-20T13:04:57.934836Z"},"trusted":true},"execution_count":4,"outputs":[]},{"cell_type":"code","source":"train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.937187Z","iopub.execute_input":"2024-07-20T13:04:57.937623Z","iopub.status.idle":"2024-07-20T13:04:57.953101Z","shell.execute_reply.started":"2024-07-20T13:04:57.937589Z","shell.execute_reply":"2024-07-20T13:04:57.952066Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"code","source":"val_df","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.955406Z","iopub.execute_input":"2024-07-20T13:04:57.955788Z","iopub.status.idle":"2024-07-20T13:04:57.967894Z","shell.execute_reply.started":"2024-07-20T13:04:57.955763Z","shell.execute_reply":"2024-07-20T13:04:57.966921Z"},"trusted":true},"execution_count":6,"outputs":[{"execution_count":6,"output_type":"execute_result","data":{"text/plain":"       label                                            premise  \\\n3615       2  An ambitious plan for a hexagonally based chur...   \n2536       1                               for for city use and   \n5397       0  The really valuable estate cannot be touched b...   \n9982       0  isn't that the truth it's funny in fact it's i...   \n1498       0  Most drivers will be able to point out the Bok...   \n...      ...                                                ...   \n11872      1  As the road rises, the rugged countryside beco...   \n9264       0  The monastery rests in a fertile valley and is...   \n7277       2  Since everyone who matters presumably knows al...   \n3752       2             so what type of restaurant do you like   \n6292       2  right that that's actually the part that that ...   \n\n                                              hypothesis  \n3615   The complex plan of the hospital, came to noth...  \n2536                           Only the city can use it.  \n5397   The death tax is unable to reach the most impo...  \n9982   I love that music from my childhood has return...  \n1498     The Bok House is transformed into a restaurant.  \n...                                                  ...  \n11872         The hillsides are full of ferns and trees.  \n9264   In a fertile valley surrounded by plane and pi...  \n7277   People who matter no nothing about who backs t...  \n3752         You don't eat at restaurants at all, right?  \n6292      I don't believe muslims are hated by Israelis.  \n\n[2600 rows x 3 columns]","text/html":"<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>label</th>\n      <th>premise</th>\n      <th>hypothesis</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>3615</th>\n      <td>2</td>\n      <td>An ambitious plan for a hexagonally based chur...</td>\n      <td>The complex plan of the hospital, came to noth...</td>\n    </tr>\n    <tr>\n      <th>2536</th>\n      <td>1</td>\n      <td>for for city use and</td>\n      <td>Only the city can use it.</td>\n    </tr>\n    <tr>\n      <th>5397</th>\n      <td>0</td>\n      <td>The really valuable estate cannot be touched b...</td>\n      <td>The death tax is unable to reach the most impo...</td>\n    </tr>\n    <tr>\n      <th>9982</th>\n      <td>0</td>\n      <td>isn't that the truth it's funny in fact it's i...</td>\n      <td>I love that music from my childhood has return...</td>\n    </tr>\n    <tr>\n      <th>1498</th>\n      <td>0</td>\n      <td>Most drivers will be able to point out the Bok...</td>\n      <td>The Bok House is transformed into a restaurant.</td>\n    </tr>\n    <tr>\n      <th>...</th>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n    </tr>\n    <tr>\n      <th>11872</th>\n      <td>1</td>\n      <td>As the road rises, the rugged countryside beco...</td>\n      <td>The hillsides are full of ferns and trees.</td>\n    </tr>\n    <tr>\n      <th>9264</th>\n      <td>0</td>\n      <td>The monastery rests in a fertile valley and is...</td>\n      <td>In a fertile valley surrounded by plane and pi...</td>\n    </tr>\n    <tr>\n      <th>7277</th>\n      <td>2</td>\n      <td>Since everyone who matters presumably knows al...</td>\n      <td>People who matter no nothing about who backs t...</td>\n    </tr>\n    <tr>\n      <th>3752</th>\n      <td>2</td>\n      <td>so what type of restaurant do you like</td>\n      <td>You don't eat at restaurants at all, right?</td>\n    </tr>\n    <tr>\n      <th>6292</th>\n      <td>2</td>\n      <td>right that that's actually the part that that ...</td>\n      <td>I don't believe muslims are hated by Israelis.</td>\n    </tr>\n  </tbody>\n</table>\n<p>2600 rows × 3 columns</p>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"mnli_dataset = MNLIDataBert(train_df, val_df)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.969018Z","iopub.execute_input":"2024-07-20T13:04:57.969361Z","iopub.status.idle":"2024-07-20T13:05:17.499250Z","shell.execute_reply.started":"2024-07-20T13:04:57.969331Z","shell.execute_reply":"2024-07-20T13:05:17.498209Z"},"trusted":true},"execution_count":7,"outputs":[{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"12f337f451524f179a9216fbe3066aa7"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"12e0a457b18e489383f90d593015db87"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"97c24925a5b643bc8a82d90efe5daf0f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"030a47a6b43b4875bbf7b65c76190618"}},"metadata":{}},{"name":"stdout","text":"10400\n2600\n","output_type":"stream"}]},{"cell_type":"code","source":"train_loader, val_loader = mnli_dataset.get_data_loaders()","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:17.500643Z","iopub.execute_input":"2024-07-20T13:05:17.501044Z","iopub.status.idle":"2024-07-20T13:05:17.506066Z","shell.execute_reply.started":"2024-07-20T13:05:17.500988Z","shell.execute_reply":"2024-07-20T13:05:17.505097Z"},"trusted":true},"execution_count":8,"outputs":[]},{"cell_type":"code","source":"from transformers import BertForSequenceClassification, AdamW\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nmodel = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\", num_labels=3)\nmodel.to(device)\n\noptimizer_grouped_parameters = [\n    {'params': [p for n, p in model.named_parameters() if 'bias' not in n and 'LayerNorm.weight' not in n], 'weight_decay': 0.01},\n    {'params': [p for n, p in model.named_parameters() if 'bias' in n or 'LayerNorm.weight' in n], 'weight_decay': 0.0}\n]\n\noptimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, correct_bias=False)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:17.507296Z","iopub.execute_input":"2024-07-20T13:05:17.507927Z","iopub.status.idle":"2024-07-20T13:05:21.748203Z","shell.execute_reply.started":"2024-07-20T13:05:17.507894Z","shell.execute_reply":"2024-07-20T13:05:21.747249Z"},"trusted":true},"execution_count":9,"outputs":[{"output_type":"display_data","data":{"text/plain":"model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e507bb06d8e046d6bc13f059dcf0bea0"}},"metadata":{}},{"name":"stderr","text":"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n  warnings.warn(\n","output_type":"stream"}]},{"cell_type":"code","source":"def multi_acc(y_pred, y_test):\n  acc = (torch.log_softmax(y_pred, dim=1).argmax(dim=1) == y_test).sum().float() / float(y_test.size(0))\n  return acc","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:21.751451Z","iopub.execute_input":"2024-07-20T13:05:21.752046Z","iopub.status.idle":"2024-07-20T13:05:21.758221Z","shell.execute_reply.started":"2024-07-20T13:05:21.751990Z","shell.execute_reply":"2024-07-20T13:05:21.757094Z"},"trusted":true},"execution_count":10,"outputs":[]},{"cell_type":"code","source":"import time","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:21.759529Z","iopub.execute_input":"2024-07-20T13:05:21.759906Z","iopub.status.idle":"2024-07-20T13:05:21.780569Z","shell.execute_reply.started":"2024-07-20T13:05:21.759873Z","shell.execute_reply":"2024-07-20T13:05:21.779531Z"},"trusted":true},"execution_count":11,"outputs":[]},{"cell_type":"code","source":"EPOCHS = 2\n\ndef train(model, train_loader, val_loader, optimizer):  \n  total_step = len(train_loader)\n\n  for epoch in range(EPOCHS):\n    start = time.time()\n    model.train()\n    total_train_loss = 0\n    total_train_acc  = 0\n    for batch_idx, (pair_token_ids, mask_ids, seg_ids, y) in enumerate(train_loader):\n      optimizer.zero_grad()\n      pair_token_ids = pair_token_ids.to(device)\n      mask_ids = mask_ids.to(device)\n      seg_ids = seg_ids.to(device)\n      labels = y.to(device)\n\n      loss, prediction = model(pair_token_ids, \n                             token_type_ids=seg_ids, \n                             attention_mask=mask_ids, \n                             labels=labels).values()\n\n      acc = multi_acc(prediction, labels)\n\n      loss.backward()\n      optimizer.step()\n      \n      total_train_loss += loss.item()\n      total_train_acc  += acc.item()\n\n    train_acc  = total_train_acc/len(train_loader)\n    train_loss = total_train_loss/len(train_loader)\n    model.eval()\n    total_val_acc  = 0\n    total_val_loss = 0\n    with torch.no_grad():\n      for batch_idx, (pair_token_ids, mask_ids, seg_ids, y) in enumerate(val_loader):\n        optimizer.zero_grad()\n        pair_token_ids = pair_token_ids.to(device)\n        mask_ids = mask_ids.to(device)\n        seg_ids = seg_ids.to(device)\n        labels = y.to(device)\n        \n        loss, prediction = model(pair_token_ids, \n                             token_type_ids=seg_ids, \n                             attention_mask=mask_ids, \n                             labels=labels).values()\n        \n        acc = multi_acc(prediction, labels)\n\n        total_val_loss += loss.item()\n        total_val_acc  += acc.item()\n\n    val_acc  = total_val_acc/len(val_loader)\n    val_loss = total_val_loss/len(val_loader)\n    end = time.time()\n    hours, rem = divmod(end-start, 3600)\n    minutes, seconds = divmod(rem, 60)\n\n    print(f'Epoch {epoch+1}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}')\n    print(\"{:0>2}:{:0>2}:{:05.2f}\".format(int(hours),int(minutes),seconds))","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:21.782196Z","iopub.execute_input":"2024-07-20T13:05:21.783344Z","iopub.status.idle":"2024-07-20T13:05:21.797911Z","shell.execute_reply.started":"2024-07-20T13:05:21.783317Z","shell.execute_reply":"2024-07-20T13:05:21.797059Z"},"trusted":true},"execution_count":12,"outputs":[]},{"cell_type":"code","source":"train(model, train_loader, val_loader, optimizer)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:21.799438Z","iopub.execute_input":"2024-07-20T13:05:21.800060Z","iopub.status.idle":"2024-07-20T13:21:51.983474Z","shell.execute_reply.started":"2024-07-20T13:05:21.800009Z","shell.execute_reply":"2024-07-20T13:21:51.982480Z"},"trusted":true},"execution_count":13,"outputs":[{"name":"stdout","text":"Epoch 1: train_loss: 0.8012 train_acc: 0.6405 | val_loss: 0.6349 val_acc: 0.7367\n00:08:12.58\nEpoch 2: train_loss: 0.4223 train_acc: 0.8425 | val_loss: 0.6711 val_acc: 0.7416\n00:08:17.60\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\nfrom transformers import BertTokenizer\nimport torch.nn.functional as F\n\nmodel.eval()\n\n# Load the tokenizer\ntokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n\n# Function to predict entailment for a single premise-hypothesis pair\ndef predict_entailment(premise, hypothesis):\n    # Tokenize and encode the inputs\n    premise_id = tokenizer.encode(premise, add_special_tokens=False)\n    hypothesis_id = tokenizer.encode(hypothesis, add_special_tokens=False)\n    pair_token_ids = [tokenizer.cls_token_id] + premise_id + [tokenizer.sep_token_id] + hypothesis_id + [tokenizer.sep_token_id]\n    \n    segment_ids = torch.tensor([0] * (len(premise_id) + 2) + [1] * (len(hypothesis_id) + 1)).unsqueeze(0)  # Add batch dimension\n    attention_mask_ids = torch.tensor([1] * (len(premise_id) + len(hypothesis_id) + 3)).unsqueeze(0)  # Add batch dimension\n    token_ids = torch.tensor(pair_token_ids).unsqueeze(0)  # Add batch dimension\n    \n    # Move to device\n    token_ids = token_ids.to(device)\n    segment_ids = segment_ids.to(device)\n    attention_mask_ids = attention_mask_ids.to(device)\n    \n    # Run the model\n    with torch.no_grad():\n        outputs = model(token_ids, token_type_ids=segment_ids, attention_mask=attention_mask_ids)\n        logits = outputs.logits\n    \n    # Apply softmax to get probabilities\n    probs = F.softmax(logits, dim=1)\n    \n    # Get the predicted label\n    predicted_label = torch.argmax(probs, dim=1).item()\n    \n    return predicted_label, probs","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:51.984900Z","iopub.execute_input":"2024-07-20T13:21:51.985354Z","iopub.status.idle":"2024-07-20T13:21:52.134564Z","shell.execute_reply.started":"2024-07-20T13:21:51.985319Z","shell.execute_reply":"2024-07-20T13:21:52.133707Z"},"trusted":true},"execution_count":14,"outputs":[]},{"cell_type":"code","source":"label_map = {0: 'Entailment', 1: 'Neutral', 2: 'Contradiction'}","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.136154Z","iopub.execute_input":"2024-07-20T13:21:52.136726Z","iopub.status.idle":"2024-07-20T13:21:52.141308Z","shell.execute_reply.started":"2024-07-20T13:21:52.136692Z","shell.execute_reply":"2024-07-20T13:21:52.140368Z"},"trusted":true},"execution_count":15,"outputs":[]},{"cell_type":"code","source":"# Example premises and hypotheses\npremises = [\n    \"A man is playing a guitar.\",\n    \"Laura likes to go to restaurants every weekend.\",\n    \"Messi is a proffesional football player.\"\n]\n\nhypotheses = [\n    \"A person is making music.\",\n    \"Laura doesn't eat at restaurants at all.\",\n    \"Akash is doing his homework.\"\n]","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.142479Z","iopub.execute_input":"2024-07-20T13:21:52.142762Z","iopub.status.idle":"2024-07-20T13:21:52.150447Z","shell.execute_reply.started":"2024-07-20T13:21:52.142738Z","shell.execute_reply":"2024-07-20T13:21:52.149469Z"},"trusted":true},"execution_count":16,"outputs":[]},{"cell_type":"code","source":"\n# Predict entailment for each pair\nfor premise, hypothesis in zip(premises, hypotheses):\n    label, probs = predict_entailment(premise, hypothesis)\n    print(f\"Premise: {premise}\")\n    print(f\"Hypothesis: {hypothesis}\")\n    print(f\"Predicted label: {label_map[label]}\")\n    print(f\"Probabilities: {probs}\")\n    print('-'*80)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.151433Z","iopub.execute_input":"2024-07-20T13:21:52.151688Z","iopub.status.idle":"2024-07-20T13:21:52.340741Z","shell.execute_reply.started":"2024-07-20T13:21:52.151665Z","shell.execute_reply":"2024-07-20T13:21:52.339852Z"},"trusted":true},"execution_count":17,"outputs":[{"name":"stdout","text":"Premise: A man is playing a guitar.\nHypothesis: A person is making music.\nPredicted label: Entailment\nProbabilities: tensor([[0.9668, 0.0200, 0.0132]], device='cuda:0')\n--------------------------------------------------------------------------------\nPremise: Laura likes to go to restaurants every weekend.\nHypothesis: Laura doesn't eat at restaurants at all.\nPredicted label: Contradiction\nProbabilities: tensor([[0.0016, 0.0022, 0.9962]], device='cuda:0')\n--------------------------------------------------------------------------------\nPremise: Messi is a proffesional football player.\nHypothesis: Akash is doing his homework.\nPredicted label: Neutral\nProbabilities: tensor([[0.0153, 0.6406, 0.3441]], device='cuda:0')\n--------------------------------------------------------------------------------\n","output_type":"stream"}]},{"cell_type":"code","source":"model_path = \"./ema_task_model\"\ntokenizer_path = \"./ema_task_tokenizer\"\n\n# Save the model and tokenizer\nmodel.save_pretrained(model_path)\ntokenizer.save_pretrained(tokenizer_path)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.344190Z","iopub.execute_input":"2024-07-20T13:21:52.344492Z","iopub.status.idle":"2024-07-20T13:21:53.357043Z","shell.execute_reply.started":"2024-07-20T13:21:52.344467Z","shell.execute_reply":"2024-07-20T13:21:53.356070Z"},"trusted":true},"execution_count":18,"outputs":[{"execution_count":18,"output_type":"execute_result","data":{"text/plain":"('./ema_task_tokenizer/tokenizer_config.json',\n './ema_task_tokenizer/special_tokens_map.json',\n './ema_task_tokenizer/vocab.txt',\n './ema_task_tokenizer/added_tokens.json')"},"metadata":{}}]},{"cell_type":"code","source":"!zip -r ema_task_model.zip ema_task_model\n!zip -r ema_task_tokenizer.zip ema_task_tokenizer","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:53.358178Z","iopub.execute_input":"2024-07-20T13:21:53.358464Z","iopub.status.idle":"2024-07-20T13:22:18.652819Z","shell.execute_reply.started":"2024-07-20T13:21:53.358439Z","shell.execute_reply":"2024-07-20T13:22:18.651602Z"},"trusted":true},"execution_count":19,"outputs":[{"name":"stdout","text":"  adding: ema_task_model/ (stored 0%)\n  adding: ema_task_model/config.json (deflated 51%)\n  adding: ema_task_model/model.safetensors (deflated 7%)\n  adding: ema_task_tokenizer/ (stored 0%)\n  adding: ema_task_tokenizer/vocab.txt (deflated 53%)\n  adding: ema_task_tokenizer/special_tokens_map.json (deflated 42%)\n  adding: ema_task_tokenizer/tokenizer_config.json (deflated 75%)\n","output_type":"stream"}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}