Spaces:
Runtime error
Runtime error
from typing import Dict, Optional | |
import numpy as np | |
import torch | |
import itertools | |
import torch | |
from torch.utils.data import Dataset | |
import json | |
import random | |
from collections.abc import Mapping | |
from typing import Dict, Optional, List, Any, NewType | |
import pandas as pd | |
from torch.utils.data import DataLoader | |
from os.path import join | |
import os | |
import gensim.downloader | |
import h5py | |
import time | |
from tqdm import tqdm | |
def getTokenizedLabelDescriptions(data_args, desc_file, tokenizer): | |
padding = "max_length" if data_args.pad_to_max_length else False | |
max_seq_length = min(data_args.label_max_seq_length, tokenizer.model_max_length) | |
label_descs = json.load(open(desc_file, encoding = 'utf-8')) | |
return {label_key: [ | |
tokenizer( | |
desc, | |
truncation=True, | |
padding=padding, | |
max_length=max_seq_length, | |
return_tensors='pt' | |
) | |
for desc in descs[1]] for label_key, descs in label_descs.items()} | |
class SemSupDataset(Dataset): | |
def __init__(self, input_dataset, data_args, label_descriptions_file, label_to_id, id_to_label, tokenizer, clsas_descs_len = None, return_desc_embeddings = False, sampleRandom : int = -1, cl_min_positive_descs = 20, useSemSup = True, seen_labels = None, add_label_name = False, max_descs_per_label = 999999, use_precomputed_embeddings = '', bm_short_file = '', ignore_pos_labels_file = '', isTrain = True, class_descs_tokenized = None, choice_indexes = None): | |
self.input_dataset = input_dataset | |
self.sampleRandom = sampleRandom | |
self.cl_min_positive_descs = cl_min_positive_descs | |
self.semsup = useSemSup | |
self.seen_labels = seen_labels | |
self.add_label_name = add_label_name | |
self.max_descs_per_label = max_descs_per_label | |
self.use_precomputed_embeddings = use_precomputed_embeddings | |
self.choice_indexes = choice_indexes | |
self.bmshortfile = bm_short_file | |
self.useBMShort = True if self.bmshortfile!='' else False | |
self.data_args = data_args | |
self.tok_format = 0 | |
self.isTrain = isTrain | |
# if data_args.large_dset: | |
# Instead of loading the | |
self.coil_cluster_map = None | |
try: | |
if data_args.coil_cluster_mapping_path: | |
self.coil_cluster_map = json.load(open(data_args.coil_cluster_mapping_path)) | |
except: | |
print('Failed to load cluster map for some reason') | |
self.coil_cluster_map = None | |
self.ignore_pos_labels_file = ignore_pos_labels_file | |
if self.ignore_pos_labels_file: | |
self.ignored_labels = [[y.strip() for y in x.split('\t') if y.strip()!=''] for x in open(self.ignore_pos_labels_file).readlines()] | |
else: | |
self.ignored_labels = False | |
if self.useBMShort and not data_args.large_dset: | |
self.shortlists = [[y.strip() for y in x.split('\t')] for x in open(self.bmshortfile).readlines()] | |
if self.semsup and not data_args.large_dset: | |
self.data_args = data_args | |
self.label_descriptions_file = label_descriptions_file | |
self.label_to_id = label_to_id | |
self.id_to_label = id_to_label | |
if self.seen_labels is not None and isinstance(self.seen_labels[0], str): | |
self.seen_labels = np.array([self.label_to_id[x] for x in self.seen_labels]) | |
self.tokenizer = tokenizer | |
if class_descs_len is None: | |
js_file = json.load(open(self.label_descriptions_file, encoding = 'utf-8')) | |
self.class_descs_len = self.tokenize_class_descs(js_file, return_lengths = True) | |
self.class_descs = self.tokenize_class_descs(js_file) | |
else: | |
self.class_descs_len = class_descs_len | |
self.return_desc_embeddings = return_desc_embeddings | |
self.label_max_seq_length = data_args.label_max_seq_length | |
if return_desc_embeddings: | |
self.save_tokenized_descs(self.add_label_name) | |
if self.use_precomputed_embeddings: | |
self.computed_desc_inputs_embeds = torch.from_numpy(np.load(self.use_precomputed_embeddings)) | |
if self.semsup and data_args.large_dset: | |
self.data_args = data_args | |
self.label_descriptions_file = label_descriptions_file | |
self.label_to_id = label_to_id | |
self.id_to_label = id_to_label | |
# No concept of seen labels over here, directly load the shortlists | |
self.tokenizer = tokenizer | |
self.return_desc_embeddings = return_desc_embeddings | |
self.label_max_seq_length = data_args.label_max_seq_length | |
to_save = True | |
if os.path.exists(data_args.tokenized_descs_file): | |
print('Path Exists') | |
if data_args.tok_format == 1: | |
self.tok_format = 1 | |
if class_descs_tokenized is not None: | |
self.class_descs_tokenized = class_descs_tokenized | |
else: | |
if data_args.tokenized_descs_file.endswith('h5'): | |
self.class_descs_tokenized = h5py.File(data_args.tokenized_descs_file) # np.load(data_args.tokenized_descs_file, allow_pickle=True).item() | |
self.tok_format = 1 | |
else: | |
self.class_descs_tokenized = np.load(data_args.tokenized_descs_file, allow_pickle=True) | |
# TODO: Fix this hardcoding | |
# if len(arr) < int(1e6): | |
# to_save = True # Possibly Corrupt File | |
# # All set, load the file | |
# else: | |
to_save = False | |
js_file = json.load(open(self.label_descriptions_file, encoding = 'utf-8')) | |
print('Loaded js File') | |
self.class_descs_len = self.tokenize_class_descs(js_file, return_lengths = True) | |
if to_save: | |
self.class_descs = self.tokenize_class_descs(js_file) | |
print('Begin Tokenization Process') | |
self.save_tokenized_descs(self.add_label_name) | |
print('Saving Tokenized Descriptions') | |
import pickle | |
pickle.dump(self.class_descs_tokenized, open(data_args.tokenized_descs_file,'wb')) | |
print(len(self.class_descs_tokenized)) | |
3/0 | |
file = h5py.File(data_args.tokenized_descs_file,'w') | |
for key in tqdm(self.class_descs_tokenized): | |
key_h5 = key | |
if key.find('/') != -1: | |
print('There may be issue with', key) | |
key_h5 = key.replace('/','\/') | |
file.create_dataset(key_h5+'/'+'input_ids', data = np.array(self.class_descs_tokenized[key]['input_ids'])) | |
file[key_h5].create_dataset('attention_mask', data = np.array(self.class_descs_tokenized[key]['attention_mask'])) | |
# else: | |
# self.class_descs_tokenized = np.load(data_args.tokenized_descs_file).item() | |
if isTrain: | |
self.shortlists = h5py.File(data_args.train_tfidf_short)['data'] | |
else: | |
print('Testtt File Loaded') | |
self.shortlists = h5py.File(data_args.test_tfidf_short)['data'] | |
try: | |
del self.class_descs | |
except: ... | |
if self.tok_format != 1: | |
self.class_descs_tokenized = pd.DataFrame({k: [np.array(x) for i, x in enumerate(v.values()) if i != 1] for k,v in self.class_descs_tokenized.items()}) | |
def tokenize_class_descs(self, label_descs, return_lengths = False): | |
if return_lengths == 1: | |
return { | |
label_key: min(descs[0],self.max_descs_per_label) for label_key, descs in label_descs.items() | |
} # descs 0 is the length | |
else: | |
return { | |
label_key: descs[1][:self.max_descs_per_label] for label_key, descs in label_descs.items() | |
} | |
def save_tokenized_descs(self, add_label_name = False): | |
self.class_descs_tokenized = dict() | |
for label_key in tqdm(list(self.class_descs.keys())): | |
descs_len = self.class_descs_len[label_key] | |
descs = self.class_descs[label_key] | |
self.class_descs_tokenized[label_key] = self.tokenizer( | |
[label_key + ". " + x for x in descs] if add_label_name else | |
descs, | |
max_length = self.label_max_seq_length, padding = 'max_length', truncation= True) | |
# del self.class_descs_tokenized[label_key]['token_type_ids'] | |
def __len__(self): | |
return len(self.input_dataset) | |
def get_item_for_large_dset(self, idx, item): | |
if self.choice_indexes is not None: | |
idx = int(self.choice_indexes[idx]) | |
# print(idx) | |
shortlists = self.shortlists[idx] | |
labels_new = item['label'] | |
if self.sampleRandom != -1: | |
if self.sampleRandom < len(shortlists): | |
shortlists = np.random.choice(shortlists, self.sampleRandom, replace = False) | |
elif self.sampleRandom > len(shortlists): | |
# randomly choose from all remaining labels | |
shortlists = shortlists.tolist() + [self.label_to_id[x] for x in np.random.choice(self.seen_labels, self.sampleRandom - len(shortlists), replace = False)] | |
if self.isTrain: | |
pos_labels = np.where(np.array(labels_new) == 1)[0] | |
item['all_candidate_labels'] = np.unique(np.concatenate([pos_labels, shortlists]))[:len(shortlists)] | |
else: | |
item['all_candidate_labels'] = np.unique(shortlists) | |
if self.sampleRandom!=-1: | |
if len(item['all_candidate_labels']) < self.sampleRandom: | |
# Duplicate entries were deleted, manually add some duplicates :) | |
item['all_candidate_labels'] = np.concatenate([item['all_candidate_labels'], item['all_candidate_labels'][len(item['all_candidate_labels'])-self.sampleRandom:]]) | |
item['all_candidate_labels'] = item['all_candidate_labels'][:self.sampleRandom] | |
l1 = len(item['all_candidate_labels']) | |
if self.ignored_labels: | |
# Remove the ignored labels | |
# After removing make sure the size is equal to l1, by randomly duplicating elements | |
ignore_list = {self.label_to_id[x] for x in self.ignored_labels} | |
if len(ignore_list) > 0: | |
item['all_candidate_labels'] = set(item['all_candidate_labels'].tolist()).difference(ignore_list) | |
item['all_candidate_labels'] = sorted(list(item['all_candidate_labels'])) | |
if len(item['all_candidate_labels']) < l: | |
item['all_candidate_labels'] += item['all_candidate_labels'][:l - len(item['all_candidate_labels'])] | |
item['all_candidate_labels'] = np.array(item['all_candidate_labels']) | |
# l1 = np.array(item['label']).sum() | |
item['label'] = np.array(item['label'])[item['all_candidate_labels']] | |
# print(f'{item["label"].sum()} / {l1}') | |
item['label_desc_ids'] = [np.random.randint(0, self.class_descs_len[self.id_to_label[label_key]]) for label_key in item['all_candidate_labels']] | |
if self.tok_format ==1: | |
item['desc_input_ids'] = [self.class_descs_tokenized['input_ids'][label_key][item['label_desc_ids'][i]].astype(np.int32) for i, label_key in enumerate(item['all_candidate_labels'])] | |
item['desc_attention_mask'] = [self.class_descs_tokenized['attention_mask'][label_key][item['label_desc_ids'][i]].astype(np.int32) for i, label_key in enumerate(item['all_candidate_labels'])] | |
else: | |
item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][i]] for i, label_key in enumerate(item['all_candidate_labels'])] | |
item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][i]] for i, label_key in enumerate(item['all_candidate_labels'])] | |
pos_pts = item['label'].nonzero()[0] | |
# if len(pos_pts) > 0: | |
# print(idx, item['desc_input_ids'][pos_pts[0]]) | |
if self.coil_cluster_map: | |
map_to_cluster = lambda x : self.coil_cluster_map[str(x)] | |
if isinstance(item['input_ids'], list): | |
item['clustered_input_ids'] = [self.coil_cluster_map[str(x)] for x in item['input_ids']] | |
else: | |
item['clustered_input_ids'] = item['input_ids'].vectorize(map_to_cluster) | |
item['clustered_desc_ids'] = [[self.coil_cluster_map[str(x)] for x in xx] for xx in item['desc_input_ids']] | |
return item | |
def __getitem__(self, idx): | |
item = self.input_dataset.__getitem__(idx) | |
if self.data_args.large_dset: | |
return self.get_item_for_large_dset(idx, item) | |
# Iterate over all the labels of input_dataset | |
# and add random label_description to the item in the same order | |
if self.ignored_labels: | |
ignored_labels = self.ignored_labels[idx] | |
if self.sampleRandom != -1: | |
# Create all_candidate_labels | |
if self.seen_labels is None: | |
labels_new = item['label'] | |
else: | |
labels_new = np.array(item['label'])[self.seen_labels] | |
if self.useBMShort: | |
# Instead of choosing randomly, choose 60% topmost most from the shortlist | |
# Next sample the remaining random entries | |
if self.seen_labels is not None: | |
# from pdb import set_trace as bp | |
# bp() | |
all_candidate_labels = [self.seen_labels.tolist().index(self.label_to_id[x]) for x in self.shortlists[idx] if self.label_to_id[x] in self.seen_labels][:int(0.8*self.sampleRandom)] | |
# print(f'BM got: {len(all_candidate_labels)}') | |
# Choose the remaining randomly from set of seen_labels - all_candidates | |
all_candidate_labels += np.random.choice(list({x for x in range(len(self.seen_labels))}.difference(set(all_candidate_labels))), self.sampleRandom - len(all_candidate_labels), replace = False).tolist() | |
else: | |
all_candidate_labels = np.random.choice(range(len(labels_new)) , self.sampleRandom , replace = False) | |
# prepend positive labels | |
pos_labels = np.where(np.array(labels_new) == 1)[0] | |
all_candidate_labels = np.concatenate([pos_labels, all_candidate_labels]) | |
# Remove duplicates | |
all_candidate_labels = np.unique(all_candidate_labels)[:self.sampleRandom] | |
if len(pos_labels) < self.cl_min_positive_descs: | |
addn_pos_labels = np.random.choice(pos_labels, self.cl_min_positive_descs - len(pos_labels)) | |
all_candidate_labels = np.concatenate([addn_pos_labels, all_candidate_labels])[:self.sampleRandom] | |
np.random.shuffle(all_candidate_labels) | |
item['all_candidate_labels'] = all_candidate_labels | |
# NOTE: ids will be according to seen labels | |
# Now update the labels based on all_candidate_labels | |
# print('Getting Data') | |
if self.semsup: | |
# print(len(item['label'])) | |
if 'all_candidate_labels' not in item: | |
item['label_desc_ids'] = [np.random.randint(0, self.class_descs_len[self.id_to_label[label_key]]) for label_key in range(len(item['label']))] | |
if self.return_desc_embeddings: | |
item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
if self.use_precomputed_embeddings: | |
new_indices = [i*5 + x for i,x in enumerate(item['label_desc_ids'])] | |
# item['desc_inputs_embeds'] = [self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] ] for label_key in range(len(item['label']))] | |
# item['desc_inputs_embeds'] = self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] for label_key in range(len(item['label']))] | |
if self.seen_labels is not None: | |
new_indices = [x for i, x in enumerate(new_indices) if i in self.seen_labels] | |
item['desc_inputs_embeds'] = self.computed_desc_inputs_embeds[new_indices] | |
item['all_candidate_labels'] = range(len(item['label'])) | |
if self.seen_labels is not None: | |
item['label_desc_ids'] = (np.array(item['label_desc_ids'])[self.seen_labels]).tolist() | |
if self.return_desc_embeddings: | |
item['desc_input_ids'] = (np.array(item['desc_input_ids']))[self.seen_labels].tolist() | |
item['desc_attention_mask'] = (np.array(item['desc_attention_mask']))[self.seen_labels].tolist() | |
# if self.use_precomputed_embeddings: | |
# item['desc_inputs_embeds'] = torch.tensor(item['desc_inputs_embeds'])[self.seen_labels] | |
item['all_candidate_labels'] = (np.array(item['all_candidate_labels']))[self.seen_labels].tolist() | |
item['label'] = (np.array(item['label']))[self.seen_labels].tolist() | |
elif 'all_candidate_labels' in item: | |
# print('Computing') | |
st = time.time() | |
item['label_desc_ids'] = [np.random.randint(0, self.class_descs_len[self.id_to_label[label_key]]) for label_key in range(len(item['label']))] | |
if self.seen_labels is not None: | |
if self.return_desc_embeddings: | |
item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
if self.use_precomputed_embeddings: | |
new_indices = [i*5 + x for i,x in enumerate(item['label_desc_ids'])] | |
# Now of the 4271 labels, chose only the seen labels | |
new_indices = [x for i, x in enumerate(new_indices) if i in self.seen_labels] | |
# Now choose all_candidate labels | |
# print(len(new_indices)) | |
new_indices = [new_indices[x] for x in sorted(item['all_candidate_labels'])] | |
# print(len(new_indices), len(item['all_candidate_labels'])) | |
# if len(new_indices)!=1500: | |
# print('Some Issue Over Here') | |
item['desc_inputs_embeds'] = self.computed_desc_inputs_embeds[new_indices] | |
# [self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] ] for label_key in range(len(item['label']))] | |
# print('Mid Calculation Done', item['desc_inputs_embeds'].shape, time.time() - st) | |
item['label_desc_ids'] = np.array(item['label_desc_ids'])[self.seen_labels].tolist() | |
item['label'] = np.array(item['label'])[self.seen_labels].tolist() | |
item['label'] = np.array(item['label'])[all_candidate_labels].tolist() | |
item['desc_input_ids'] = np.array(item['desc_input_ids'])[self.seen_labels][item['all_candidate_labels']].tolist() | |
item['desc_attention_mask'] = np.array(item['desc_attention_mask'])[self.seen_labels][item['all_candidate_labels']].tolist() | |
# if self.use_precomputed_embeddings: | |
# print('Starting Final Compute', time.time() - st) | |
# item['desc_inputs_embeds'] = item['desc_inputs_embeds'][self.seen_labels][item['all_candidate_labels']]#.tolist() | |
# print('Computed', type(item['desc_inputs_embeds']), type(item['desc_inputs_embeds'][0]), time.time() - st) | |
else: | |
item['label'] = np.array(item['label'])[all_candidate_labels].tolist() | |
if self.return_desc_embeddings: | |
item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][label_key]] for label_key in np.array(item['all_candidate_labels'])] | |
item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][label_key]] for label_key in np.array(item['all_candidate_labels'])] | |
if self.use_precomputed_embeddings: | |
item['desc_inputs_embeds'] = [self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] ] for label_key in np.array(item['all_candidate_labels'])] | |
if self.ignored_labels: | |
if self.sampleRandom != -1 and self.seen_labels is not None: | |
ignored_labels = [self.seen_labels.tolist().index(self.label_to_id[x]) for x in self.ignored_labels[idx]] | |
item['all_candidate_labels'] = item['all_candidate_labels'].tolist() | |
else: | |
ignored_labels = [self.label_to_id[x] for x in self.ignored_labels[idx]] | |
remove_pts = [item['all_candidate_labels'].index(x) for x in ignored_labels if x in item['all_candidate_labels']] | |
keep_pts = [x for x in range(len(item['all_candidate_labels'])) if x not in remove_pts] | |
# Keep pts can be less than sampleRandom. Manually pad after choosing some values | |
# print('Before Len', len(keep_pts), len(item['desc_input_ids'])) | |
if self.sampleRandom!=-1 and len(keep_pts) < self.sampleRandom: | |
# print('Inside the choice function') | |
keep_pts += np.random.choice(keep_pts, self.sampleRandom - len(keep_pts), replace = False).tolist() | |
# print('After Len', len(keep_pts), len(item['desc_input_ids'])) | |
# print(len(keep_pts), max(keep_pts)) | |
item['desc_input_ids'] = np.array(item['desc_input_ids'])[keep_pts].tolist() | |
item['desc_attention_mask'] = np.array(item['desc_attention_mask'])[keep_pts].tolist() | |
if 'desc_inputs_embeds' in item: | |
item['desc_inputs_embeds'] = np.array(item['desc_inputs_embeds'])[keep_pts].tolist() | |
item['label_desc_ids'] = np.array(item['label_desc_ids'])[keep_pts].tolist() | |
item['label'] = np.array(item['label'])[keep_pts].tolist() | |
if self.coil_cluster_map: | |
map_to_cluster = lambda x : self.coil_cluster_map[str(x)] | |
if isinstance(item['input_ids'], list): | |
item['clustered_input_ids'] = [self.coil_cluster_map[str(x)] for x in item['input_ids']] | |
else: | |
item['clustered_input_ids'] = item['input_ids'].vectorize(map_to_cluster) | |
item['clustered_desc_ids'] = [[self.coil_cluster_map[str(x)] for x in xx] for xx in item['desc_input_ids']] | |
return item | |
else: | |
return item | |