xguman commited on
Commit
875ae89
1 Parent(s): 37f6b71
Files changed (4) hide show
  1. app.py +54 -0
  2. create_image_embeddnigs.py +46 -0
  3. download_dataset.py +25 -0
  4. requirements.txt +223 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import os
7
+ import numpy as np
8
+ import warnings
9
+ from create_image_embeddnigs import create_embeddings
10
+ from download_dataset import download_images
11
+
12
+ warnings.filterwarnings("ignore", category=UserWarning)
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "true" # or "false"
14
+
15
+ download_images()
16
+
17
+ # get image embeddings. If file «image_embeddings.npy» exists, just load it, otherwise create it
18
+ if os.path.exists("image_embeddings.npy"):
19
+ image_embeddings = np.load("image_embeddings.npy")
20
+ else:
21
+ image_dir = "data/pictures"
22
+ batch_size = 32
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ image_embeddings = create_embeddings(image_dir, batch_size, device)
25
+
26
+ image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
27
+
28
+ def get_text_embeddings(input_text):
29
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
30
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
31
+ inputs = processor(text=input_text, return_tensors="pt", padding=True, truncation=True)
32
+ embeddings = model.get_text_features(**inputs)
33
+ vector = embeddings.detach().numpy().ravel()
34
+ return vector / np.linalg.norm(vector)
35
+
36
+ def cosine_similarity(a, b):
37
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
38
+
39
+ def find_similar_images(text_embedding, image_embeddings, top_k=4):
40
+ similarities = np.array([cosine_similarity(text_embedding, image_embedding) for image_embedding in image_embeddings])
41
+ top_k_indices = np.argsort(similarities)[-top_k:][::-1]
42
+ return top_k_indices
43
+
44
+ def get_similar_images(input_text):
45
+ text_embedding = get_text_embeddings(input_text)
46
+ top_k_indices = find_similar_images(text_embedding, image_embeddings)
47
+ image_paths = [os.path.join("data/pictures", f) for f in os.listdir("data/pictures") if f.endswith(('.png', '.jpg', '.jpeg'))]
48
+ similar_images = [image_paths[i] for i in top_k_indices]
49
+ return [Image.open(image_path) for image_path in similar_images]
50
+
51
+
52
+ if __name__ == "__main__":
53
+ iface = gr.Interface(fn=get_similar_images, inputs="text", outputs="gallery", title="Find Similar Images")
54
+ iface.launch(share=True)
create_image_embeddnigs.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import os
6
+ import numpy as np
7
+ import warnings
8
+
9
+ warnings.filterwarnings("ignore", category=UserWarning)
10
+
11
+ class ImageDataset(Dataset):
12
+ def __init__(self, image_dir, processor):
13
+ self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
14
+ self.processor = processor
15
+
16
+ def __len__(self):
17
+ return len(self.image_paths)
18
+
19
+ def __getitem__(self, idx):
20
+ image = Image.open(self.image_paths[idx])
21
+ return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
22
+
23
+ def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'):
24
+ # Load the CLIP model and processor
25
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
26
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
27
+
28
+ # Create dataset and dataloader
29
+ dataset = ImageDataset(image_dir, processor)
30
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
31
+
32
+ all_embeddings = []
33
+
34
+ model.eval()
35
+ with torch.no_grad():
36
+ for batch in dataloader:
37
+ batch = batch.to(device)
38
+ embeddings = model.get_image_features(pixel_values=batch)
39
+ all_embeddings.append(embeddings.cpu().numpy())
40
+
41
+ return np.concatenate(all_embeddings)
42
+
43
+ def create_embeddings(image_dir, batch_size, device):
44
+ embeddings = get_clip_embeddings_batch(image_dir, batch_size, device)
45
+ np.save("image_embeddings.npy", embeddings)
46
+ return embeddings
download_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datasets import load_dataset
3
+ from PIL import Image
4
+
5
+ def download_images(name="food101"):
6
+ # Load the "food101" dataset
7
+ dataset = load_dataset(name, split="train[:1%]") # Get a small percentage of the data
8
+
9
+ # Create a directory to save the images if it doesn't exist
10
+ output_dir = "data/pictures"
11
+ os.makedirs(output_dir, exist_ok=True)
12
+
13
+ # Limit to 200 images
14
+ num_images = 200
15
+ count = 0
16
+
17
+ # Iterate over the dataset and save the images
18
+ for example in dataset:
19
+ if count >= num_images:
20
+ break
21
+ image = example['image']
22
+ image.save(os.path.join(output_dir, f"image_{count}.jpg")) # Save as JPG
23
+ count += 1
24
+
25
+ print(f"Downloaded and saved {count} images to the folder '{output_dir}'")
requirements.txt ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - _libgcc_mutex=0.1
2
+ - _openmp_mutex=5.1
3
+ - bzip2=1.0.8
4
+ - ca-certificates=2024.7.2
5
+ - ld_impl_linux-64=2.38
6
+ - libffi=3.4.4
7
+ - libgcc-ng=11.2.0
8
+ - libgomp=11.2.0
9
+ - libstdcxx-ng=11.2.0
10
+ - libuuid=1.41.5
11
+ - ncurses=6.4
12
+ - openssl=1.1.1w
13
+ - pip=24.2
14
+ - python=3.10.10
15
+ - readline=8.2
16
+ - setuptools=72.1.0
17
+ - sqlite=3.45.3
18
+ - tk=8.6.14
19
+ - wheel=0.43.0
20
+ - xz=5.4.6
21
+ - zlib=1.2.13
22
+ - absl-py==2.1.0
23
+ - aiofiles==23.2.1
24
+ - aiohappyeyeballs==2.4.0
25
+ - aiohttp==3.10.5
26
+ - aiosignal==1.3.1
27
+ - annotated-types==0.7.0
28
+ - anyio==4.4.0
29
+ - argon2-cffi==23.1.0
30
+ - argon2-cffi-bindings==21.2.0
31
+ - arrow==1.3.0
32
+ - asttokens==2.4.1
33
+ - async-lru==2.0.4
34
+ - async-timeout==4.0.3
35
+ - attrs==24.2.0
36
+ - babel==2.16.0
37
+ - backoff==2.2.1
38
+ - beautifulsoup4==4.12.3
39
+ - bleach==6.1.0
40
+ - boto3==1.35.12
41
+ - botocore==1.35.12
42
+ - cachetools==5.5.0
43
+ - certifi==2024.8.30
44
+ - cffi==1.17.1
45
+ - charset-normalizer==3.3.2
46
+ - click==8.1.7
47
+ - comm==0.2.2
48
+ - contourpy==1.3.0
49
+ - cycler==0.12.1
50
+ - datasets==3.0.1
51
+ - debugpy==1.8.5
52
+ - decorator==5.1.1
53
+ - defusedxml==0.7.1
54
+ - dill==0.3.8
55
+ - exceptiongroup==1.2.2
56
+ - executing==2.1.0
57
+ - fastapi==0.113.0
58
+ - fastjsonschema==2.20.0
59
+ - ffmpy==0.4.0
60
+ - filelock==3.15.4
61
+ - fire==0.6.0
62
+ - fonttools==4.53.1
63
+ - fqdn==1.5.1
64
+ - frozenlist==1.4.1
65
+ - fsspec==2024.6.1
66
+ - google-auth==2.34.0
67
+ - google-auth-oauthlib==1.2.1
68
+ - gradio==5.1.0
69
+ - gradio-client==1.4.0
70
+ - grpcio==1.66.1
71
+ - h11==0.14.0
72
+ - httpcore==1.0.5
73
+ - httpx==0.27.2
74
+ - huggingface-hub==0.26.0
75
+ - idna==3.8
76
+ - ipykernel==6.26.0
77
+ - ipython==8.17.2
78
+ - ipywidgets==8.1.1
79
+ - isoduration==20.11.0
80
+ - jedi==0.19.1
81
+ - jinja2==3.1.4
82
+ - jmespath==1.0.1
83
+ - joblib==1.4.2
84
+ - json5==0.9.25
85
+ - jsonpointer==3.0.0
86
+ - jsonschema==4.23.0
87
+ - jsonschema-specifications==2023.12.1
88
+ - jupyter-client==8.6.2
89
+ - jupyter-core==5.7.2
90
+ - jupyter-events==0.10.0
91
+ - jupyter-lsp==2.2.5
92
+ - jupyter-server==2.14.2
93
+ - jupyter-server-terminals==0.5.3
94
+ - jupyterlab==4.2.0
95
+ - jupyterlab-pygments==0.3.0
96
+ - jupyterlab-server==2.27.3
97
+ - jupyterlab-widgets==3.0.13
98
+ - kiwisolver==1.4.7
99
+ - lightning==2.4.0
100
+ - lightning-cloud==0.5.70
101
+ - lightning-sdk==0.1.15
102
+ - lightning-utilities==0.11.7
103
+ - litdata==0.2.19
104
+ - litserve==0.2.2
105
+ - markdown==3.7
106
+ - markdown-it-py==3.0.0
107
+ - markupsafe==2.1.5
108
+ - matplotlib==3.8.2
109
+ - matplotlib-inline==0.1.7
110
+ - mdurl==0.1.2
111
+ - mistune==3.0.2
112
+ - mpmath==1.3.0
113
+ - multidict==6.0.5
114
+ - multiprocess==0.70.16
115
+ - nbclient==0.10.0
116
+ - nbconvert==7.16.4
117
+ - nbformat==5.10.4
118
+ - nest-asyncio==1.6.0
119
+ - networkx==3.3
120
+ - notebook-shim==0.2.4
121
+ - numpy==1.26.4
122
+ - nvidia-cublas-cu12==12.1.3.1
123
+ - nvidia-cuda-cupti-cu12==12.1.105
124
+ - nvidia-cuda-nvrtc-cu12==12.1.105
125
+ - nvidia-cuda-runtime-cu12==12.1.105
126
+ - nvidia-cudnn-cu12==8.9.2.26
127
+ - nvidia-cufft-cu12==11.0.2.54
128
+ - nvidia-curand-cu12==10.3.2.106
129
+ - nvidia-cusolver-cu12==11.4.5.107
130
+ - nvidia-cusparse-cu12==12.1.0.106
131
+ - nvidia-nccl-cu12==2.19.3
132
+ - nvidia-nvjitlink-cu12==12.6.68
133
+ - nvidia-nvtx-cu12==12.1.105
134
+ - oauthlib==3.2.2
135
+ - orjson==3.10.7
136
+ - overrides==7.7.0
137
+ - packaging==24.1
138
+ - pandas==2.1.4
139
+ - pandocfilters==1.5.1
140
+ - parso==0.8.4
141
+ - pexpect==4.9.0
142
+ - pillow==10.4.0
143
+ - platformdirs==4.2.2
144
+ - prometheus-client==0.20.0
145
+ - prompt-toolkit==3.0.47
146
+ - protobuf==4.23.4
147
+ - psutil==6.0.0
148
+ - ptyprocess==0.7.0
149
+ - pure-eval==0.2.3
150
+ - pyarrow==17.0.0
151
+ - pyasn1==0.6.0
152
+ - pyasn1-modules==0.4.0
153
+ - pycparser==2.22
154
+ - pydantic==2.9.0
155
+ - pydantic-core==2.23.2
156
+ - pydub==0.25.1
157
+ - pygments==2.18.0
158
+ - pyjwt==2.9.0
159
+ - pyparsing==3.1.4
160
+ - python-dateutil==2.9.0.post0
161
+ - python-json-logger==2.0.7
162
+ - python-multipart==0.0.9
163
+ - pytorch-lightning==2.4.0
164
+ - pytz==2024.1
165
+ - pyyaml==6.0.2
166
+ - pyzmq==26.2.0
167
+ - referencing==0.35.1
168
+ - regex==2024.9.11
169
+ - requests==2.32.3
170
+ - requests-oauthlib==2.0.0
171
+ - rfc3339-validator==0.1.4
172
+ - rfc3986-validator==0.1.1
173
+ - rich==13.8.0
174
+ - rpds-py==0.20.0
175
+ - rsa==4.9
176
+ - ruff==0.7.0
177
+ - s3transfer==0.10.2
178
+ - safetensors==0.4.5
179
+ - scikit-learn==1.3.2
180
+ - scipy==1.11.4
181
+ - semantic-version==2.10.0
182
+ - send2trash==1.8.3
183
+ - shellingham==1.5.4
184
+ - simple-term-menu==1.6.4
185
+ - six==1.16.0
186
+ - sniffio==1.3.1
187
+ - soupsieve==2.6
188
+ - stack-data==0.6.3
189
+ - starlette==0.38.4
190
+ - sympy==1.13.2
191
+ - tensorboard==2.15.1
192
+ - tensorboard-data-server==0.7.2
193
+ - termcolor==2.4.0
194
+ - terminado==0.18.1
195
+ - threadpoolctl==3.5.0
196
+ - tinycss2==1.3.0
197
+ - tokenizers==0.20.1
198
+ - tomli==2.0.1
199
+ - tomlkit==0.12.0
200
+ - torch==2.2.1+cu121
201
+ - torchmetrics==1.3.1
202
+ - torchvision==0.17.1+cu121
203
+ - tornado==6.4.1
204
+ - tqdm==4.66.5
205
+ - traitlets==5.14.3
206
+ - transformers==4.45.2
207
+ - triton==2.2.0
208
+ - typer==0.12.5
209
+ - types-python-dateutil==2.9.0.20240821
210
+ - typing-extensions==4.12.2
211
+ - tzdata==2024.1
212
+ - uri-template==1.3.0
213
+ - urllib3==2.2.2
214
+ - uvicorn==0.30.6
215
+ - wcwidth==0.2.13
216
+ - webcolors==24.8.0
217
+ - webencodings==0.5.1
218
+ - websocket-client==1.8.0
219
+ - websockets==12.0
220
+ - werkzeug==3.0.4
221
+ - widgetsnbextension==4.0.13
222
+ - xxhash==3.5.0
223
+ - yarl==1.9.11