haor commited on
Commit
cc4db04
·
verified ·
1 Parent(s): d22b376

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import logging
5
+ from PIL import Image
6
+ from tensorflow.keras.preprocessing import image as keras_image
7
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
8
+ from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
9
+ import scipy.fftpack
10
+ import time
11
+ import clip
12
+ import torch
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ # Load models
18
+ resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
19
+ vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
20
+ clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")
21
+
22
+ # Preprocess function
23
+ def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
24
+ start_time = time.time()
25
+ img = keras_image.load_img(img_path, target_size=target_size)
26
+ img_array = keras_image.img_to_array(img)
27
+ img_array = np.expand_dims(img_array, axis=0)
28
+ img_array = preprocess_func(img_array)
29
+ logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
30
+ return img_array
31
+
32
+ # Feature extraction function
33
+ def extract_features(img_path, model, preprocess_func):
34
+ img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
35
+ start_time = time.time()
36
+ features = model.predict(img_array)
37
+ logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
38
+ return features.flatten()
39
+
40
+ # Calculate cosine similarity
41
+ def cosine_similarity(vec1, vec2):
42
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
43
+
44
+ # pHash related functions
45
+ def phashstr(image, hash_size=8, highfreq_factor=4):
46
+ img_size = hash_size * highfreq_factor
47
+ image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
48
+ pixels = np.asarray(image)
49
+ dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
50
+ dctlowfreq = dct[:hash_size, :hash_size]
51
+ med = np.median(dctlowfreq)
52
+ diff = dctlowfreq > med
53
+ return _binary_array_to_hex(diff.flatten())
54
+
55
+ def _binary_array_to_hex(arr):
56
+ h = 0
57
+ s = []
58
+ for i, v in enumerate(arr):
59
+ if v:
60
+ h += 2**(i % 8)
61
+ if (i % 8) == 7:
62
+ s.append(hex(h)[2:].rjust(2, '0'))
63
+ h = 0
64
+ return ''.join(s)
65
+
66
+ def hamming_distance(hash1, hash2):
67
+ if len(hash1) != len(hash2):
68
+ raise ValueError("Hashes must be of the same length")
69
+ return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
70
+
71
+ def hamming_to_similarity(distance, hash_length):
72
+ return (1 - distance / hash_length) * 100
73
+
74
+ # CLIP related functions
75
+ def extract_clip_features(image_path, model, preprocess):
76
+ image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
77
+ with torch.no_grad():
78
+ features = model.encode_image(image)
79
+ return features.cpu().numpy().flatten()
80
+
81
+ # Main function
82
+ def compare_images(image1, image2, method):
83
+ start_time = time.time()
84
+ if method == 'pHash':
85
+ img1 = Image.open(image1)
86
+ img2 = Image.open(image2)
87
+ hash1 = phashstr(img1)
88
+ hash2 = phashstr(img2)
89
+ distance = hamming_distance(hash1, hash2)
90
+ similarity = hamming_to_similarity(distance, len(hash1) * 4)
91
+ elif method == 'ResNet50':
92
+ features1 = extract_features(image1, resnet_model, resnet_preprocess)
93
+ features2 = extract_features(image2, resnet_model, resnet_preprocess)
94
+ similarity = cosine_similarity(features1, features2)
95
+ elif method == 'VGG16':
96
+ features1 = extract_features(image1, vgg_model, vgg_preprocess)
97
+ features2 = extract_features(image2, vgg_model, vgg_preprocess)
98
+ similarity = cosine_similarity(features1, features2)
99
+ elif method == 'CLIP':
100
+ features1 = extract_clip_features(image1, clip_model, preprocess_clip)
101
+ features2 = extract_clip_features(image2, clip_model, preprocess_clip)
102
+ similarity = cosine_similarity(features1, features2)
103
+
104
+ logging.info(f"Image comparison using {method} completed in {time.time() - start_time:.4f} seconds")
105
+ return similarity
106
+
107
+ # Gradio interface
108
+ demo = gr.Interface(
109
+ fn=compare_images,
110
+ inputs=[
111
+ gr.Image(type="filepath", label="Upload First Image"),
112
+ gr.Image(type="filepath", label="Upload Second Image"),
113
+ gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
114
+ ],
115
+ outputs=gr.Textbox(label="Similarity"),
116
+ title="Image Similarity Comparison",
117
+ description="Upload two images and select the comparison method.",
118
+ examples=[
119
+ ["example1.png", "example2.png", "pHash"],
120
+ ["example1.png", "example2.png", "ResNet50"],
121
+ ["example1.png", "example2.png", "VGG16"],
122
+ ["example1.png", "example2.png", "CLIP"]
123
+ ]
124
+ )
125
+
126
+ demo.launch()