v1
Browse files- app.py +54 -0
- create_image_embeddnigs.py +46 -0
- download_dataset.py +25 -0
- requirements.txt +223 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import CLIPProcessor, CLIPModel
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import warnings
|
9 |
+
from create_image_embeddnigs import create_embeddings
|
10 |
+
from download_dataset import download_images
|
11 |
+
|
12 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true" # or "false"
|
14 |
+
|
15 |
+
download_images()
|
16 |
+
|
17 |
+
# get image embeddings. If file «image_embeddings.npy» exists, just load it, otherwise create it
|
18 |
+
if os.path.exists("image_embeddings.npy"):
|
19 |
+
image_embeddings = np.load("image_embeddings.npy")
|
20 |
+
else:
|
21 |
+
image_dir = "data/pictures"
|
22 |
+
batch_size = 32
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
image_embeddings = create_embeddings(image_dir, batch_size, device)
|
25 |
+
|
26 |
+
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
27 |
+
|
28 |
+
def get_text_embeddings(input_text):
|
29 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
30 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
31 |
+
inputs = processor(text=input_text, return_tensors="pt", padding=True, truncation=True)
|
32 |
+
embeddings = model.get_text_features(**inputs)
|
33 |
+
vector = embeddings.detach().numpy().ravel()
|
34 |
+
return vector / np.linalg.norm(vector)
|
35 |
+
|
36 |
+
def cosine_similarity(a, b):
|
37 |
+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
38 |
+
|
39 |
+
def find_similar_images(text_embedding, image_embeddings, top_k=4):
|
40 |
+
similarities = np.array([cosine_similarity(text_embedding, image_embedding) for image_embedding in image_embeddings])
|
41 |
+
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
|
42 |
+
return top_k_indices
|
43 |
+
|
44 |
+
def get_similar_images(input_text):
|
45 |
+
text_embedding = get_text_embeddings(input_text)
|
46 |
+
top_k_indices = find_similar_images(text_embedding, image_embeddings)
|
47 |
+
image_paths = [os.path.join("data/pictures", f) for f in os.listdir("data/pictures") if f.endswith(('.png', '.jpg', '.jpeg'))]
|
48 |
+
similar_images = [image_paths[i] for i in top_k_indices]
|
49 |
+
return [Image.open(image_path) for image_path in similar_images]
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
iface = gr.Interface(fn=get_similar_images, inputs="text", outputs="gallery", title="Find Similar Images")
|
54 |
+
iface.launch(share=True)
|
create_image_embeddnigs.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import CLIPProcessor, CLIPModel
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
10 |
+
|
11 |
+
class ImageDataset(Dataset):
|
12 |
+
def __init__(self, image_dir, processor):
|
13 |
+
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
|
14 |
+
self.processor = processor
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.image_paths)
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
image = Image.open(self.image_paths[idx])
|
21 |
+
return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
|
22 |
+
|
23 |
+
def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'):
|
24 |
+
# Load the CLIP model and processor
|
25 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
26 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
27 |
+
|
28 |
+
# Create dataset and dataloader
|
29 |
+
dataset = ImageDataset(image_dir, processor)
|
30 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
31 |
+
|
32 |
+
all_embeddings = []
|
33 |
+
|
34 |
+
model.eval()
|
35 |
+
with torch.no_grad():
|
36 |
+
for batch in dataloader:
|
37 |
+
batch = batch.to(device)
|
38 |
+
embeddings = model.get_image_features(pixel_values=batch)
|
39 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
40 |
+
|
41 |
+
return np.concatenate(all_embeddings)
|
42 |
+
|
43 |
+
def create_embeddings(image_dir, batch_size, device):
|
44 |
+
embeddings = get_clip_embeddings_batch(image_dir, batch_size, device)
|
45 |
+
np.save("image_embeddings.npy", embeddings)
|
46 |
+
return embeddings
|
download_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datasets import load_dataset
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
def download_images(name="food101"):
|
6 |
+
# Load the "food101" dataset
|
7 |
+
dataset = load_dataset(name, split="train[:1%]") # Get a small percentage of the data
|
8 |
+
|
9 |
+
# Create a directory to save the images if it doesn't exist
|
10 |
+
output_dir = "data/pictures"
|
11 |
+
os.makedirs(output_dir, exist_ok=True)
|
12 |
+
|
13 |
+
# Limit to 200 images
|
14 |
+
num_images = 200
|
15 |
+
count = 0
|
16 |
+
|
17 |
+
# Iterate over the dataset and save the images
|
18 |
+
for example in dataset:
|
19 |
+
if count >= num_images:
|
20 |
+
break
|
21 |
+
image = example['image']
|
22 |
+
image.save(os.path.join(output_dir, f"image_{count}.jpg")) # Save as JPG
|
23 |
+
count += 1
|
24 |
+
|
25 |
+
print(f"Downloaded and saved {count} images to the folder '{output_dir}'")
|
requirements.txt
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- _libgcc_mutex=0.1
|
2 |
+
- _openmp_mutex=5.1
|
3 |
+
- bzip2=1.0.8
|
4 |
+
- ca-certificates=2024.7.2
|
5 |
+
- ld_impl_linux-64=2.38
|
6 |
+
- libffi=3.4.4
|
7 |
+
- libgcc-ng=11.2.0
|
8 |
+
- libgomp=11.2.0
|
9 |
+
- libstdcxx-ng=11.2.0
|
10 |
+
- libuuid=1.41.5
|
11 |
+
- ncurses=6.4
|
12 |
+
- openssl=1.1.1w
|
13 |
+
- pip=24.2
|
14 |
+
- python=3.10.10
|
15 |
+
- readline=8.2
|
16 |
+
- setuptools=72.1.0
|
17 |
+
- sqlite=3.45.3
|
18 |
+
- tk=8.6.14
|
19 |
+
- wheel=0.43.0
|
20 |
+
- xz=5.4.6
|
21 |
+
- zlib=1.2.13
|
22 |
+
- absl-py==2.1.0
|
23 |
+
- aiofiles==23.2.1
|
24 |
+
- aiohappyeyeballs==2.4.0
|
25 |
+
- aiohttp==3.10.5
|
26 |
+
- aiosignal==1.3.1
|
27 |
+
- annotated-types==0.7.0
|
28 |
+
- anyio==4.4.0
|
29 |
+
- argon2-cffi==23.1.0
|
30 |
+
- argon2-cffi-bindings==21.2.0
|
31 |
+
- arrow==1.3.0
|
32 |
+
- asttokens==2.4.1
|
33 |
+
- async-lru==2.0.4
|
34 |
+
- async-timeout==4.0.3
|
35 |
+
- attrs==24.2.0
|
36 |
+
- babel==2.16.0
|
37 |
+
- backoff==2.2.1
|
38 |
+
- beautifulsoup4==4.12.3
|
39 |
+
- bleach==6.1.0
|
40 |
+
- boto3==1.35.12
|
41 |
+
- botocore==1.35.12
|
42 |
+
- cachetools==5.5.0
|
43 |
+
- certifi==2024.8.30
|
44 |
+
- cffi==1.17.1
|
45 |
+
- charset-normalizer==3.3.2
|
46 |
+
- click==8.1.7
|
47 |
+
- comm==0.2.2
|
48 |
+
- contourpy==1.3.0
|
49 |
+
- cycler==0.12.1
|
50 |
+
- datasets==3.0.1
|
51 |
+
- debugpy==1.8.5
|
52 |
+
- decorator==5.1.1
|
53 |
+
- defusedxml==0.7.1
|
54 |
+
- dill==0.3.8
|
55 |
+
- exceptiongroup==1.2.2
|
56 |
+
- executing==2.1.0
|
57 |
+
- fastapi==0.113.0
|
58 |
+
- fastjsonschema==2.20.0
|
59 |
+
- ffmpy==0.4.0
|
60 |
+
- filelock==3.15.4
|
61 |
+
- fire==0.6.0
|
62 |
+
- fonttools==4.53.1
|
63 |
+
- fqdn==1.5.1
|
64 |
+
- frozenlist==1.4.1
|
65 |
+
- fsspec==2024.6.1
|
66 |
+
- google-auth==2.34.0
|
67 |
+
- google-auth-oauthlib==1.2.1
|
68 |
+
- gradio==5.1.0
|
69 |
+
- gradio-client==1.4.0
|
70 |
+
- grpcio==1.66.1
|
71 |
+
- h11==0.14.0
|
72 |
+
- httpcore==1.0.5
|
73 |
+
- httpx==0.27.2
|
74 |
+
- huggingface-hub==0.26.0
|
75 |
+
- idna==3.8
|
76 |
+
- ipykernel==6.26.0
|
77 |
+
- ipython==8.17.2
|
78 |
+
- ipywidgets==8.1.1
|
79 |
+
- isoduration==20.11.0
|
80 |
+
- jedi==0.19.1
|
81 |
+
- jinja2==3.1.4
|
82 |
+
- jmespath==1.0.1
|
83 |
+
- joblib==1.4.2
|
84 |
+
- json5==0.9.25
|
85 |
+
- jsonpointer==3.0.0
|
86 |
+
- jsonschema==4.23.0
|
87 |
+
- jsonschema-specifications==2023.12.1
|
88 |
+
- jupyter-client==8.6.2
|
89 |
+
- jupyter-core==5.7.2
|
90 |
+
- jupyter-events==0.10.0
|
91 |
+
- jupyter-lsp==2.2.5
|
92 |
+
- jupyter-server==2.14.2
|
93 |
+
- jupyter-server-terminals==0.5.3
|
94 |
+
- jupyterlab==4.2.0
|
95 |
+
- jupyterlab-pygments==0.3.0
|
96 |
+
- jupyterlab-server==2.27.3
|
97 |
+
- jupyterlab-widgets==3.0.13
|
98 |
+
- kiwisolver==1.4.7
|
99 |
+
- lightning==2.4.0
|
100 |
+
- lightning-cloud==0.5.70
|
101 |
+
- lightning-sdk==0.1.15
|
102 |
+
- lightning-utilities==0.11.7
|
103 |
+
- litdata==0.2.19
|
104 |
+
- litserve==0.2.2
|
105 |
+
- markdown==3.7
|
106 |
+
- markdown-it-py==3.0.0
|
107 |
+
- markupsafe==2.1.5
|
108 |
+
- matplotlib==3.8.2
|
109 |
+
- matplotlib-inline==0.1.7
|
110 |
+
- mdurl==0.1.2
|
111 |
+
- mistune==3.0.2
|
112 |
+
- mpmath==1.3.0
|
113 |
+
- multidict==6.0.5
|
114 |
+
- multiprocess==0.70.16
|
115 |
+
- nbclient==0.10.0
|
116 |
+
- nbconvert==7.16.4
|
117 |
+
- nbformat==5.10.4
|
118 |
+
- nest-asyncio==1.6.0
|
119 |
+
- networkx==3.3
|
120 |
+
- notebook-shim==0.2.4
|
121 |
+
- numpy==1.26.4
|
122 |
+
- nvidia-cublas-cu12==12.1.3.1
|
123 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
124 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
125 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
126 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
127 |
+
- nvidia-cufft-cu12==11.0.2.54
|
128 |
+
- nvidia-curand-cu12==10.3.2.106
|
129 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
130 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
131 |
+
- nvidia-nccl-cu12==2.19.3
|
132 |
+
- nvidia-nvjitlink-cu12==12.6.68
|
133 |
+
- nvidia-nvtx-cu12==12.1.105
|
134 |
+
- oauthlib==3.2.2
|
135 |
+
- orjson==3.10.7
|
136 |
+
- overrides==7.7.0
|
137 |
+
- packaging==24.1
|
138 |
+
- pandas==2.1.4
|
139 |
+
- pandocfilters==1.5.1
|
140 |
+
- parso==0.8.4
|
141 |
+
- pexpect==4.9.0
|
142 |
+
- pillow==10.4.0
|
143 |
+
- platformdirs==4.2.2
|
144 |
+
- prometheus-client==0.20.0
|
145 |
+
- prompt-toolkit==3.0.47
|
146 |
+
- protobuf==4.23.4
|
147 |
+
- psutil==6.0.0
|
148 |
+
- ptyprocess==0.7.0
|
149 |
+
- pure-eval==0.2.3
|
150 |
+
- pyarrow==17.0.0
|
151 |
+
- pyasn1==0.6.0
|
152 |
+
- pyasn1-modules==0.4.0
|
153 |
+
- pycparser==2.22
|
154 |
+
- pydantic==2.9.0
|
155 |
+
- pydantic-core==2.23.2
|
156 |
+
- pydub==0.25.1
|
157 |
+
- pygments==2.18.0
|
158 |
+
- pyjwt==2.9.0
|
159 |
+
- pyparsing==3.1.4
|
160 |
+
- python-dateutil==2.9.0.post0
|
161 |
+
- python-json-logger==2.0.7
|
162 |
+
- python-multipart==0.0.9
|
163 |
+
- pytorch-lightning==2.4.0
|
164 |
+
- pytz==2024.1
|
165 |
+
- pyyaml==6.0.2
|
166 |
+
- pyzmq==26.2.0
|
167 |
+
- referencing==0.35.1
|
168 |
+
- regex==2024.9.11
|
169 |
+
- requests==2.32.3
|
170 |
+
- requests-oauthlib==2.0.0
|
171 |
+
- rfc3339-validator==0.1.4
|
172 |
+
- rfc3986-validator==0.1.1
|
173 |
+
- rich==13.8.0
|
174 |
+
- rpds-py==0.20.0
|
175 |
+
- rsa==4.9
|
176 |
+
- ruff==0.7.0
|
177 |
+
- s3transfer==0.10.2
|
178 |
+
- safetensors==0.4.5
|
179 |
+
- scikit-learn==1.3.2
|
180 |
+
- scipy==1.11.4
|
181 |
+
- semantic-version==2.10.0
|
182 |
+
- send2trash==1.8.3
|
183 |
+
- shellingham==1.5.4
|
184 |
+
- simple-term-menu==1.6.4
|
185 |
+
- six==1.16.0
|
186 |
+
- sniffio==1.3.1
|
187 |
+
- soupsieve==2.6
|
188 |
+
- stack-data==0.6.3
|
189 |
+
- starlette==0.38.4
|
190 |
+
- sympy==1.13.2
|
191 |
+
- tensorboard==2.15.1
|
192 |
+
- tensorboard-data-server==0.7.2
|
193 |
+
- termcolor==2.4.0
|
194 |
+
- terminado==0.18.1
|
195 |
+
- threadpoolctl==3.5.0
|
196 |
+
- tinycss2==1.3.0
|
197 |
+
- tokenizers==0.20.1
|
198 |
+
- tomli==2.0.1
|
199 |
+
- tomlkit==0.12.0
|
200 |
+
- torch==2.2.1+cu121
|
201 |
+
- torchmetrics==1.3.1
|
202 |
+
- torchvision==0.17.1+cu121
|
203 |
+
- tornado==6.4.1
|
204 |
+
- tqdm==4.66.5
|
205 |
+
- traitlets==5.14.3
|
206 |
+
- transformers==4.45.2
|
207 |
+
- triton==2.2.0
|
208 |
+
- typer==0.12.5
|
209 |
+
- types-python-dateutil==2.9.0.20240821
|
210 |
+
- typing-extensions==4.12.2
|
211 |
+
- tzdata==2024.1
|
212 |
+
- uri-template==1.3.0
|
213 |
+
- urllib3==2.2.2
|
214 |
+
- uvicorn==0.30.6
|
215 |
+
- wcwidth==0.2.13
|
216 |
+
- webcolors==24.8.0
|
217 |
+
- webencodings==0.5.1
|
218 |
+
- websocket-client==1.8.0
|
219 |
+
- websockets==12.0
|
220 |
+
- werkzeug==3.0.4
|
221 |
+
- widgetsnbextension==4.0.13
|
222 |
+
- xxhash==3.5.0
|
223 |
+
- yarl==1.9.11
|