import os from tqdm.auto import tqdm from utils.utils import create_client from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility from utils.get_embeddings import preprocess_image, extract_features, create_resnet18_model COLLECTION_NAME = "Resnet18" EMBEDDING_DIM = 512 IMAGE_FOLDER = "/home/nampham/Desktop/image-retrieval/data/images_mr" client = create_client() def load_collection(): check_collection = utility.has_collection(COLLECTION_NAME) if check_collection: print("Load and use collection right now!") collection = Collection(COLLECTION_NAME) collection.load() print(utility.load_state(COLLECTION_NAME)) else: print("Please create a collection and insert data!") collection = create_collection() # insert data into collection model = create_resnet18_model() insert_data(model, collection, IMAGE_FOLDER) # create index for search create_index(collection) return collection def create_collection(): image_id = FieldSchema( name="image_id", dtype=DataType.INT64, is_primary=True, description="Image ID" ) image_embedding = FieldSchema( name="image_embedding", dtype=DataType.FLOAT_VECTOR, description="Image Embedding" ) schema = CollectionSchema( fields=[image_id, image_embedding], auto_id=True, description="Image Retrieval using Resnet18" ) collection = Collection( name=COLLECTION_NAME, schema=schema ) return collection def insert_data(model, collection, image_folder): image_ids = sorted([ int(iamge_name.split('.')[0]) for image_name in os.listdir(image_folder) ]) image_embeddings = [] for image_name in tqdm(image_ids): file_name = str(image_name) + ".jpg" image_path = os.path.join(image_folder, file_name) processed_image = preprocess_image(image_path) processed_image = extract_features(model, processed_image) image_embeddings.append(processed_image) entities = [image_ids, image_embeddings] ins_resp = collection.insert(entities) collection.flush() def create_index(collection): index_params = { "index_type": "IVF_FLAT", "metric_type": "L2", "params": {} } collection.create_index( field_name=image_embedding.name, index_params=index_params ) # load collection collection.load()