{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import pickle\n", "import h5py\n", "from tqdm import tqdm\n", "from transformers import AutoTokenizer\n", "from scipy.special import expit " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def compute_tok_score_cart(doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask):\n", " qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1\n", " doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD\n", " exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD\n", " exact_match = exact_match.float()\n", " scores_no_masking = torch.matmul(\n", " qry_reps.view(-1, 16), # (Q * LQ) * d\n", " doc_reps.view(-1, 16).transpose(0, 1) # d * (D * LD)\n", " )\n", " scores_no_masking = scores_no_masking.view(\n", " *qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD\n", " scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D\n", " tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1)\n", " \n", " return tok_scores\n", "\n", "import torch\n", "from typing import Optional\n", "def coil_fast_eval_forward(\n", " input_ids: Optional[torch.Tensor] = None,\n", " doc_reps = None,\n", " logits: Optional[torch.Tensor] = None,\n", " desc_input_ids = None,\n", " desc_attention_mask = None,\n", " lab_reps = None,\n", " label_embeddings = None\n", "):\n", " tok_scores = compute_tok_score_cart(\n", " doc_reps, input_ids,\n", " lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask\n", " )\n", " logits = (logits.unsqueeze(0) @ label_embeddings.T)\n", " new_tok_scores = torch.zeros(logits.shape, device = logits.device)\n", " for i in range(tok_scores.shape[1]):\n", " stride = tok_scores.shape[0]//tok_scores.shape[1]\n", " new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i]\n", " return (logits + new_tok_scores).squeeze()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "label_list = [x.strip() for x in open('datasets/Amzn13K/all_labels.txt')]\n", "unseen_label_list = [x.strip() for x in open('datasets/Amzn13K/unseen_labels_split6500_2.txt')]\n", "num_labels = len(label_list)\n", "label_list.sort() # For consistency\n", "l2i = {v: i for i, v in enumerate(label_list)}\n", "unseen_label_indexes = [l2i[x] for x in unseen_label_list]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import json\n", "coil_cluster_map = json.load(open('bert_coil_map_dict_lemma255K_isotropic.json')) " ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "label_preds = pickle.load(open('/n/fs/nlp-pranjal/SemSup-LMLC/training/ablation_amzn_1_main_labels_zsl.pkl','rb'))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "label_preds = pickle.load(open('/n/fs/scratch/pranjal/seed_experiments/ablation_amzn_eda_labels_zsl_seed2.pkl','rb'))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 13330/13330 [00:00<00:00, 64680.71it/s]\n" ] } ], "source": [ "all_lab_reps, all_label_embeddings, all_desc_input_ids, all_desc_attention_mask = [], [], [], []\n", "for l in tqdm(label_list):\n", " ll = label_preds[l]\n", " lab_reps, label_embeddings, desc_input_ids, desc_attention_mask = ll[np.random.randint(len(ll))] \n", " all_lab_reps.append(lab_reps.squeeze())\n", " all_label_embeddings.append(label_embeddings.squeeze())\n", " all_desc_input_ids.append(desc_input_ids.squeeze())\n", " all_desc_attention_mask.append(desc_attention_mask.squeeze())\n", "all_lab_reps = torch.stack(all_lab_reps).cpu()\n", "all_label_embeddings = torch.stack(all_label_embeddings).cpu()\n", "all_desc_input_ids = torch.stack(all_desc_input_ids).cpu()\n", "all_desc_attention_mask = torch.stack(all_desc_attention_mask).cpu()\n", "all_desc_input_ids_clus = torch.tensor([[coil_cluster_map[str(x.item())] for x in xx] for xx in all_desc_input_ids])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pickle.dump([all_lab_reps, all_label_embeddings, all_desc_input_ids, all_desc_input_ids_clus, all_desc_attention_mask], open('precomputed/Amzn13K/amzn_base_labels_data1_4.pkl','wb'))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "all_lab_reps1, all_label_embeddings1, _, all_desc_input_ids1, all_desc_attention_mask1 = pickle.load(open('precomputed/Amzn13K/amzn_base_labels_data1.pkl','rb'))\n", "all_lab_reps2, all_label_embeddings2, _, all_desc_input_ids2, all_desc_attention_mask2 = pickle.load(open('precomputed/Amzn13K/amzn_base_labels_data2.pkl','rb'))\n", "all_lab_reps3, all_label_embeddings3, _, all_desc_input_ids3, all_desc_attention_mask3 = pickle.load(open('precomputed/Amzn13K/amzn_base_labels_data3.pkl','rb'))\n", "\n", "\n", "all_lab_reps = [all_lab_reps1.to(device), all_lab_reps2.to(device), all_lab_reps3.to(device)]\n", "all_label_embeddings = [all_label_embeddings1.to(device), all_label_embeddings2.to(device), all_label_embeddings3.to(device)]\n", "all_desc_input_ids = [all_desc_input_ids1.to(device), all_desc_input_ids2.to(device), all_desc_input_ids3.to(device)]\n", "all_desc_attention_mask = [all_desc_attention_mask1.to(device), all_desc_attention_mask2.to(device), all_desc_attention_mask3.to(device)]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Yaml Config is:\n", "--------------------------------------------------------------------------------\n", "{'task_name': 'amazon13k', 'dataset_name': 'amazon13k', 'dataset_config_name': None, 'max_seq_length': 160, 'overwrite_output_dir': False, 'overwrite_cache': False, 'pad_to_max_length': True, 'load_from_local': True, 'max_train_samples': None, 'max_eval_samples': 15000, 'max_predict_samples': None, 'train_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/train_split6500_2.jsonl', 'validation_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/test_unseen_split6500_2.jsonl', 'test_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/test_unseen_split6500_2.jsonl', 'label_max_seq_length': 160, 'descriptions_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/heir_withdescriptions_v3_v3_unseen_edaaug.json', 'test_descriptions_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/heir_withdescriptions_v3_v3.json', 'all_labels': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/all_labels.txt', 'test_labels': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/unseen_labels_split6500_2.txt', 'contrastive_learning_samples': 1000, 'cl_min_positive_descs': 1, 'coil_cluster_mapping_path': 'bert_coil_map_dict_lemma255K_isotropic.json', 'model_name_or_path': 'bert-base-uncased', 'config_name': None, 'tokenizer_name': None, 'cache_dir': None, 'use_fast_tokenizer': True, 'model_revision': 'main', 'use_auth_token': False, 'ignore_mismatched_sizes': False, 'negative_sampling': 'none', 'semsup': True, 'label_model_name_or_path': 'prajjwal1/bert-small', 'encoder_model_type': 'bert', 'use_custom_optimizer': 'adamw', 'output_learning_rate': 0.0001, 'arch_type': 2, 'add_label_name': True, 'normalize_embeddings': False, 'tie_weights': False, 'coil': True, 'colbert': False, 'token_dim': 16, 'label_frozen_layers': 2, 'do_train': True, 'do_eval': True, 'do_predict': False, 'per_device_train_batch_size': 1, 'gradient_accumulation_steps': 8, 'per_device_eval_batch_size': 1, 'learning_rate': 5e-05, 'num_train_epochs': 2, 'save_steps': 4900, 'evaluation_strategy': 'steps', 'eval_steps': 3000000, 'fp16': True, 'fp16_opt_level': 'O1', 'lr_scheduler_type': 'linear', 'dataloader_num_workers': 16, 'label_names': ['labels'], 'scenario': 'unseen_labels', 'ddp_find_unused_parameters': False, 'ignore_data_skip': True, 'seed': -1, 'EXP_NAME': 'semsup_descs_100ep_newds_cosine', 'EXP_DESC': 'SemSup Descriptions ran for 100 epochs', 'output_dir': 'demo_tmp'}\n", "--------------------------------------------------------------------------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Config is BertConfig {\n", " \"_name_or_path\": \"bert-base-uncased\",\n", " \"arch_type\": 2,\n", " \"architectures\": [\n", " \"BertForMaskedLM\"\n", " ],\n", " \"attention_probs_dropout_prob\": 0.1,\n", " \"classifier_dropout\": null,\n", " \"coil\": true,\n", " \"colbert\": false,\n", " \"encoder_model_type\": \"bert\",\n", " \"finetuning_task\": \"amazon13k\",\n", " \"gradient_checkpointing\": false,\n", " \"hidden_act\": \"gelu\",\n", " \"hidden_dropout_prob\": 0.1,\n", " \"hidden_size\": 768,\n", " \"initializer_range\": 0.02,\n", " \"intermediate_size\": 3072,\n", " \"label_hidden_size\": 512,\n", " \"layer_norm_eps\": 1e-12,\n", " \"max_position_embeddings\": 512,\n", " \"model_name_or_path\": \"bert-base-uncased\",\n", " \"model_type\": \"bert\",\n", " \"negative_sampling\": \"none\",\n", " \"num_attention_heads\": 12,\n", " \"num_hidden_layers\": 12,\n", " \"pad_token_id\": 0,\n", " \"position_embedding_type\": \"absolute\",\n", " \"problem_type\": \"multi_label_classification\",\n", " \"semsup\": true,\n", " \"token_dim\": 16,\n", " \"transformers_version\": \"4.20.0\",\n", " \"type_vocab_size\": 2,\n", " \"use_cache\": true,\n", " \"vocab_size\": 30522\n", "}\n", "\n" ] } ], "source": [ "from src import BertForSemanticEmbedding, getLabelModel\n", "from src import DataTrainingArguments, ModelArguments, CustomTrainingArguments, read_yaml_config\n", "from src import dataset_classification_type\n", "from src import SemSupDataset\n", "from transformers import AutoConfig, HfArgumentParser, AutoTokenizer\n", "import torch\n", "\n", "import json\n", "from tqdm import tqdm\n", "\n", "ARGS_FILE = 'configs/ablation_amzn_eda.yml'\n", "\n", "parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))\n", "model_args, data_args, training_args = parser.parse_dict(read_yaml_config(ARGS_FILE, output_dir = 'demo_tmp', extra_args = {}))\n", "\n", "config = AutoConfig.from_pretrained(\n", " model_args.config_name if model_args.config_name else model_args.model_name_or_path,\n", " finetuning_task=data_args.task_name,\n", " cache_dir=model_args.cache_dir,\n", " revision=model_args.model_revision,\n", " use_auth_token=True if model_args.use_auth_token else None,\n", ")\n", "\n", "config.model_name_or_path = model_args.model_name_or_path\n", "config.problem_type = dataset_classification_type[data_args.task_name]\n", "config.negative_sampling = model_args.negative_sampling\n", "config.semsup = model_args.semsup\n", "config.encoder_model_type = model_args.encoder_model_type\n", "config.arch_type = model_args.arch_type\n", "config.coil = model_args.coil\n", "config.token_dim = model_args.token_dim\n", "config.colbert = model_args.colbert\n", "\n", "label_model, label_tokenizer = getLabelModel(data_args, model_args)\n", "config.label_hidden_size = label_model.config.hidden_size\n", "model = BertForSemanticEmbedding(config)\n", "model.label_model = label_model\n", "model.label_tokenizer = label_tokenizer\n", "model.config.label2id = {l: i for i, l in enumerate(label_list)}\n", "model.config.id2label = {id: label for label, id in config.label2id.items()}\n", "\n", "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BertForSemanticEmbedding(\n", " (encoder): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (1): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (3): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (4): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (5): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (6): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (7): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (8): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (9): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (10): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (11): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (tok_proj): Linear(in_features=768, out_features=16, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (label_projection): Linear(in_features=768, out_features=512, bias=False)\n", " (label_model): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 512, padding_idx=0)\n", " (position_embeddings): Embedding(512, 512)\n", " (token_type_embeddings): Embedding(2, 512)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=512, out_features=512, bias=True)\n", " (key): Linear(in_features=512, out_features=512, bias=True)\n", " (value): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=512, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=512, out_features=2048, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=2048, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (1): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=512, out_features=512, bias=True)\n", " (key): Linear(in_features=512, out_features=512, bias=True)\n", " (value): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=512, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=512, out_features=2048, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=2048, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=512, out_features=512, bias=True)\n", " (key): Linear(in_features=512, out_features=512, bias=True)\n", " (value): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=512, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=512, out_features=2048, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=2048, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (3): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=512, out_features=512, bias=True)\n", " (key): Linear(in_features=512, out_features=512, bias=True)\n", " (value): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=512, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=512, out_features=2048, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=2048, out_features=512, bias=True)\n", " (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=512, out_features=512, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.to(device)\n", "model.eval()\n", "torch.set_grad_enabled(False)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_state_dict(torch.load('ckpt/Amzn13K/amzn_main_model.bin', map_location = device))" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "text = '''SanDisk Cruzer Blade 32GB USB Flash Drive\\nUltra-compact and portable USB flash drive,Capless design\n", "Share your photos, videos, songs and other files between computers with ease,care number:18001205899/18004195592\n", "Protect your private files with included SanDisk SecureAccess software\n", "Includes added protection of secure online backup (up to 2GB optionally available) offered by YuuWaa\n", "Password-protect your sensitive files. Customer care:IndiaSupport@sandisk.com\n", "Importer Details:Rashi Peripherals Pvt. Ltd. Rashi Complex,A Building,Survey186,Dongaripada,Poman Village,Vasai Bhiwandi Road, Dist. Thane,Maharastra 401208, India\n", "Share your work files between computers with ease\n", "Manufacturer Name & Address: SanDisk International LTD, C/O Unit 100, Airside Business Park, Lakeshore Drive, Swords, Co Dublin, Ireland.\n", "Consumer Complaint Details: indiasupport@sandisk.com/18001022055'''" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "item = tokenizer(text, padding='max_length', max_length=data_args.max_seq_length, truncation=True)\n", "item = {k:torch.tensor(v, device = device).unsqueeze(0) for k,v in item.items()}\n", "\n", "outputs_doc, logits = model.forward_input_encoder(**item)\n", "doc_reps = model.tok_proj(outputs_doc.last_hidden_state)\n", "\n", "input_ids = torch.tensor([coil_cluster_map[str(x.item())] for x in item['input_ids'][0]]).to(device).unsqueeze(0)\n", "all_logits = []\n", "for adi, ada, alr, ale in zip(all_desc_input_ids, all_desc_attention_mask, all_lab_reps, all_label_embeddings):\n", " all_logits.append(coil_fast_eval_forward(input_ids, doc_reps, logits, adi, ada, alr, ale))\n", "\n", "final_logits = sum([expit(x.cpu()) for x in all_logits]) / len(all_logits)\n", "\n", "outs = torch.topk(final_logits, k = 5)\n", "preds_dic = dict()\n", "for i,v in zip(outs.indices, outs.values):\n", " preds_dic[label_list[i]] = v.item()" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'electronics': 0.9989226460456848,\n", " 'computers & accessories': 0.981508731842041,\n", " 'computer components': 0.9518740177154541,\n", " 'computer accessories': 0.7639468312263489,\n", " 'hardware': 0.6584190726280212}" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds_dic" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([13330])" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "final_logits.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "90fcbf6f06d9a30c70fdaff45e14c5534421a599dc22a7267c486c9cb67dea6d" }, "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }