Spaces:
Runtime error
Runtime error
File size: 3,844 Bytes
85a5010 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import sys
import torch
import numpy as np
import progressbar
import os
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32")
parser.add_argument("--text_file_path", type=str)
# save configuration
parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index')
parser.add_argument("--save_index_name", type=str)
parser.add_argument("--save_mapping_dict_name", type=str,
help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text")
# inference configuration
parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP")
return parser.parse_args()
def load_batch_text(text_file_path, batch_size):
import json
with open(text_file_path) as f:
item_list = json.load(f)
text_list = []
for item in item_list:
captions = item["captions"]
for cap in captions:
text_list.append(cap)
print ('Number of text instances is {}'.format(len(text_list)))
data_num = len(text_list)
batch_num = data_num // batch_size
batch_text_list = []
s_idx, e_idx = 0, batch_size
for p_idx in range(batch_num):
one_batch_text_list = []
for idx in range(s_idx, e_idx):
one_batch_text_list.append(text_list[idx])
batch_text_list.append(one_batch_text_list)
return batch_text_list
import argparse
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = parse_config()
device = torch.device('cuda')
import os
if os.path.exists(args.save_index_prefix):
pass
else: # recursively construct directory
os.makedirs(args.save_index_prefix, exist_ok=True)
print ('Loading CLIP...')
from clip import CLIP
model = CLIP(args.clip_name)
if cuda_available:
model = model.cuda(device)
model.eval()
print ('CLIP loaded!')
print ('Loading text data...')
batch_text_list = load_batch_text(args.text_file_path, args.batch_size)
print ('Text data loaded.')
res_text_vec_list, res_text_list = [], []
batch_num = len(batch_text_list)
print ('Number of batches is {}'.format(batch_num))
print ('Start inference...')
p = progressbar.ProgressBar(batch_num)
p.start()
with torch.no_grad():
for p_idx in range(batch_num):
p.update(p_idx)
one_text_batch = batch_text_list[p_idx]
one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu()
one_batch_vec_list = one_batch_vec.unbind(dim=0)
bsz = len(one_batch_vec_list)
for k in range(bsz):
res_text_vec_list.append(one_batch_vec_list[k].numpy())
res_text_list.append(one_text_batch[k])
p.finish()
assert len(res_text_vec_list) == len(res_text_list)
print ('Inference completed!')
index_text_mapping_dict = {}
for k in range(len(res_text_list)):
index_text_mapping_dict[k] = res_text_list[k]
mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name
import json
with open(mapping_list_save_path, 'w') as outfile:
json.dump(index_text_mapping_dict, outfile, indent=4)
print ('Mapping dictionary saved!')
print ('Start buiding index...')
index_save_path = args.save_index_prefix + '/' + args.save_index_name
with open(index_save_path, 'w', encoding = 'utf8') as o:
for vec in res_text_vec_list:
one_text = ' '.join([str(num) for num in vec]).strip()
o.writelines(one_text + '\n')
print ('Index completed!')
|