|
import argparse |
|
from tqdm import tqdm |
|
|
|
import faiss |
|
|
|
from embeddings import FaissIndex |
|
from models import CLIP |
|
|
|
|
|
def main(file, index_type): |
|
|
|
clip = CLIP() |
|
with open(file) as f: |
|
references = f.read().split("\n") |
|
|
|
index = FaissIndex( |
|
embedding_size=768, |
|
faiss_index_location=f"faiss_indices/{index_type}.index", |
|
indexer=faiss.IndexFlatIP, |
|
) |
|
index.reset() |
|
|
|
if len(references) < 500: |
|
ref_embeddings = clip.get_text_emb(references) |
|
index.add(ref_embeddings.detach().numpy(), references) |
|
else: |
|
|
|
batches = list(range(0, len(references), 300)) + [len(references)] |
|
batched_objects = [] |
|
for idx in range(0, len(batches) - 1): |
|
batched_objects.append(references[batches[idx] : batches[idx + 1]]) |
|
|
|
for batch in tqdm(batched_objects): |
|
ref_embeddings = clip.get_text_emb(batch) |
|
index.add(ref_embeddings.detach().numpy(), batch) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("file", type=str, help="File containing references") |
|
parser.add_argument("index_type", type=str, choices=["places", "objects"]) |
|
args = parser.parse_args() |
|
|
|
main(args.file, args.index_type) |
|
|