In [1]:
import os
import numpy as np
import pickle
import h5py
from tqdm import tqdm
from transformers import AutoTokenizer
from scipy.special import expit 

In [2]:
def compute_tok_score_cart(doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask):
 qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1
 doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD
 exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD
 exact_match = exact_match.float()
 scores_no_masking = torch.matmul(
 qry_reps.view(-1, 16), # (Q * LQ) * d
 doc_reps.view(-1, 16).transpose(0, 1) # d * (D * LD)
 )
 scores_no_masking = scores_no_masking.view(
 *qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD
 scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D
 tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1)
 
 return tok_scores

import torch
from typing import Optional
def coil_fast_eval_forward(
 input_ids: Optional[torch.Tensor] = None,
 doc_reps = None,
 logits: Optional[torch.Tensor] = None,
 desc_input_ids = None,
 desc_attention_mask = None,
 lab_reps = None,
 label_embeddings = None
):
 tok_scores = compute_tok_score_cart(
 doc_reps, input_ids,
 lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask
 )
 logits = (logits.unsqueeze(0) @ label_embeddings.T)
 new_tok_scores = torch.zeros(logits.shape, device = logits.device)
 for i in range(tok_scores.shape[1]):
 stride = tok_scores.shape[0]//tok_scores.shape[1]
 new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i]
 return (logits + new_tok_scores).squeeze()

In [3]:
label_list = [x.strip() for x in open('datasets/Amzn13K/all_labels.txt')]
unseen_label_list = [x.strip() for x in open('datasets/Amzn13K/unseen_labels_split6500_2.txt')]
num_labels = len(label_list)
label_list.sort() # For consistency
l2i = {v: i for i, v in enumerate(label_list)}
unseen_label_indexes = [l2i[x] for x in unseen_label_list]

In [4]:
import json
coil_cluster_map = json.load(open('bert_coil_map_dict_lemma255K_isotropic.json')) 

In [22]:
label_preds = pickle.load(open('/n/fs/nlp-pranjal/SemSup-LMLC/training/ablation_amzn_1_main_labels_zsl.pkl','rb'))

In [20]:
label_preds = pickle.load(open('/n/fs/scratch/pranjal/seed_experiments/ablation_amzn_eda_labels_zsl_seed2.pkl','rb'))

In [38]:
all_lab_reps, all_label_embeddings, all_desc_input_ids, all_desc_attention_mask = [], [], [], []
for l in tqdm(label_list):
 ll = label_preds[l]
 lab_reps, label_embeddings, desc_input_ids, desc_attention_mask = ll[np.random.randint(len(ll))] 
 all_lab_reps.append(lab_reps.squeeze())
 all_label_embeddings.append(label_embeddings.squeeze())
 all_desc_input_ids.append(desc_input_ids.squeeze())
 all_desc_attention_mask.append(desc_attention_mask.squeeze())
all_lab_reps = torch.stack(all_lab_reps).cpu()
all_label_embeddings = torch.stack(all_label_embeddings).cpu()
all_desc_input_ids = torch.stack(all_desc_input_ids).cpu()
all_desc_attention_mask = torch.stack(all_desc_attention_mask).cpu()
all_desc_input_ids_clus = torch.tensor([[coil_cluster_map[str(x.item())] for x in xx] for xx in all_desc_input_ids])

100%|██████████| 13330/13330 [00:00<00:00, 64680.71it/s]


In [None]:
pickle.dump([all_lab_reps, all_label_embeddings, all_desc_input_ids, all_desc_input_ids_clus, all_desc_attention_mask], open('precomputed/Amzn13K/amzn_base_labels_data1_4.pkl','wb'))

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
all_lab_reps1, all_label_embeddings1, _, all_desc_input_ids1, all_desc_attention_mask1 = pickle.load(open('precomputed/Amzn13K/amzn_base_labels_data1.pkl','rb'))
all_lab_reps2, all_label_embeddings2, _, all_desc_input_ids2, all_desc_attention_mask2 = pickle.load(open('precomputed/Amzn13K/amzn_base_labels_data2.pkl','rb'))
all_lab_reps3, all_label_embeddings3, _, all_desc_input_ids3, all_desc_attention_mask3 = pickle.load(open('precomputed/Amzn13K/amzn_base_labels_data3.pkl','rb'))


all_lab_reps = [all_lab_reps1.to(device), all_lab_reps2.to(device), all_lab_reps3.to(device)]
all_label_embeddings = [all_label_embeddings1.to(device), all_label_embeddings2.to(device), all_label_embeddings3.to(device)]
all_desc_input_ids = [all_desc_input_ids1.to(device), all_desc_input_ids2.to(device), all_desc_input_ids3.to(device)]
all_desc_attention_mask = [all_desc_attention_mask1.to(device), all_desc_attention_mask2.to(device), all_desc_attention_mask3.to(device)]

In [8]:
from src import BertForSemanticEmbedding, getLabelModel
from src import DataTrainingArguments, ModelArguments, CustomTrainingArguments, read_yaml_config
from src import dataset_classification_type
from src import SemSupDataset
from transformers import AutoConfig, HfArgumentParser, AutoTokenizer
import torch

import json
from tqdm import tqdm

ARGS_FILE = 'configs/ablation_amzn_eda.yml'

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))
model_args, data_args, training_args = parser.parse_dict(read_yaml_config(ARGS_FILE, output_dir = 'demo_tmp', extra_args = {}))

config = AutoConfig.from_pretrained(
 model_args.config_name if model_args.config_name else model_args.model_name_or_path,
 finetuning_task=data_args.task_name,
 cache_dir=model_args.cache_dir,
 revision=model_args.model_revision,
 use_auth_token=True if model_args.use_auth_token else None,
)

config.model_name_or_path = model_args.model_name_or_path
config.problem_type = dataset_classification_type[data_args.task_name]
config.negative_sampling = model_args.negative_sampling
config.semsup = model_args.semsup
config.encoder_model_type = model_args.encoder_model_type
config.arch_type = model_args.arch_type
config.coil = model_args.coil
config.token_dim = model_args.token_dim
config.colbert = model_args.colbert

label_model, label_tokenizer = getLabelModel(data_args, model_args)
config.label_hidden_size = label_model.config.hidden_size
model = BertForSemanticEmbedding(config)
model.label_model = label_model
model.label_tokenizer = label_tokenizer
model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {id: label for label, id in config.label2id.items()}

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Yaml Config is:
--------------------------------------------------------------------------------
{'task_name': 'amazon13k', 'dataset_name': 'amazon13k', 'dataset_config_name': None, 'max_seq_length': 160, 'overwrite_output_dir': False, 'overwrite_cache': False, 'pad_to_max_length': True, 'load_from_local': True, 'max_train_samples': None, 'max_eval_samples': 15000, 'max_predict_samples': None, 'train_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/train_split6500_2.jsonl', 'validation_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/test_unseen_split6500_2.jsonl', 'test_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/test_unseen_split6500_2.jsonl', 'label_max_seq_length': 160, 'descriptions_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/heir_withdescriptions_v3_v3_unseen_edaaug.json', 'test_descriptions_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/heir_withdescriptions_v3_v3.json', 'all_labels': '/n/fs/n

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.b

Config is BertConfig {
 "_name_or_path": "bert-base-uncased",
 "arch_type": 2,
 "architectures": [
 "BertForMaskedLM"
 ],
 "attention_probs_dropout_prob": 0.1,
 "classifier_dropout": null,
 "coil": true,
 "colbert": false,
 "encoder_model_type": "bert",
 "finetuning_task": "amazon13k",
 "gradient_checkpointing": false,
 "hidden_act": "gelu",
 "hidden_dropout_prob": 0.1,
 "hidden_size": 768,
 "initializer_range": 0.02,
 "intermediate_size": 3072,
 "label_hidden_size": 512,
 "layer_norm_eps": 1e-12,
 "max_position_embeddings": 512,
 "model_name_or_path": "bert-base-uncased",
 "model_type": "bert",
 "negative_sampling": "none",
 "num_attention_heads": 12,
 "num_hidden_layers": 12,
 "pad_token_id": 0,
 "position_embedding_type": "absolute",
 "problem_type": "multi_label_classification",
 "semsup": true,
 "token_dim": 16,
 "transformers_version": "4.20.0",
 "type_vocab_size": 2,
 "use_cache": true,
 "vocab_size": 30522
}



In [9]:
model.to(device)
model.eval()
torch.set_grad_enabled(False)

BertForSemanticEmbedding(
 (encoder): BertModel(
 (embeddings): BertEmbeddings(
 (word_embeddings): Embedding(30522, 768, padding_idx=0)
 (position_embeddings): Embedding(512, 768)
 (token_type_embeddings): Embedding(2, 768)
 (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 (dropout): Dropout(p=0.1, inplace=False)
 )
 (encoder): BertEncoder(
 (layer): ModuleList(
 (0): BertLayer(
 (attention): BertAttention(
 (self): BertSelfAttention(
 (query): Linear(in_features=768, out_features=768, bias=True)
 (key): Linear(in_features=768, out_features=768, bias=True)
 (value): Linear(in_features=768, out_features=768, bias=True)
 (dropout): Dropout(p=0.1, inplace=False)
 )
 (output): BertSelfOutput(
 (dense): Linear(in_features=768, out_features=768, bias=True)
 (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 (dropout): Dropout(p=0.1, inplace=False)
 )
 )
 (intermediate): BertIntermediate(
 (dense): Linear(in_features=768, out_features=3072, bias=True)
 (in

In [65]:
model.load_state_dict(torch.load('ckpt/Amzn13K/amzn_main_model.bin', map_location = device))



In [88]:
text = '''SanDisk Cruzer Blade 32GB USB Flash Drive\nUltra-compact and portable USB flash drive,Capless design
Share your photos, videos, songs and other files between computers with ease,care number:18001205899/18004195592
Protect your private files with included SanDisk SecureAccess software
Includes added protection of secure online backup (up to 2GB optionally available) offered by YuuWaa
Password-protect your sensitive files. Customer care:IndiaSupport@sandisk.com
Importer Details:Rashi Peripherals Pvt. Ltd. Rashi Complex,A Building,Survey186,Dongaripada,Poman Village,Vasai Bhiwandi Road, Dist. Thane,Maharastra 401208, India
Share your work files between computers with ease
Manufacturer Name & Address: SanDisk International LTD, C/O Unit 100, Airside Business Park, Lakeshore Drive, Swords, Co Dublin, Ireland.
Consumer Complaint Details: indiasupport@sandisk.com/18001022055'''

In [89]:
item = tokenizer(text, padding='max_length', max_length=data_args.max_seq_length, truncation=True)
item = {k:torch.tensor(v, device = device).unsqueeze(0) for k,v in item.items()}

outputs_doc, logits = model.forward_input_encoder(**item)
doc_reps = model.tok_proj(outputs_doc.last_hidden_state)

input_ids = torch.tensor([coil_cluster_map[str(x.item())] for x in item['input_ids'][0]]).to(device).unsqueeze(0)
all_logits = []
for adi, ada, alr, ale in zip(all_desc_input_ids, all_desc_attention_mask, all_lab_reps, all_label_embeddings):
 all_logits.append(coil_fast_eval_forward(input_ids, doc_reps, logits, adi, ada, alr, ale))

final_logits = sum([expit(x.cpu()) for x in all_logits]) / len(all_logits)

outs = torch.topk(final_logits, k = 5)
preds_dic = dict()
for i,v in zip(outs.indices, outs.values):
 preds_dic[label_list[i]] = v.item()

In [90]:
preds_dic

{'electronics': 0.9989226460456848,
 'computers & accessories': 0.981508731842041,
 'computer components': 0.9518740177154541,
 'computer accessories': 0.7639468312263489,
 'hardware': 0.6584190726280212}

In [78]:
final_logits.shape

torch.Size([13330])