Spaces:
Sleeping
Sleeping
import os | |
from tqdm.auto import tqdm | |
from utils.utils import create_client | |
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility | |
from utils.get_embeddings import preprocess_image, extract_features, create_resnet18_model | |
COLLECTION_NAME = "Resnet18" | |
EMBEDDING_DIM = 512 | |
IMAGE_FOLDER = "/home/nampham/Desktop/image-retrieval/data/images_mr" | |
client = create_client() | |
def load_collection(): | |
check_collection = utility.has_collection(COLLECTION_NAME) | |
if check_collection: | |
print("Load and use collection right now!") | |
collection = Collection(COLLECTION_NAME) | |
collection.load() | |
print(utility.load_state(COLLECTION_NAME)) | |
else: | |
print("Please create a collection and insert data!") | |
collection = create_collection() | |
# insert data into collection | |
model = create_resnet18_model() | |
insert_data(model, collection, IMAGE_FOLDER) | |
# create index for search | |
create_index(collection) | |
return collection | |
def create_collection(): | |
image_id = FieldSchema( | |
name="image_id", | |
dtype=DataType.INT64, | |
is_primary=True, | |
description="Image ID" | |
) | |
image_embedding = FieldSchema( | |
name="image_embedding", | |
dtype=DataType.FLOAT_VECTOR, | |
description="Image Embedding" | |
) | |
schema = CollectionSchema( | |
fields=[image_id, image_embedding], | |
auto_id=True, | |
description="Image Retrieval using Resnet18" | |
) | |
collection = Collection( | |
name=COLLECTION_NAME, | |
schema=schema | |
) | |
return collection | |
def insert_data(model, collection, image_folder): | |
image_ids = sorted([ | |
int(iamge_name.split('.')[0]) for image_name in os.listdir(image_folder) | |
]) | |
image_embeddings = [] | |
for image_name in tqdm(image_ids): | |
file_name = str(image_name) + ".jpg" | |
image_path = os.path.join(image_folder, file_name) | |
processed_image = preprocess_image(image_path) | |
processed_image = extract_features(model, processed_image) | |
image_embeddings.append(processed_image) | |
entities = [image_ids, image_embeddings] | |
ins_resp = collection.insert(entities) | |
collection.flush() | |
def create_index(collection): | |
index_params = { | |
"index_type": "IVF_FLAT", | |
"metric_type": "L2", | |
"params": {} | |
} | |
collection.create_index( | |
field_name=image_embedding.name, | |
index_params=index_params | |
) | |
# load collection | |
collection.load() |