yuxi-liu-wired commited on
Commit
1b1d8c3
1 Parent(s): e435be0

example usage

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **.parquet
2
+ **.json
3
+ .ipynb_checkpoints
examples/example.ipynb ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "32b7d029-64ce-4361-acde-dc72d67637b7",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import copy\n",
13
+ "import os\n",
14
+ "import io\n",
15
+ "import torch\n",
16
+ "import torch.nn as nn\n",
17
+ "import clip\n",
18
+ "import pandas as pd\n",
19
+ "from PIL import Image\n",
20
+ "from tqdm import tqdm\n",
21
+ "import numpy as np\n",
22
+ "from transformers import Pipeline, CLIPProcessor, CLIPVisionModel\n",
23
+ "from huggingface_hub import PyTorchModelHubMixin\n",
24
+ "from typing import List, Union\n",
25
+ "from transformers import PretrainedConfig\n",
26
+ "import json\n",
27
+ "import safetensors\n",
28
+ "\n",
29
+ "class CSDCLIPConfig(PretrainedConfig):\n",
30
+ " model_type = \"csd_clip\"\n",
31
+ "\n",
32
+ " def __init__(\n",
33
+ " self,\n",
34
+ " name=\"csd_large\",\n",
35
+ " embedding_dim=1024,\n",
36
+ " feature_dim=1024,\n",
37
+ " content_dim=768,\n",
38
+ " style_dim=768,\n",
39
+ " content_proj_head=\"default\",\n",
40
+ " **kwargs\n",
41
+ " ):\n",
42
+ " super().__init__(**kwargs)\n",
43
+ " self.name = name\n",
44
+ " self.embedding_dim = embedding_dim\n",
45
+ " self.content_proj_head = content_proj_head\n",
46
+ " self.task_specific_params = None # Add this line\n",
47
+ "\n",
48
+ "class CSD_CLIP(nn.Module, PyTorchModelHubMixin):\n",
49
+ " \"\"\"backbone + projection head\"\"\"\n",
50
+ " def __init__(self, name='vit_large',content_proj_head='default'):\n",
51
+ " super(CSD_CLIP, self).__init__()\n",
52
+ " self.content_proj_head = content_proj_head\n",
53
+ " if name == 'vit_large':\n",
54
+ " clipmodel, _ = clip.load(\"ViT-L/14\")\n",
55
+ " self.backbone = clipmodel.visual\n",
56
+ " self.embedding_dim = 1024\n",
57
+ " self.feature_dim = 1024\n",
58
+ " self.content_dim = 768\n",
59
+ " self.style_dim = 768\n",
60
+ " self.name = \"csd_large\"\n",
61
+ " elif name == 'vit_base':\n",
62
+ " clipmodel, _ = clip.load(\"ViT-B/16\")\n",
63
+ " self.backbone = clipmodel.visual\n",
64
+ " self.embedding_dim = 768 \n",
65
+ " self.feature_dim = 512\n",
66
+ " self.content_dim = 512\n",
67
+ " self.style_dim = 512\n",
68
+ " self.name = \"csd_base\"\n",
69
+ " else:\n",
70
+ " raise Exception('This model is not implemented')\n",
71
+ "\n",
72
+ " self.last_layer_style = copy.deepcopy(self.backbone.proj)\n",
73
+ " self.last_layer_content = copy.deepcopy(self.backbone.proj)\n",
74
+ "\n",
75
+ " self.backbone.proj = None\n",
76
+ " \n",
77
+ " self.config = CSDCLIPConfig(\n",
78
+ " name=self.name,\n",
79
+ " embedding_dim=self.embedding_dim,\n",
80
+ " feature_dim=self.feature_dim,\n",
81
+ " content_dim=self.content_dim,\n",
82
+ " style_dim=self.style_dim,\n",
83
+ " content_proj_head=self.content_proj_head\n",
84
+ " )\n",
85
+ "\n",
86
+ " def get_config(self):\n",
87
+ " return self.config.to_dict()\n",
88
+ "\n",
89
+ " @property\n",
90
+ " def dtype(self):\n",
91
+ " return self.backbone.conv1.weight.dtype\n",
92
+ " \n",
93
+ " @property\n",
94
+ " def device(self):\n",
95
+ " return next(self.parameters()).device\n",
96
+ "\n",
97
+ " def forward(self, input_data):\n",
98
+ " \n",
99
+ " feature = self.backbone(input_data)\n",
100
+ "\n",
101
+ " style_output = feature @ self.last_layer_style\n",
102
+ " style_output = nn.functional.normalize(style_output, dim=1, p=2)\n",
103
+ "\n",
104
+ " content_output = feature @ self.last_layer_content\n",
105
+ " content_output = nn.functional.normalize(content_output, dim=1, p=2)\n",
106
+ " \n",
107
+ " return feature, content_output, style_output\n",
108
+ "\n",
109
+ "device = 'cuda'\n",
110
+ "model = CSD_CLIP.from_pretrained(\"yuxi-liu-wired/CSD\")\n",
111
+ "model.to(device);"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "bbd750f6-fde9-48ed-a7d8-42ee5d31429d",
118
+ "metadata": {
119
+ "tags": []
120
+ },
121
+ "outputs": [],
122
+ "source": [
123
+ "import torch\n",
124
+ "from transformers import Pipeline\n",
125
+ "from typing import Union, List\n",
126
+ "from PIL import Image\n",
127
+ "\n",
128
+ "class CSDCLIPPipeline(Pipeline):\n",
129
+ " def __init__(self, model, processor, device=None):\n",
130
+ " if device is None:\n",
131
+ " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
132
+ " super().__init__(model=model, tokenizer=None, device=device)\n",
133
+ " self.processor = processor\n",
134
+ "\n",
135
+ " def _sanitize_parameters(self, **kwargs):\n",
136
+ " return {}, {}, {}\n",
137
+ "\n",
138
+ " def preprocess(self, images):\n",
139
+ " if isinstance(images, (str, Image.Image)):\n",
140
+ " images = [images]\n",
141
+ " \n",
142
+ " processed = self.processor(images=images, return_tensors=\"pt\", padding=True, truncation=True)\n",
143
+ " return {k: v.to(self.device) for k, v in processed.items()}\n",
144
+ "\n",
145
+ " def _forward(self, model_inputs):\n",
146
+ " pixel_values = model_inputs['pixel_values'].to(self.model.dtype)\n",
147
+ " with torch.no_grad():\n",
148
+ " features, content_output, style_output = self.model(pixel_values)\n",
149
+ " return {\"features\": features, \"content_output\": content_output, \"style_output\": style_output}\n",
150
+ "\n",
151
+ " def postprocess(self, model_outputs):\n",
152
+ " return {\n",
153
+ " \"features\": model_outputs[\"features\"].cpu().numpy(),\n",
154
+ " \"content_output\": model_outputs[\"content_output\"].cpu().numpy(),\n",
155
+ " \"style_output\": model_outputs[\"style_output\"].cpu().numpy()\n",
156
+ " }\n",
157
+ "\n",
158
+ " def __call__(self, images: Union[str, List[str], Image.Image, List[Image.Image]]):\n",
159
+ " return super().__call__(images)\n",
160
+ "\n",
161
+ "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
162
+ "pipeline = CSDCLIPPipeline(model=model, processor=processor, device=device)"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 3,
168
+ "id": "4107999a-c48c-4cb4-9247-9836dfb27e98",
169
+ "metadata": {
170
+ "tags": []
171
+ },
172
+ "outputs": [
173
+ {
174
+ "name": "stderr",
175
+ "output_type": "stream",
176
+ "text": [
177
+ "Processing images: 100%|█████████████████████████████████████████████████████████████| 900/900 [01:09<00:00, 12.86it/s]\n"
178
+ ]
179
+ },
180
+ {
181
+ "name": "stdout",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "Processing complete. Results saved to 'processed_dataset.parquet'\n"
185
+ ]
186
+ }
187
+ ],
188
+ "source": [
189
+ "import io\n",
190
+ "from PIL import Image\n",
191
+ "import requests\n",
192
+ "from datasets import load_dataset\n",
193
+ "import pandas as pd\n",
194
+ "import numpy as np\n",
195
+ "from tqdm import tqdm\n",
196
+ "\n",
197
+ "def to_jpeg(image):\n",
198
+ " buffered = io.BytesIO()\n",
199
+ " if image.mode not in (\"RGB\"):\n",
200
+ " image = image.convert(\"RGB\")\n",
201
+ " image.save(buffered, format='JPEG')\n",
202
+ " return buffered.getvalue() \n",
203
+ "\n",
204
+ "def scale_image(image, max_resolution):\n",
205
+ " if max(image.width, image.height) > max_resolution:\n",
206
+ " image = image.resize((max_resolution, int(image.height * max_resolution / image.width)))\n",
207
+ " return image\n",
208
+ "\n",
209
+ "def process_dataset(pipeline, dataset_name, dataset_size=900, max_resolution=192):\n",
210
+ " dataset = load_dataset(dataset_name, split='train')\n",
211
+ " dataset = dataset.select(range(dataset_size))\n",
212
+ " \n",
213
+ " # Print the column names\n",
214
+ " print(\"Dataset columns:\", dataset.column_names)\n",
215
+ " \n",
216
+ " # Initialize lists to store results\n",
217
+ " embeddings = []\n",
218
+ " jpeg_images = []\n",
219
+ " \n",
220
+ " # Process each item in the dataset\n",
221
+ " for item in tqdm(dataset, desc=\"Processing images\"):\n",
222
+ " try:\n",
223
+ " img = item['image']\n",
224
+ " \n",
225
+ " # If img is a string (file path), load the image\n",
226
+ " if isinstance(img, str):\n",
227
+ " img = Image.open(img)\n",
228
+ "\n",
229
+ "\n",
230
+ " output = pipeline(img)\n",
231
+ " style_output = output[\"style_output\"].squeeze(0)\n",
232
+ " \n",
233
+ " img = scale_image(img, max_resolution)\n",
234
+ " jpeg_img = to_jpeg(img)\n",
235
+ " \n",
236
+ " # Append results to lists\n",
237
+ " embeddings.append(style_output)\n",
238
+ " jpeg_images.append(jpeg_img)\n",
239
+ " except Exception as e:\n",
240
+ " print(f\"Error processing item: {e}\")\n",
241
+ " \n",
242
+ " # Create a DataFrame with the results\n",
243
+ " df = pd.DataFrame({\n",
244
+ " 'embedding': embeddings,\n",
245
+ " 'image': jpeg_images\n",
246
+ " })\n",
247
+ " \n",
248
+ " df.to_parquet('processed_dataset.parquet')\n",
249
+ " print(\"Processing complete. Results saved to 'processed_dataset.parquet'\")\n",
250
+ "\n",
251
+ "process_dataset(pipeline, \"yuxi-liu-wired/style-content-grid-SDXL\", \n",
252
+ " dataset_size=900, max_resolution=192)"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "id": "066ec067-edb1-4110-a0fe-8d7c97311790",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": []
262
+ }
263
+ ],
264
+ "metadata": {
265
+ "kernelspec": {
266
+ "display_name": "Python [conda env:diffgan]",
267
+ "language": "python",
268
+ "name": "conda-env-diffgan-py"
269
+ },
270
+ "language_info": {
271
+ "codemirror_mode": {
272
+ "name": "ipython",
273
+ "version": 3
274
+ },
275
+ "file_extension": ".py",
276
+ "mimetype": "text/x-python",
277
+ "name": "python",
278
+ "nbconvert_exporter": "python",
279
+ "pygments_lexer": "ipython3",
280
+ "version": "3.10.14"
281
+ }
282
+ },
283
+ "nbformat": 4,
284
+ "nbformat_minor": 5
285
+ }
examples/tsne_visualization.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.manifold import TSNE
4
+ import json
5
+ import base64
6
+
7
+ def generate_tsne_embedding(input_file, output_file):
8
+ # Load the Parquet file
9
+ df = pd.read_parquet(input_file)
10
+
11
+ # Extract embeddings and convert to numpy array
12
+ embeddings = np.array(df['embedding'].tolist())
13
+
14
+ # Perform t-SNE
15
+ tsne = TSNE(n_components=2, random_state=42)
16
+ tsne_results = tsne.fit_transform(embeddings)
17
+
18
+ # Prepare output data
19
+ output_data = []
20
+ for i, (x, y) in enumerate(tsne_results):
21
+ image_base64 = base64.b64encode(df['image'][i]).decode('utf-8')
22
+ output_data.append({
23
+ 'x': float(x),
24
+ 'y': float(y),
25
+ 'image': image_base64
26
+ })
27
+
28
+ # Save results to JSON file
29
+ with open(output_file, 'w') as f:
30
+ json.dump(output_data, f)
31
+
32
+ ## ----------------------------
33
+ ## Dash app
34
+ ## ----------------------------
35
+
36
+ import os
37
+ import base64
38
+ import json
39
+ import numpy as np
40
+ from dash import dcc, html, Input, Output, no_update, Dash
41
+ import numpy as np
42
+ from sklearn.cluster import KMeans
43
+ from scipy.spatial.distance import cdist
44
+ import plotly.graph_objects as go
45
+ from PIL import Image
46
+ import random
47
+ import socket
48
+
49
+ def find_free_port():
50
+ while True:
51
+ port = random.randint(49152, 65535) # Use dynamic/private port range
52
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
53
+ try:
54
+ s.bind(('', port))
55
+ return port
56
+ except OSError:
57
+ pass
58
+
59
+ def create_dash_app(fig, images):
60
+ app = Dash(__name__)
61
+
62
+ app.layout = html.Div(
63
+ className="container",
64
+ children=[
65
+ dcc.Graph(id="graph", figure=fig, clear_on_unhover=True),
66
+ dcc.Tooltip(id="graph-tooltip", direction='bottom'),
67
+ ],
68
+ )
69
+
70
+ @app.callback(
71
+ Output("graph-tooltip", "show"),
72
+ Output("graph-tooltip", "bbox"),
73
+ Output("graph-tooltip", "children"),
74
+ Input("graph", "hoverData"),
75
+ )
76
+ def display_hover(hoverData):
77
+ if hoverData is None:
78
+ return False, no_update, no_update
79
+
80
+ hover_data = hoverData["points"][0]
81
+ bbox = hover_data["bbox"]
82
+ num = hover_data["pointNumber"]
83
+
84
+ image_base64 = images[num]
85
+ children = [
86
+ html.Div([
87
+ html.Img(
88
+ src=f"data:image/jpeg;base64,{image_base64}",
89
+ style={"width": "200px",
90
+ "height": "200px",
91
+ 'display': 'block', 'margin': '0 auto'},
92
+ ),
93
+ ])
94
+ ]
95
+
96
+ return True, bbox, children
97
+
98
+ return app
99
+
100
+ def perform_kmeans(data, k=20):
101
+ # Extract x, y coordinates
102
+ coords = np.array([[point['x'], point['y']] for point in data])
103
+
104
+ # Perform k-means clustering
105
+ kmeans = KMeans(n_clusters=k, random_state=42)
106
+ kmeans.fit(coords)
107
+
108
+ return kmeans
109
+
110
+ def find_nearest_images(data, kmeans):
111
+ coords = np.array([[point['x'], point['y']] for point in data])
112
+ images = [point['image'] for point in data]
113
+
114
+ # Calculate distances to cluster centers
115
+ distances = cdist(coords, kmeans.cluster_centers_, metric='euclidean')
116
+
117
+ # Find the index of the nearest point for each cluster
118
+ nearest_indices = distances.argmin(axis=0)
119
+
120
+ # Get the images nearest to each cluster center
121
+ nearest_images = [images[i] for i in nearest_indices]
122
+
123
+ return nearest_images, kmeans.cluster_centers_
124
+
125
+ def create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title):
126
+ # Extract x, y coordinates
127
+ x = [point['x'] for point in data]
128
+ y = [point['y'] for point in data]
129
+ images = [point['image'] for point in data]
130
+
131
+ # Determine the range for both axes
132
+ max_range = max(max(x) - min(x), max(y) - min(y)) / 2
133
+ center_x = (max(x) + min(x)) / 2
134
+ center_y = (max(y) + min(y)) / 2
135
+
136
+ # Create the scatter plot
137
+ fig = go.Figure()
138
+
139
+ # Add data points
140
+ fig.add_trace(go.Scatter(
141
+ x=x,
142
+ y=y,
143
+ mode='markers',
144
+ marker=dict(
145
+ size=5,
146
+ color=kmeans_result.labels_,
147
+ colorscale='Viridis',
148
+ showscale=False
149
+ ),
150
+ name='Data Points'
151
+ ))
152
+
153
+ # Add cluster centers and images
154
+
155
+ fig.update_layout(
156
+ title=title,
157
+ width=1000, height=1000,
158
+ xaxis=dict(
159
+ range=[center_x - max_range, center_x + max_range],
160
+ scaleanchor="y",
161
+ scaleratio=1,
162
+ ),
163
+ yaxis=dict(
164
+ range=[center_y - max_range, center_y + max_range],
165
+ ),
166
+ showlegend=False,
167
+ )
168
+
169
+ fig.update_traces(
170
+ hoverinfo="none",
171
+ hovertemplate=None,
172
+ )
173
+ # Add images
174
+ for i, (cx, cy) in enumerate(cluster_centers):
175
+ fig.add_layout_image(
176
+ dict(
177
+ source=f"data:image/jpg;base64,{nearest_images[i]}",
178
+ x=cx,
179
+ y=cy,
180
+ xref="x",
181
+ yref="y",
182
+ sizex=10,
183
+ sizey=10,
184
+ sizing="contain",
185
+ opacity=1,
186
+ layer="below"
187
+ )
188
+ )
189
+
190
+ # Remove x and y axes ticks
191
+ fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False))
192
+
193
+ return fig, images
194
+
195
+ def make_dash_kmeans(data, title, k=40):
196
+ kmeans_result = perform_kmeans(data, k=k)
197
+ nearest_images, cluster_centers = find_nearest_images(data, kmeans_result)
198
+ fig, images = create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title)
199
+ app = create_dash_app(fig, images)
200
+ port = find_free_port()
201
+ print(f"Serving on http://127.0.0.1:{port}/")
202
+ print(f"To serve this over the Internet, run `ngrok http {port}`")
203
+ app.run_server(port=port)
204
+ return app
205
+
206
+ if __name__ == "__main__":
207
+
208
+ dataset_folder = os.path.dirname('./')
209
+ name = "style"
210
+ image_embedding_path = os.path.join(dataset_folder, f"processed_dataset.parquet")
211
+ tsne_path = os.path.join(dataset_folder, f"processed_dataset.json")
212
+
213
+ generate_tsne_embedding(image_embedding_path, tsne_path)
214
+ with open(tsne_path, "r") as f:
215
+ data = json.load(f)
216
+
217
+ make_dash_kmeans(data, name, k=40)