Spaces:
Runtime error
Runtime error
File size: 2,227 Bytes
5f65b55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
# Load pre-trained ResNet-50 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None)
model.eval()
# Define image transformation
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Directory containing images
images_dir = "picture/"
# List all image files in directory
image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
if not image_files:
print("No images found in directory")
else:
# Dictionary to store feature vectors
feature_dict = {}
# Loop through images in the directory
for filename in image_files:
# Load image
image_path = os.path.join(images_dir, filename)
with Image.open(image_path) as img:
img = transform(img).unsqueeze(0)
# Extract features from penultimate layer
with torch.no_grad():
features = model(img)
features = torch.squeeze(features).detach().numpy()
feature_dict[filename] = features
# Convert dictionary of feature vectors to array
feature_array = np.array(list(feature_dict.values()))
if len(feature_array) == 0:
print("No feature vectors extracted")
else:
# Fit nearest neighbor model on feature vectors
nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array)
# Loop through images again to query nearest neighbors
for query_filename in image_files:
query_feature = feature_dict[query_filename]
distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1))
print("Query image:", query_filename)
print("Most similar images:")
for i, idx in enumerate(indices[0]):
if i == 0:
continue # Skip first index, as it will always be the query image itself
print(image_files[idx])
print("-----")
|