RoboApocalypse commited on
Commit
1272949
1 Parent(s): d228b46

Add OpenCLIP embedding generator app and dependencies

Browse files
Files changed (3) hide show
  1. .gitignore +23 -0
  2. app.py +168 -0
  3. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python virtual environment
2
+ venv/
3
+ .venv/
4
+
5
+ # Compiled Python files
6
+ *.pyc
7
+
8
+ # Logs
9
+ *.log
10
+
11
+ # Gradio app output files
12
+ output/
13
+ flagged/
14
+
15
+ # IDE and editor files
16
+ .vscode/
17
+ .idea/
18
+ *.iml
19
+
20
+ # Dependency directories
21
+ __pycache__/
22
+ dist/
23
+ build/
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from numpy import empty
3
+ import open_clip
4
+ from regex import F
5
+ import torch
6
+ import json
7
+ import PIL
8
+
9
+ # Set device to GPU if available
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ # Load the OpenCLIP model and the necessary preprocessors
13
+ # openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
14
+ # openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
15
+ openclip_model = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
16
+ openclip_model = 'hf-hub:' + openclip_model
17
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
18
+ model_name=openclip_model,
19
+ device=device
20
+ )
21
+
22
+
23
+ def generate_embedding(text_data, image_data):
24
+ """
25
+ Generate embeddings for text and image data using the OpenCLIP model.
26
+
27
+ Parameters
28
+ ----------
29
+ text_data : str or tuple of str
30
+ Text data to embed.
31
+ image_data : PIL.Image.Image or tuple of PIL.Image.Image
32
+ Image data to embed.
33
+
34
+ Returns
35
+ -------
36
+ text_embeddings : list of str
37
+ List of text embeddings.
38
+ image_embeddings : list of str
39
+ List of image embeddings.
40
+ similarity : list of str
41
+ List of cosine similarity between text and image embeddings.
42
+ """
43
+
44
+ # Embed text data
45
+ text_embeddings = []
46
+ empty_text_indices = []
47
+ if text_data:
48
+ # If text_data is a string, convert to list of strings
49
+ if isinstance(text_data, str):
50
+ text_data = [text_data]
51
+
52
+ # If text_data is a tuple of strings, convert to list of strings
53
+ if isinstance(text_data, tuple):
54
+ text_data = list(text_data)
55
+
56
+ # Keep track of indices of empty text strings
57
+ empty_text_indices = [i for i, text in enumerate(text_data) if text == ""]
58
+
59
+ # Remove empty text strings
60
+ text_data = [text for text in text_data if text != ""]
61
+
62
+ if text_data:
63
+ # Tokenize text_data and convert to tensor
64
+ text_data = open_clip.tokenize(text_data).to(device)
65
+
66
+ # Generate text embeddings
67
+ with torch.no_grad():
68
+ text_embeddings = model.encode_text(text_data)
69
+
70
+ # Convert embeddings to list of strings
71
+ text_embeddings = [embedding.detach().cpu().numpy().tolist() for embedding in text_embeddings]
72
+
73
+ # Insert empty strings at indices of empty text strings
74
+ for i in empty_text_indices:
75
+ text_embeddings.insert(i, "")
76
+
77
+ # Embed image data
78
+ image_embeddings = []
79
+ empty_image_indices = []
80
+ if image_data:
81
+ # If image_data is a single PIL image, convert to list of PIL images
82
+ if isinstance(image_data, PIL.Image.Image):
83
+ image_data = [image_data]
84
+
85
+ # If image_data is a tuple of images, convert to list of images
86
+ if isinstance(image_data, tuple):
87
+ image_data = list(image_data)
88
+
89
+ # Keep track of indices of None images
90
+ empty_image_indices = [i for i, img in enumerate(image_data) if img is None]
91
+
92
+ # Remove None images
93
+ image_data = [img for img in image_data if img is not None]
94
+
95
+ if image_data:
96
+ # Preprocess image_data and convert to tensor
97
+ image_data = [preprocess_val(img).unsqueeze(0) for img in image_data]
98
+ image_data = torch.stack(image_data).squeeze(1).to(device)
99
+
100
+ # Generate image embeddings
101
+ with torch.no_grad():
102
+ image_embeddings = model.encode_image(image_data)
103
+
104
+ # Convert embeddings to list of strings
105
+ image_embeddings = [embedding.detach().cpu().numpy().tolist() for embedding in image_embeddings]
106
+
107
+ # Insert empty strings at indices of empty images
108
+ for i in empty_image_indices:
109
+ image_embeddings.insert(i, "")
110
+
111
+ # Calculate cosine similarity between text and image embeddings
112
+ similarity = []
113
+ empty_similarity_indices = []
114
+ if text_embeddings and image_embeddings:
115
+ # Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings
116
+ text_embeddings_filtered = []
117
+ image_embeddings_filtered = []
118
+ for i, (text_embedding, image_embedding) in enumerate(zip(text_embeddings, image_embeddings)):
119
+ if text_embedding != "" and image_embedding != "":
120
+ text_embeddings_filtered.append(text_embedding)
121
+ image_embeddings_filtered.append(image_embedding)
122
+ else:
123
+ empty_similarity_indices.append(i)
124
+
125
+ # Calculate cosine similarity if there are any non-empty embedding pairs
126
+ if image_embeddings_filtered and text_embeddings_filtered:
127
+ # Convert lists back to tensors for processing
128
+ text_embeddings_tensor = torch.tensor(text_embeddings_filtered)
129
+ image_embeddings_tensor = torch.tensor(image_embeddings_filtered)
130
+
131
+ # Normalize the embeddings
132
+ text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True)
133
+ image_embedding_norm = image_embeddings_tensor / image_embeddings_tensor.norm(dim=-1, keepdim=True)
134
+
135
+ # Calculate cosine similarity
136
+ similarity = torch.nn.functional.cosine_similarity(text_embedding_norm, image_embedding_norm, dim=-1)
137
+ # Convert to percentage as text
138
+ similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity]
139
+
140
+ # Insert empty text strings in similarity
141
+ for i in empty_similarity_indices:
142
+ similarity.insert(i, "")
143
+
144
+ return (text_embeddings, image_embeddings, similarity)
145
+
146
+
147
+ # Define Gradio interface
148
+ demo = gr.Interface(
149
+ fn=generate_embedding,
150
+ inputs=[
151
+ gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
152
+ gr.Image(height=512, type="pil", label="Image to Embed")
153
+ ],
154
+ outputs=[
155
+ gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
156
+ gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
157
+ gr.Textbox(label="Cosine Similarity")
158
+ ],
159
+ title="OpenCLIP Embedding Generator",
160
+ description="Generate embeddings using OpenCLIP model for text and images.",
161
+ allow_flagging="never",
162
+ batch=True,
163
+ api_name="embed"
164
+ )
165
+
166
+ # Enable queueing and launch the app
167
+ if __name__ == "__main__":
168
+ demo.queue().launch(show_api=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ open_clip_torch
3
+ torch