import argparse from tqdm import tqdm import faiss from embeddings import FaissIndex from models import CLIP def main(file, index_type): clip = CLIP() with open(file) as f: references = f.read().split("\n") index = FaissIndex( embedding_size=768, faiss_index_location=f"faiss_indices/{index_type}.index", indexer=faiss.IndexFlatIP, ) index.reset() if len(references) < 500: ref_embeddings = clip.get_text_emb(references) index.add(ref_embeddings.detach().numpy(), references) else: batches = list(range(0, len(references), 300)) + [len(references)] batched_objects = [] for idx in range(0, len(batches) - 1): batched_objects.append(references[batches[idx] : batches[idx + 1]]) for batch in tqdm(batched_objects): ref_embeddings = clip.get_text_emb(batch) index.add(ref_embeddings.detach().numpy(), batch) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("file", type=str, help="File containing references") parser.add_argument("index_type", type=str, choices=["places", "objects"]) args = parser.parse_args() main(args.file, args.index_type)