File size: 2,227 Bytes
5f65b55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np


# Load pre-trained ResNet-50 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None)
model.eval()

# Define image transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Directory containing images
images_dir = "picture/"

# List all image files in directory
image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]

if not image_files:
    print("No images found in directory")
else:
    # Dictionary to store feature vectors
    feature_dict = {}

    # Loop through images in the directory
    for filename in image_files:
        # Load image
        image_path = os.path.join(images_dir, filename)
        with Image.open(image_path) as img:
            img = transform(img).unsqueeze(0)

        # Extract features from penultimate layer
        with torch.no_grad():
            features = model(img)
            features = torch.squeeze(features).detach().numpy()

        feature_dict[filename] = features

    # Convert dictionary of feature vectors to array
    feature_array = np.array(list(feature_dict.values()))

    if len(feature_array) == 0:
        print("No feature vectors extracted")
    else:
        # Fit nearest neighbor model on feature vectors
        nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array)

        # Loop through images again to query nearest neighbors
        for query_filename in image_files:
            query_feature = feature_dict[query_filename]
            distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1))

            print("Query image:", query_filename)
            print("Most similar images:")
            for i, idx in enumerate(indices[0]):
                if i == 0:
                    continue  # Skip first index, as it will always be the query image itself
                print(image_files[idx])
            print("-----")