multimodalart HF staff commited on
Commit
18d0601
1 Parent(s): fe00fdd

Upload 52 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. cog_sdxl/.dockerignore +35 -0
  3. cog_sdxl/.gitignore +23 -0
  4. cog_sdxl/LICENSE +202 -0
  5. cog_sdxl/README.md +41 -0
  6. cog_sdxl/cog.yaml +33 -0
  7. cog_sdxl/dataset_and_utils.py +421 -0
  8. cog_sdxl/example_datasets/README.md +3 -0
  9. cog_sdxl/example_datasets/kiriko.png +3 -0
  10. cog_sdxl/example_datasets/kiriko/0.src.jpg +0 -0
  11. cog_sdxl/example_datasets/kiriko/1.src.jpg +0 -0
  12. cog_sdxl/example_datasets/kiriko/10.src.jpg +0 -0
  13. cog_sdxl/example_datasets/kiriko/11.src.jpg +0 -0
  14. cog_sdxl/example_datasets/kiriko/12.src.jpg +0 -0
  15. cog_sdxl/example_datasets/kiriko/2.src.jpg +0 -0
  16. cog_sdxl/example_datasets/kiriko/3.src.jpg +0 -0
  17. cog_sdxl/example_datasets/kiriko/4.src.jpg +0 -0
  18. cog_sdxl/example_datasets/kiriko/5.src.jpg +0 -0
  19. cog_sdxl/example_datasets/kiriko/6.src.jpg +0 -0
  20. cog_sdxl/example_datasets/kiriko/7.src.jpg +0 -0
  21. cog_sdxl/example_datasets/kiriko/8.src.jpg +0 -0
  22. cog_sdxl/example_datasets/kiriko/9.src.jpg +0 -0
  23. cog_sdxl/example_datasets/monster.png +0 -0
  24. cog_sdxl/example_datasets/monster/caption.csv +6 -0
  25. cog_sdxl/example_datasets/monster/monstertoy (1).jpg +0 -0
  26. cog_sdxl/example_datasets/monster/monstertoy (2).jpg +0 -0
  27. cog_sdxl/example_datasets/monster/monstertoy (3).jpg +0 -0
  28. cog_sdxl/example_datasets/monster/monstertoy (4).jpg +0 -0
  29. cog_sdxl/example_datasets/monster/monstertoy (5).jpg +0 -0
  30. cog_sdxl/example_datasets/monster_uni.png +3 -0
  31. cog_sdxl/example_datasets/zeke.zip +3 -0
  32. cog_sdxl/example_datasets/zeke/0.src.jpg +0 -0
  33. cog_sdxl/example_datasets/zeke/1.src.jpg +0 -0
  34. cog_sdxl/example_datasets/zeke/2.src.jpg +0 -0
  35. cog_sdxl/example_datasets/zeke/3.src.jpg +0 -0
  36. cog_sdxl/example_datasets/zeke/4.src.jpg +0 -0
  37. cog_sdxl/example_datasets/zeke/5.src.jpg +0 -0
  38. cog_sdxl/example_datasets/zeke_unicorn.png +3 -0
  39. cog_sdxl/feature-extractor/preprocessor_config.json +20 -0
  40. cog_sdxl/no_init.py +121 -0
  41. cog_sdxl/predict.py +462 -0
  42. cog_sdxl/preprocess.py +599 -0
  43. cog_sdxl/requirements_test.txt +5 -0
  44. cog_sdxl/samples.py +155 -0
  45. cog_sdxl/script/download_preprocessing_weights.py +54 -0
  46. cog_sdxl/script/download_weights.py +50 -0
  47. cog_sdxl/tests/assets/out.png +3 -0
  48. cog_sdxl/tests/test_predict.py +205 -0
  49. cog_sdxl/tests/test_remote_train.py +69 -0
  50. cog_sdxl/tests/test_utils.py +105 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cog_sdxl/example_datasets/kiriko.png filter=lfs diff=lfs merge=lfs -text
37
+ cog_sdxl/example_datasets/monster_uni.png filter=lfs diff=lfs merge=lfs -text
38
+ cog_sdxl/example_datasets/zeke_unicorn.png filter=lfs diff=lfs merge=lfs -text
39
+ cog_sdxl/tests/assets/out.png filter=lfs diff=lfs merge=lfs -text
cog_sdxl/.dockerignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sdxl-cache/
2
+ refiner-cache/
3
+ safety-cache/
4
+ trained-model/
5
+ *.png
6
+ cache/
7
+ checkpoint/
8
+ training_out/
9
+ dreambooth/
10
+ lora/
11
+ ttemp/
12
+ .git/
13
+ cog_class_data/
14
+ dataset/
15
+ training_data/
16
+ temp/
17
+ temp_in/
18
+ cog_instance_data/
19
+ example_datasets/
20
+ trained_model.tar
21
+ zeke_data.tar
22
+ data.tar
23
+ zeke.zip
24
+ sketch-mountains-input.jpeg
25
+ training_out*
26
+ weights
27
+ inference_*
28
+ trained-model
29
+ *.zip
30
+ tmp/
31
+ blip-cache/
32
+ clipseg-cache/
33
+ swin2sr-cache/
34
+ weights-cache/
35
+ tests/
cog_sdxl/.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ refiner-cache
3
+ sdxl-cache
4
+ safety-cache
5
+ trained-model
6
+ temp
7
+ temp_in
8
+ cache
9
+ .cog
10
+ __pycache__
11
+ wandb
12
+ ft*
13
+ *.ipynb
14
+ dataset
15
+ training_data
16
+ training_out
17
+ output*
18
+ training_out*
19
+ trained_model.tar
20
+ checkpoint*
21
+ weights
22
+ __*.zip
23
+ **-cache
cog_sdxl/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2023, Replicate, Inc.
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
cog_sdxl/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cog-SDXL
2
+
3
+ [![Replicate demo and cloud API](https://replicate.com/stability-ai/sdxl/badge)](https://replicate.com/stability-ai/sdxl)
4
+
5
+ This is an implementation of Stability AI's [SDXL](https://github.com/Stability-AI/generative-models) as a [Cog](https://github.com/replicate/cog) model.
6
+
7
+ ## Development
8
+
9
+ Follow the [model pushing guide](https://replicate.com/docs/guides/push-a-model) to push your own fork of SDXL to [Replicate](https://replicate.com).
10
+
11
+ ## Basic Usage
12
+
13
+ for prediction,
14
+
15
+ ```bash
16
+ cog predict -i prompt="a photo of TOK"
17
+ ```
18
+
19
+ ```bash
20
+ cog train -i input_images=@example_datasets/__data.zip -i use_face_detection_instead=True
21
+ ```
22
+
23
+ ```bash
24
+ cog run -p 5000 python -m cog.server.http
25
+ ```
26
+
27
+ ## Update notes
28
+
29
+ **2023-08-17**
30
+ * ROI problem is fixed.
31
+ * Now BLIP caption_prefix does not interfere with BLIP captioner.
32
+
33
+
34
+ **2023-08-12**
35
+ * Input types are inferred from input name extensions, or from the `input_images_filetype` argument
36
+ * Preprocssing are now done with fp16, and if no mask is found, the model will use the whole image
37
+
38
+ **2023-08-11**
39
+ * Default to 768x768 resolution training
40
+ * Rank as argument now, default to 32
41
+ * Now uses Swin2SR `caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr` as default, and will upscale + downscale to 768x768
cog_sdxl/cog.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ gpu: true
6
+ cuda: "11.8"
7
+ python_version: "3.9"
8
+ system_packages:
9
+ - "libgl1-mesa-glx"
10
+ - "ffmpeg"
11
+ - "libsm6"
12
+ - "libxext6"
13
+ - "wget"
14
+ python_packages:
15
+ - "diffusers<=0.25"
16
+ - "torch==2.0.1"
17
+ - "transformers==4.31.0"
18
+ - "invisible-watermark==0.2.0"
19
+ - "accelerate==0.21.0"
20
+ - "pandas==2.0.3"
21
+ - "torchvision==0.15.2"
22
+ - "numpy==1.25.1"
23
+ - "pandas==2.0.3"
24
+ - "fire==0.5.0"
25
+ - "opencv-python>=4.1.0.25"
26
+ - "mediapipe==0.10.2"
27
+
28
+ run:
29
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)" && chmod +x /usr/local/bin/pget
30
+ - wget http://thegiflibrary.tumblr.com/post/11565547760 -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task
31
+
32
+ predict: "predict.py:Predictor"
33
+ train: "train.py:train"
cog_sdxl/dataset_and_utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import PIL
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
10
+ from PIL import Image
11
+ from safetensors import safe_open
12
+ from safetensors.torch import save_file
13
+ from torch.utils.data import Dataset
14
+ from transformers import AutoTokenizer, PretrainedConfig
15
+
16
+
17
+ def prepare_image(
18
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
19
+ ) -> torch.Tensor:
20
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
21
+ arr = np.array(pil_image.convert("RGB"))
22
+ arr = arr.astype(np.float32) / 127.5 - 1
23
+ arr = np.transpose(arr, [2, 0, 1])
24
+ image = torch.from_numpy(arr).unsqueeze(0)
25
+ return image
26
+
27
+
28
+ def prepare_mask(
29
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
30
+ ) -> torch.Tensor:
31
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
32
+ arr = np.array(pil_image.convert("L"))
33
+ arr = arr.astype(np.float32) / 255.0
34
+ arr = np.expand_dims(arr, 0)
35
+ image = torch.from_numpy(arr).unsqueeze(0)
36
+ return image
37
+
38
+
39
+ class PreprocessedDataset(Dataset):
40
+ def __init__(
41
+ self,
42
+ csv_path: str,
43
+ tokenizer_1,
44
+ tokenizer_2,
45
+ vae_encoder,
46
+ text_encoder_1=None,
47
+ text_encoder_2=None,
48
+ do_cache: bool = False,
49
+ size: int = 512,
50
+ text_dropout: float = 0.0,
51
+ scale_vae_latents: bool = True,
52
+ substitute_caption_map: Dict[str, str] = {},
53
+ ):
54
+ super().__init__()
55
+
56
+ self.data = pd.read_csv(csv_path)
57
+ self.csv_path = csv_path
58
+
59
+ self.caption = self.data["caption"]
60
+ # make it lowercase
61
+ self.caption = self.caption.str.lower()
62
+ for key, value in substitute_caption_map.items():
63
+ self.caption = self.caption.str.replace(key.lower(), value)
64
+
65
+ self.image_path = self.data["image_path"]
66
+
67
+ if "mask_path" not in self.data.columns:
68
+ self.mask_path = None
69
+ else:
70
+ self.mask_path = self.data["mask_path"]
71
+
72
+ if text_encoder_1 is None:
73
+ self.return_text_embeddings = False
74
+ else:
75
+ self.text_encoder_1 = text_encoder_1
76
+ self.text_encoder_2 = text_encoder_2
77
+ self.return_text_embeddings = True
78
+ assert (
79
+ NotImplementedError
80
+ ), "Preprocessing Text Encoder is not implemented yet"
81
+
82
+ self.tokenizer_1 = tokenizer_1
83
+ self.tokenizer_2 = tokenizer_2
84
+
85
+ self.vae_encoder = vae_encoder
86
+ self.scale_vae_latents = scale_vae_latents
87
+ self.text_dropout = text_dropout
88
+
89
+ self.size = size
90
+
91
+ if do_cache:
92
+ self.vae_latents = []
93
+ self.tokens_tuple = []
94
+ self.masks = []
95
+
96
+ self.do_cache = True
97
+
98
+ print("Captions to train on: ")
99
+ for idx in range(len(self.data)):
100
+ token, vae_latent, mask = self._process(idx)
101
+ self.vae_latents.append(vae_latent)
102
+ self.tokens_tuple.append(token)
103
+ self.masks.append(mask)
104
+
105
+ del self.vae_encoder
106
+
107
+ else:
108
+ self.do_cache = False
109
+
110
+ @torch.no_grad()
111
+ def _process(
112
+ self, idx: int
113
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
114
+ image_path = self.image_path[idx]
115
+ image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
116
+
117
+ image = PIL.Image.open(image_path).convert("RGB")
118
+ image = prepare_image(image, self.size, self.size).to(
119
+ dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
120
+ )
121
+
122
+ caption = self.caption[idx]
123
+
124
+ print(caption)
125
+
126
+ # tokenizer_1
127
+ ti1 = self.tokenizer_1(
128
+ caption,
129
+ padding="max_length",
130
+ max_length=77,
131
+ truncation=True,
132
+ add_special_tokens=True,
133
+ return_tensors="pt",
134
+ ).input_ids
135
+
136
+ ti2 = self.tokenizer_2(
137
+ caption,
138
+ padding="max_length",
139
+ max_length=77,
140
+ truncation=True,
141
+ add_special_tokens=True,
142
+ return_tensors="pt",
143
+ ).input_ids
144
+
145
+ vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
146
+
147
+ if self.scale_vae_latents:
148
+ vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
149
+
150
+ if self.mask_path is None:
151
+ mask = torch.ones_like(
152
+ vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
153
+ )
154
+
155
+ else:
156
+ mask_path = self.mask_path[idx]
157
+ mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
158
+
159
+ mask = PIL.Image.open(mask_path)
160
+ mask = prepare_mask(mask, self.size, self.size).to(
161
+ dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
162
+ )
163
+
164
+ mask = torch.nn.functional.interpolate(
165
+ mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
166
+ )
167
+ mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
168
+
169
+ assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
170
+
171
+ return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
172
+
173
+ def __len__(self) -> int:
174
+ return len(self.data)
175
+
176
+ def atidx(
177
+ self, idx: int
178
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
179
+ if self.do_cache:
180
+ return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
181
+ else:
182
+ return self._process(idx)
183
+
184
+ def __getitem__(
185
+ self, idx: int
186
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
187
+ token, vae_latent, mask = self.atidx(idx)
188
+ return token, vae_latent, mask
189
+
190
+
191
+ def import_model_class_from_model_name_or_path(
192
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
193
+ ):
194
+ text_encoder_config = PretrainedConfig.from_pretrained(
195
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
196
+ )
197
+ model_class = text_encoder_config.architectures[0]
198
+
199
+ if model_class == "CLIPTextModel":
200
+ from transformers import CLIPTextModel
201
+
202
+ return CLIPTextModel
203
+ elif model_class == "CLIPTextModelWithProjection":
204
+ from transformers import CLIPTextModelWithProjection
205
+
206
+ return CLIPTextModelWithProjection
207
+ else:
208
+ raise ValueError(f"{model_class} is not supported.")
209
+
210
+
211
+ def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
212
+ tokenizer_one = AutoTokenizer.from_pretrained(
213
+ pretrained_model_name_or_path,
214
+ subfolder="tokenizer",
215
+ revision=revision,
216
+ use_fast=False,
217
+ )
218
+ tokenizer_two = AutoTokenizer.from_pretrained(
219
+ pretrained_model_name_or_path,
220
+ subfolder="tokenizer_2",
221
+ revision=revision,
222
+ use_fast=False,
223
+ )
224
+
225
+ # Load scheduler and models
226
+ noise_scheduler = DDPMScheduler.from_pretrained(
227
+ pretrained_model_name_or_path, subfolder="scheduler"
228
+ )
229
+ # import correct text encoder classes
230
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
231
+ pretrained_model_name_or_path, revision
232
+ )
233
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
234
+ pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
235
+ )
236
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
237
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
238
+ )
239
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
240
+ pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
241
+ )
242
+
243
+ vae = AutoencoderKL.from_pretrained(
244
+ pretrained_model_name_or_path, subfolder="vae", revision=revision
245
+ )
246
+ unet = UNet2DConditionModel.from_pretrained(
247
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
248
+ )
249
+
250
+ vae.requires_grad_(False)
251
+ text_encoder_one.requires_grad_(False)
252
+ text_encoder_two.requires_grad_(False)
253
+
254
+ unet.to(device, dtype=weight_dtype)
255
+ vae.to(device, dtype=torch.float32)
256
+ text_encoder_one.to(device, dtype=weight_dtype)
257
+ text_encoder_two.to(device, dtype=weight_dtype)
258
+
259
+ return (
260
+ tokenizer_one,
261
+ tokenizer_two,
262
+ noise_scheduler,
263
+ text_encoder_one,
264
+ text_encoder_two,
265
+ vae,
266
+ unet,
267
+ )
268
+
269
+
270
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
271
+ """
272
+ Returns:
273
+ a state dict containing just the attention processor parameters.
274
+ """
275
+ attn_processors = unet.attn_processors
276
+
277
+ attn_processors_state_dict = {}
278
+
279
+ for attn_processor_key, attn_processor in attn_processors.items():
280
+ for parameter_key, parameter in attn_processor.state_dict().items():
281
+ attn_processors_state_dict[
282
+ f"{attn_processor_key}.{parameter_key}"
283
+ ] = parameter
284
+
285
+ return attn_processors_state_dict
286
+
287
+
288
+ class TokenEmbeddingsHandler:
289
+ def __init__(self, text_encoders, tokenizers):
290
+ self.text_encoders = text_encoders
291
+ self.tokenizers = tokenizers
292
+
293
+ self.train_ids: Optional[torch.Tensor] = None
294
+ self.inserting_toks: Optional[List[str]] = None
295
+ self.embeddings_settings = {}
296
+
297
+ def initialize_new_tokens(self, inserting_toks: List[str]):
298
+ idx = 0
299
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
300
+ assert isinstance(
301
+ inserting_toks, list
302
+ ), "inserting_toks should be a list of strings."
303
+ assert all(
304
+ isinstance(tok, str) for tok in inserting_toks
305
+ ), "All elements in inserting_toks should be strings."
306
+
307
+ self.inserting_toks = inserting_toks
308
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
309
+ tokenizer.add_special_tokens(special_tokens_dict)
310
+ text_encoder.resize_token_embeddings(len(tokenizer))
311
+
312
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
313
+
314
+ # random initialization of new tokens
315
+
316
+ std_token_embedding = (
317
+ text_encoder.text_model.embeddings.token_embedding.weight.data.std()
318
+ )
319
+
320
+ print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
321
+
322
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
323
+ self.train_ids
324
+ ] = (
325
+ torch.randn(
326
+ len(self.train_ids), text_encoder.text_model.config.hidden_size
327
+ )
328
+ .to(device=self.device)
329
+ .to(dtype=self.dtype)
330
+ * std_token_embedding
331
+ )
332
+ self.embeddings_settings[
333
+ f"original_embeddings_{idx}"
334
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
335
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
336
+
337
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
338
+ inu[self.train_ids] = False
339
+
340
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
341
+
342
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
343
+
344
+ idx += 1
345
+
346
+ def save_embeddings(self, file_path: str):
347
+ assert (
348
+ self.train_ids is not None
349
+ ), "Initialize new tokens before saving embeddings."
350
+ tensors = {}
351
+ for idx, text_encoder in enumerate(self.text_encoders):
352
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
353
+ 0
354
+ ] == len(self.tokenizers[0]), "Tokenizers should be the same."
355
+ new_token_embeddings = (
356
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
357
+ self.train_ids
358
+ ]
359
+ )
360
+ tensors[f"text_encoders_{idx}"] = new_token_embeddings
361
+
362
+ save_file(tensors, file_path)
363
+
364
+ @property
365
+ def dtype(self):
366
+ return self.text_encoders[0].dtype
367
+
368
+ @property
369
+ def device(self):
370
+ return self.text_encoders[0].device
371
+
372
+ def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
373
+ # Assuming new tokens are of the format <s_i>
374
+ self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
375
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
376
+ tokenizer.add_special_tokens(special_tokens_dict)
377
+ text_encoder.resize_token_embeddings(len(tokenizer))
378
+
379
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
380
+ assert self.train_ids is not None, "New tokens could not be converted to IDs."
381
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
382
+ self.train_ids
383
+ ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
384
+
385
+ @torch.no_grad()
386
+ def retract_embeddings(self):
387
+ for idx, text_encoder in enumerate(self.text_encoders):
388
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
389
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
390
+ index_no_updates
391
+ ] = (
392
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
393
+ .to(device=text_encoder.device)
394
+ .to(dtype=text_encoder.dtype)
395
+ )
396
+
397
+ # for the parts that were updated, we need to normalize them
398
+ # to have the same std as before
399
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
400
+
401
+ index_updates = ~index_no_updates
402
+ new_embeddings = (
403
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
404
+ index_updates
405
+ ]
406
+ )
407
+ off_ratio = std_token_embedding / new_embeddings.std()
408
+
409
+ new_embeddings = new_embeddings * (off_ratio**0.1)
410
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
411
+ index_updates
412
+ ] = new_embeddings
413
+
414
+ def load_embeddings(self, file_path: str):
415
+ with safe_open(file_path, framework="pt", device=self.device.type) as f:
416
+ for idx in range(len(self.text_encoders)):
417
+ text_encoder = self.text_encoders[idx]
418
+ tokenizer = self.tokenizers[idx]
419
+
420
+ loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
421
+ self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
cog_sdxl/example_datasets/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## Example Datasets
2
+
3
+ This folder contains three example datasets that were used to tune SDXL using the Replicate API, along with (at the top level) example outputs generated from those datasets.
cog_sdxl/example_datasets/kiriko.png ADDED

Git LFS Details

  • SHA256: 9d9861dc28bf9fd0b33992f927630f1ade740017158be76f0afa385008b0775a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
cog_sdxl/example_datasets/kiriko/0.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/1.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/10.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/11.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/12.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/2.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/3.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/4.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/5.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/6.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/7.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/8.src.jpg ADDED
cog_sdxl/example_datasets/kiriko/9.src.jpg ADDED
cog_sdxl/example_datasets/monster.png ADDED
cog_sdxl/example_datasets/monster/caption.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ caption,image_file
2
+ a TOK on a windowsill,monstertoy (1).jpg
3
+ a photo of smiling TOK in an office,monstertoy (2).jpg
4
+ a photo of TOK sitting by a window,monstertoy (3).jpg
5
+ a photo of TOK on a car,monstertoy (4).jpg
6
+ a photo of TOK smiling on the ground,monstertoy (5).jpg
cog_sdxl/example_datasets/monster/monstertoy (1).jpg ADDED
cog_sdxl/example_datasets/monster/monstertoy (2).jpg ADDED
cog_sdxl/example_datasets/monster/monstertoy (3).jpg ADDED
cog_sdxl/example_datasets/monster/monstertoy (4).jpg ADDED
cog_sdxl/example_datasets/monster/monstertoy (5).jpg ADDED
cog_sdxl/example_datasets/monster_uni.png ADDED

Git LFS Details

  • SHA256: 98bf9d0cbef77d7cc5a541940a32a02a9ea49d8122f9722401c9b3c7956aa47a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.71 MB
cog_sdxl/example_datasets/zeke.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64d655ee118eec386272a15c8e3c2522bc40155cd0f39f451596f7800df403e6
3
+ size 860587
cog_sdxl/example_datasets/zeke/0.src.jpg ADDED
cog_sdxl/example_datasets/zeke/1.src.jpg ADDED
cog_sdxl/example_datasets/zeke/2.src.jpg ADDED
cog_sdxl/example_datasets/zeke/3.src.jpg ADDED
cog_sdxl/example_datasets/zeke/4.src.jpg ADDED
cog_sdxl/example_datasets/zeke/5.src.jpg ADDED
cog_sdxl/example_datasets/zeke_unicorn.png ADDED

Git LFS Details

  • SHA256: 59339a736d96dde6f8459ac1f357ed63707e5f5eb50fea3616a64eaaf2586416
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
cog_sdxl/feature-extractor/preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_convert_rgb": true,
5
+ "do_normalize": true,
6
+ "do_resize": true,
7
+ "feature_extractor_type": "CLIPFeatureExtractor",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
cog_sdxl/no_init.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import contextvars
3
+ import threading
4
+ from typing import (
5
+ Callable,
6
+ ContextManager,
7
+ NamedTuple,
8
+ Optional,
9
+ TypeVar,
10
+ Union,
11
+ )
12
+
13
+ import torch
14
+
15
+ __all__ = ["no_init_or_tensor"]
16
+
17
+
18
+ Model = TypeVar("Model")
19
+
20
+
21
+ def no_init_or_tensor(
22
+ loading_code: Optional[Callable[..., Model]] = None
23
+ ) -> Union[Model, ContextManager]:
24
+ """
25
+ Suppress the initialization of weights while loading a model.
26
+
27
+ Can either directly be passed a callable containing model-loading code,
28
+ which will be evaluated with weight initialization suppressed,
29
+ or used as a context manager around arbitrary model-loading code.
30
+
31
+ Args:
32
+ loading_code: Either a callable to evaluate
33
+ with model weight initialization suppressed,
34
+ or None (the default) to use as a context manager.
35
+
36
+ Returns:
37
+ The return value of `loading_code`, if `loading_code` is callable.
38
+
39
+ Otherwise, if `loading_code` is None, returns a context manager
40
+ to be used in a `with`-statement.
41
+
42
+ Examples:
43
+ As a context manager::
44
+
45
+ from transformers import AutoConfig, AutoModelForCausalLM
46
+ config = AutoConfig("EleutherAI/gpt-j-6B")
47
+ with no_init_or_tensor():
48
+ model = AutoModelForCausalLM.from_config(config)
49
+
50
+ Or, directly passing a callable::
51
+
52
+ from transformers import AutoConfig, AutoModelForCausalLM
53
+ config = AutoConfig("EleutherAI/gpt-j-6B")
54
+ model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config))
55
+ """
56
+ if loading_code is None:
57
+ return _NoInitOrTensorImpl.context_manager()
58
+ elif callable(loading_code):
59
+ with _NoInitOrTensorImpl.context_manager():
60
+ return loading_code()
61
+ else:
62
+ raise TypeError(
63
+ "no_init_or_tensor() expected a callable to evaluate,"
64
+ " or None if being used as a context manager;"
65
+ f' got an object of type "{type(loading_code).__name__}" instead.'
66
+ )
67
+
68
+
69
+ class _NoInitOrTensorImpl:
70
+ # Implementation of the thread-safe, async-safe, re-entrant context manager
71
+ # version of no_init_or_tensor().
72
+ # This class essentially acts as a namespace.
73
+ # It is not instantiable, because modifications to torch functions
74
+ # inherently affect the global scope, and thus there is no worthwhile data
75
+ # to store in the class instance scope.
76
+ _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
77
+ _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
78
+ _ORIGINAL_EMPTY = torch.empty
79
+
80
+ is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
81
+ _count_active: int = 0
82
+ _count_active_lock = threading.Lock()
83
+
84
+ @classmethod
85
+ @contextlib.contextmanager
86
+ def context_manager(cls):
87
+ if cls.is_active.get():
88
+ yield
89
+ return
90
+
91
+ with cls._count_active_lock:
92
+ cls._count_active += 1
93
+ if cls._count_active == 1:
94
+ for mod in cls._MODULES:
95
+ mod.reset_parameters = cls._disable(mod.reset_parameters)
96
+ # When torch.empty is called, make it map to meta device by replacing
97
+ # the device in kwargs.
98
+ torch.empty = cls._ORIGINAL_EMPTY
99
+ reset_token = cls.is_active.set(True)
100
+
101
+ try:
102
+ yield
103
+ finally:
104
+ cls.is_active.reset(reset_token)
105
+ with cls._count_active_lock:
106
+ cls._count_active -= 1
107
+ if cls._count_active == 0:
108
+ torch.empty = cls._ORIGINAL_EMPTY
109
+ for mod, original in cls._MODULE_ORIGINALS:
110
+ mod.reset_parameters = original
111
+
112
+ @staticmethod
113
+ def _disable(func):
114
+ def wrapper(*args, **kwargs):
115
+ # Behaves as normal except in an active context
116
+ if not _NoInitOrTensorImpl.is_active.get():
117
+ return func(*args, **kwargs)
118
+
119
+ return wrapper
120
+
121
+ __init__ = None
cog_sdxl/predict.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import time
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+ from weights import WeightsDownloadCache
9
+
10
+ import numpy as np
11
+ import torch
12
+ from cog import BasePredictor, Input, Path
13
+ from diffusers import (
14
+ DDIMScheduler,
15
+ DiffusionPipeline,
16
+ DPMSolverMultistepScheduler,
17
+ EulerAncestralDiscreteScheduler,
18
+ EulerDiscreteScheduler,
19
+ HeunDiscreteScheduler,
20
+ PNDMScheduler,
21
+ StableDiffusionXLImg2ImgPipeline,
22
+ StableDiffusionXLInpaintPipeline,
23
+ )
24
+ from diffusers.models.attention_processor import LoRAAttnProcessor2_0
25
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
26
+ StableDiffusionSafetyChecker,
27
+ )
28
+ from diffusers.utils import load_image
29
+ from safetensors import safe_open
30
+ from safetensors.torch import load_file
31
+ from transformers import CLIPImageProcessor
32
+
33
+ from dataset_and_utils import TokenEmbeddingsHandler
34
+
35
+ SDXL_MODEL_CACHE = "./sdxl-cache"
36
+ REFINER_MODEL_CACHE = "./refiner-cache"
37
+ SAFETY_CACHE = "./safety-cache"
38
+ FEATURE_EXTRACTOR = "./feature-extractor"
39
+ SDXL_URL = "https://weights.replicate.delivery/default/sdxl/sdxl-vae-upcast-fix.tar"
40
+ REFINER_URL = (
41
+ "https://weights.replicate.delivery/default/sdxl/refiner-no-vae-no-encoder-1.0.tar"
42
+ )
43
+ SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
44
+
45
+
46
+ class KarrasDPM:
47
+ def from_config(config):
48
+ return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)
49
+
50
+
51
+ SCHEDULERS = {
52
+ "DDIM": DDIMScheduler,
53
+ "DPMSolverMultistep": DPMSolverMultistepScheduler,
54
+ "HeunDiscrete": HeunDiscreteScheduler,
55
+ "KarrasDPM": KarrasDPM,
56
+ "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
57
+ "K_EULER": EulerDiscreteScheduler,
58
+ "PNDM": PNDMScheduler,
59
+ }
60
+
61
+
62
+ def download_weights(url, dest):
63
+ start = time.time()
64
+ print("downloading url: ", url)
65
+ print("downloading to: ", dest)
66
+ subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
67
+ print("downloading took: ", time.time() - start)
68
+
69
+
70
+ class Predictor(BasePredictor):
71
+ def load_trained_weights(self, weights, pipe):
72
+ from no_init import no_init_or_tensor
73
+
74
+ # weights can be a URLPath, which behaves in unexpected ways
75
+ weights = str(weights)
76
+ if self.tuned_weights == weights:
77
+ print("skipping loading .. weights already loaded")
78
+ return
79
+
80
+ # predictions can be cancelled while in this function, which
81
+ # interrupts this finishing. To protect against odd states we
82
+ # set tuned_weights to a value that lets the next prediction
83
+ # know if it should try to load weights or if loading completed
84
+ self.tuned_weights = 'loading'
85
+
86
+ local_weights_cache = self.weights_cache.ensure(weights)
87
+
88
+ # load UNET
89
+ print("Loading fine-tuned model")
90
+ self.is_lora = False
91
+
92
+ maybe_unet_path = os.path.join(local_weights_cache, "unet.safetensors")
93
+ if not os.path.exists(maybe_unet_path):
94
+ print("Does not have Unet. assume we are using LoRA")
95
+ self.is_lora = True
96
+
97
+ if not self.is_lora:
98
+ print("Loading Unet")
99
+
100
+ new_unet_params = load_file(
101
+ os.path.join(local_weights_cache, "unet.safetensors")
102
+ )
103
+ # this should return _IncompatibleKeys(missing_keys=[...], unexpected_keys=[])
104
+ pipe.unet.load_state_dict(new_unet_params, strict=False)
105
+
106
+ else:
107
+ print("Loading Unet LoRA")
108
+
109
+ unet = pipe.unet
110
+
111
+ tensors = load_file(os.path.join(local_weights_cache, "lora.safetensors"))
112
+
113
+ unet_lora_attn_procs = {}
114
+ name_rank_map = {}
115
+ for tk, tv in tensors.items():
116
+ # up is N, d
117
+ tensors[tk] = tv.half()
118
+ if tk.endswith("up.weight"):
119
+ proc_name = ".".join(tk.split(".")[:-3])
120
+ r = tv.shape[1]
121
+ name_rank_map[proc_name] = r
122
+
123
+ for name, attn_processor in unet.attn_processors.items():
124
+ cross_attention_dim = (
125
+ None
126
+ if name.endswith("attn1.processor")
127
+ else unet.config.cross_attention_dim
128
+ )
129
+ if name.startswith("mid_block"):
130
+ hidden_size = unet.config.block_out_channels[-1]
131
+ elif name.startswith("up_blocks"):
132
+ block_id = int(name[len("up_blocks.")])
133
+ hidden_size = list(reversed(unet.config.block_out_channels))[
134
+ block_id
135
+ ]
136
+ elif name.startswith("down_blocks"):
137
+ block_id = int(name[len("down_blocks.")])
138
+ hidden_size = unet.config.block_out_channels[block_id]
139
+ with no_init_or_tensor():
140
+ module = LoRAAttnProcessor2_0(
141
+ hidden_size=hidden_size,
142
+ cross_attention_dim=cross_attention_dim,
143
+ rank=name_rank_map[name],
144
+ ).half()
145
+ unet_lora_attn_procs[name] = module.to("cuda", non_blocking=True)
146
+
147
+ unet.set_attn_processor(unet_lora_attn_procs)
148
+ unet.load_state_dict(tensors, strict=False)
149
+
150
+ # load text
151
+ handler = TokenEmbeddingsHandler(
152
+ [pipe.text_encoder, pipe.text_encoder_2], [pipe.tokenizer, pipe.tokenizer_2]
153
+ )
154
+ handler.load_embeddings(os.path.join(local_weights_cache, "embeddings.pti"))
155
+
156
+ # load params
157
+ with open(os.path.join(local_weights_cache, "special_params.json"), "r") as f:
158
+ params = json.load(f)
159
+
160
+ self.token_map = params
161
+ self.tuned_weights = weights
162
+ self.tuned_model = True
163
+
164
+ def unload_trained_weights(self, pipe: DiffusionPipeline):
165
+ print("unloading loras")
166
+
167
+ def _recursive_unset_lora(module: torch.nn.Module):
168
+ if hasattr(module, "lora_layer"):
169
+ module.lora_layer = None
170
+
171
+ for _, child in module.named_children():
172
+ _recursive_unset_lora(child)
173
+
174
+ _recursive_unset_lora(pipe.unet)
175
+ self.tuned_weights = None
176
+ self.tuned_model = False
177
+
178
+ def setup(self, weights: Optional[Path] = None):
179
+ """Load the model into memory to make running multiple predictions efficient"""
180
+
181
+ start = time.time()
182
+ self.tuned_model = False
183
+ self.tuned_weights = None
184
+ if str(weights) == "weights":
185
+ weights = None
186
+
187
+ self.weights_cache = WeightsDownloadCache()
188
+
189
+ print("Loading safety checker...")
190
+ if not os.path.exists(SAFETY_CACHE):
191
+ download_weights(SAFETY_URL, SAFETY_CACHE)
192
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
193
+ SAFETY_CACHE, torch_dtype=torch.float16
194
+ ).to("cuda")
195
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
196
+
197
+ if not os.path.exists(SDXL_MODEL_CACHE):
198
+ download_weights(SDXL_URL, SDXL_MODEL_CACHE)
199
+
200
+ print("Loading sdxl txt2img pipeline...")
201
+ self.txt2img_pipe = DiffusionPipeline.from_pretrained(
202
+ SDXL_MODEL_CACHE,
203
+ torch_dtype=torch.float16,
204
+ use_safetensors=True,
205
+ variant="fp16",
206
+ )
207
+ self.is_lora = False
208
+ if weights or os.path.exists("./trained-model"):
209
+ self.load_trained_weights(weights, self.txt2img_pipe)
210
+
211
+ self.txt2img_pipe.to("cuda")
212
+
213
+ print("Loading SDXL img2img pipeline...")
214
+ self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
215
+ vae=self.txt2img_pipe.vae,
216
+ text_encoder=self.txt2img_pipe.text_encoder,
217
+ text_encoder_2=self.txt2img_pipe.text_encoder_2,
218
+ tokenizer=self.txt2img_pipe.tokenizer,
219
+ tokenizer_2=self.txt2img_pipe.tokenizer_2,
220
+ unet=self.txt2img_pipe.unet,
221
+ scheduler=self.txt2img_pipe.scheduler,
222
+ )
223
+ self.img2img_pipe.to("cuda")
224
+
225
+ print("Loading SDXL inpaint pipeline...")
226
+ self.inpaint_pipe = StableDiffusionXLInpaintPipeline(
227
+ vae=self.txt2img_pipe.vae,
228
+ text_encoder=self.txt2img_pipe.text_encoder,
229
+ text_encoder_2=self.txt2img_pipe.text_encoder_2,
230
+ tokenizer=self.txt2img_pipe.tokenizer,
231
+ tokenizer_2=self.txt2img_pipe.tokenizer_2,
232
+ unet=self.txt2img_pipe.unet,
233
+ scheduler=self.txt2img_pipe.scheduler,
234
+ )
235
+ self.inpaint_pipe.to("cuda")
236
+
237
+ print("Loading SDXL refiner pipeline...")
238
+ # FIXME(ja): should the vae/text_encoder_2 be loaded from SDXL always?
239
+ # - in the case of fine-tuned SDXL should we still?
240
+ # FIXME(ja): if the answer to above is use VAE/Text_Encoder_2 from fine-tune
241
+ # what does this imply about lora + refiner? does the refiner need to know about
242
+
243
+ if not os.path.exists(REFINER_MODEL_CACHE):
244
+ download_weights(REFINER_URL, REFINER_MODEL_CACHE)
245
+
246
+ print("Loading refiner pipeline...")
247
+ self.refiner = DiffusionPipeline.from_pretrained(
248
+ REFINER_MODEL_CACHE,
249
+ text_encoder_2=self.txt2img_pipe.text_encoder_2,
250
+ vae=self.txt2img_pipe.vae,
251
+ torch_dtype=torch.float16,
252
+ use_safetensors=True,
253
+ variant="fp16",
254
+ )
255
+ self.refiner.to("cuda")
256
+ print("setup took: ", time.time() - start)
257
+ # self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt
258
+
259
+ def load_image(self, path):
260
+ shutil.copyfile(path, "/tmp/image.png")
261
+ return load_image("/tmp/image.png").convert("RGB")
262
+
263
+ def run_safety_checker(self, image):
264
+ safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
265
+ "cuda"
266
+ )
267
+ np_image = [np.array(val) for val in image]
268
+ image, has_nsfw_concept = self.safety_checker(
269
+ images=np_image,
270
+ clip_input=safety_checker_input.pixel_values.to(torch.float16),
271
+ )
272
+ return image, has_nsfw_concept
273
+
274
+ @torch.inference_mode()
275
+ def predict(
276
+ self,
277
+ prompt: str = Input(
278
+ description="Input prompt",
279
+ default="An astronaut riding a rainbow unicorn",
280
+ ),
281
+ negative_prompt: str = Input(
282
+ description="Input Negative Prompt",
283
+ default="",
284
+ ),
285
+ image: Path = Input(
286
+ description="Input image for img2img or inpaint mode",
287
+ default=None,
288
+ ),
289
+ mask: Path = Input(
290
+ description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
291
+ default=None,
292
+ ),
293
+ width: int = Input(
294
+ description="Width of output image",
295
+ default=1024,
296
+ ),
297
+ height: int = Input(
298
+ description="Height of output image",
299
+ default=1024,
300
+ ),
301
+ num_outputs: int = Input(
302
+ description="Number of images to output.",
303
+ ge=1,
304
+ le=4,
305
+ default=1,
306
+ ),
307
+ scheduler: str = Input(
308
+ description="scheduler",
309
+ choices=SCHEDULERS.keys(),
310
+ default="K_EULER",
311
+ ),
312
+ num_inference_steps: int = Input(
313
+ description="Number of denoising steps", ge=1, le=500, default=50
314
+ ),
315
+ guidance_scale: float = Input(
316
+ description="Scale for classifier-free guidance", ge=1, le=50, default=7.5
317
+ ),
318
+ prompt_strength: float = Input(
319
+ description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image",
320
+ ge=0.0,
321
+ le=1.0,
322
+ default=0.8,
323
+ ),
324
+ seed: int = Input(
325
+ description="Random seed. Leave blank to randomize the seed", default=None
326
+ ),
327
+ refine: str = Input(
328
+ description="Which refine style to use",
329
+ choices=["no_refiner", "expert_ensemble_refiner", "base_image_refiner"],
330
+ default="no_refiner",
331
+ ),
332
+ high_noise_frac: float = Input(
333
+ description="For expert_ensemble_refiner, the fraction of noise to use",
334
+ default=0.8,
335
+ le=1.0,
336
+ ge=0.0,
337
+ ),
338
+ refine_steps: int = Input(
339
+ description="For base_image_refiner, the number of steps to refine, defaults to num_inference_steps",
340
+ default=None,
341
+ ),
342
+ apply_watermark: bool = Input(
343
+ description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.",
344
+ default=True,
345
+ ),
346
+ lora_scale: float = Input(
347
+ description="LoRA additive scale. Only applicable on trained models.",
348
+ ge=0.0,
349
+ le=1.0,
350
+ default=0.6,
351
+ ),
352
+ replicate_weights: str = Input(
353
+ description="Replicate LoRA weights to use. Leave blank to use the default weights.",
354
+ default=None,
355
+ ),
356
+ disable_safety_checker: bool = Input(
357
+ description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)",
358
+ default=False,
359
+ ),
360
+ ) -> List[Path]:
361
+ """Run a single prediction on the model."""
362
+ if seed is None:
363
+ seed = int.from_bytes(os.urandom(2), "big")
364
+ print(f"Using seed: {seed}")
365
+
366
+ if replicate_weights:
367
+ self.load_trained_weights(replicate_weights, self.txt2img_pipe)
368
+ elif self.tuned_model:
369
+ self.unload_trained_weights(self.txt2img_pipe)
370
+
371
+ # OOMs can leave vae in bad state
372
+ if self.txt2img_pipe.vae.dtype == torch.float32:
373
+ self.txt2img_pipe.vae.to(dtype=torch.float16)
374
+
375
+ sdxl_kwargs = {}
376
+ if self.tuned_model:
377
+ # consistency with fine-tuning API
378
+ for k, v in self.token_map.items():
379
+ prompt = prompt.replace(k, v)
380
+ print(f"Prompt: {prompt}")
381
+ if image and mask:
382
+ print("inpainting mode")
383
+ sdxl_kwargs["image"] = self.load_image(image)
384
+ sdxl_kwargs["mask_image"] = self.load_image(mask)
385
+ sdxl_kwargs["strength"] = prompt_strength
386
+ sdxl_kwargs["width"] = width
387
+ sdxl_kwargs["height"] = height
388
+ pipe = self.inpaint_pipe
389
+ elif image:
390
+ print("img2img mode")
391
+ sdxl_kwargs["image"] = self.load_image(image)
392
+ sdxl_kwargs["strength"] = prompt_strength
393
+ pipe = self.img2img_pipe
394
+ else:
395
+ print("txt2img mode")
396
+ sdxl_kwargs["width"] = width
397
+ sdxl_kwargs["height"] = height
398
+ pipe = self.txt2img_pipe
399
+
400
+ if refine == "expert_ensemble_refiner":
401
+ sdxl_kwargs["output_type"] = "latent"
402
+ sdxl_kwargs["denoising_end"] = high_noise_frac
403
+ elif refine == "base_image_refiner":
404
+ sdxl_kwargs["output_type"] = "latent"
405
+
406
+ if not apply_watermark:
407
+ # toggles watermark for this prediction
408
+ watermark_cache = pipe.watermark
409
+ pipe.watermark = None
410
+ self.refiner.watermark = None
411
+
412
+ pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
413
+ generator = torch.Generator("cuda").manual_seed(seed)
414
+
415
+ common_args = {
416
+ "prompt": [prompt] * num_outputs,
417
+ "negative_prompt": [negative_prompt] * num_outputs,
418
+ "guidance_scale": guidance_scale,
419
+ "generator": generator,
420
+ "num_inference_steps": num_inference_steps,
421
+ }
422
+
423
+ if self.is_lora:
424
+ sdxl_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
425
+
426
+ output = pipe(**common_args, **sdxl_kwargs)
427
+
428
+ if refine in ["expert_ensemble_refiner", "base_image_refiner"]:
429
+ refiner_kwargs = {
430
+ "image": output.images,
431
+ }
432
+
433
+ if refine == "expert_ensemble_refiner":
434
+ refiner_kwargs["denoising_start"] = high_noise_frac
435
+ if refine == "base_image_refiner" and refine_steps:
436
+ common_args["num_inference_steps"] = refine_steps
437
+
438
+ output = self.refiner(**common_args, **refiner_kwargs)
439
+
440
+ if not apply_watermark:
441
+ pipe.watermark = watermark_cache
442
+ self.refiner.watermark = watermark_cache
443
+
444
+ if not disable_safety_checker:
445
+ _, has_nsfw_content = self.run_safety_checker(output.images)
446
+
447
+ output_paths = []
448
+ for i, image in enumerate(output.images):
449
+ if not disable_safety_checker:
450
+ if has_nsfw_content[i]:
451
+ print(f"NSFW content detected in image {i}")
452
+ continue
453
+ output_path = f"/tmp/out-{i}.png"
454
+ image.save(output_path)
455
+ output_paths.append(Path(output_path))
456
+
457
+ if len(output_paths) == 0:
458
+ raise Exception(
459
+ f"NSFW content detected. Try running it again, or try a different prompt."
460
+ )
461
+
462
+ return output_paths
cog_sdxl/preprocess.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Have SwinIR upsample
2
+ # Have BLIP auto caption
3
+ # Have CLIPSeg auto mask concept
4
+
5
+ import gc
6
+ import fnmatch
7
+ import mimetypes
8
+ import os
9
+ import re
10
+ import shutil
11
+ import tarfile
12
+ from pathlib import Path
13
+ from typing import List, Literal, Optional, Tuple, Union
14
+ from zipfile import ZipFile
15
+
16
+ import cv2
17
+ import mediapipe as mp
18
+ import numpy as np
19
+ import pandas as pd
20
+ import torch
21
+ from PIL import Image, ImageFilter
22
+ from tqdm import tqdm
23
+ from transformers import (
24
+ BlipForConditionalGeneration,
25
+ BlipProcessor,
26
+ CLIPSegForImageSegmentation,
27
+ CLIPSegProcessor,
28
+ Swin2SRForImageSuperResolution,
29
+ Swin2SRImageProcessor,
30
+ )
31
+
32
+ from predict import download_weights
33
+
34
+ # model is fixed to Salesforce/blip-image-captioning-large
35
+ BLIP_URL = "https://weights.replicate.delivery/default/blip_large/blip_large.tar"
36
+ BLIP_PROCESSOR_URL = (
37
+ "https://weights.replicate.delivery/default/blip_processor/blip_processor.tar"
38
+ )
39
+ BLIP_PATH = "./blip-cache"
40
+ BLIP_PROCESSOR_PATH = "./blip-proc-cache"
41
+
42
+ # model is fixed to CIDAS/clipseg-rd64-refined
43
+ CLIPSEG_URL = "https://weights.replicate.delivery/default/clip_seg_rd64_refined/clip_seg_rd64_refined.tar"
44
+ CLIPSEG_PROCESSOR = "https://weights.replicate.delivery/default/clip_seg_processor/clip_seg_processor.tar"
45
+ CLIPSEG_PATH = "./clipseg-cache"
46
+ CLIPSEG_PROCESSOR_PATH = "./clipseg-proc-cache"
47
+
48
+ # model is fixed to caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr
49
+ SWIN2SR_URL = "https://weights.replicate.delivery/default/swin2sr_realworld_sr_x4_64_bsrgan_psnr/swin2sr_realworld_sr_x4_64_bsrgan_psnr.tar"
50
+ SWIN2SR_PATH = "./swin2sr-cache"
51
+
52
+ TEMP_OUT_DIR = "./temp/"
53
+ TEMP_IN_DIR = "./temp_in/"
54
+
55
+ CSV_MATCH = "caption"
56
+
57
+
58
+ def preprocess(
59
+ input_images_filetype: str,
60
+ input_zip_path: Path,
61
+ caption_text: str,
62
+ mask_target_prompts: str,
63
+ target_size: int,
64
+ crop_based_on_salience: bool,
65
+ use_face_detection_instead: bool,
66
+ temp: float,
67
+ substitution_tokens: List[str],
68
+ ) -> Path:
69
+ # assert str(files).endswith(".zip"), "files must be a zip file"
70
+
71
+ # clear TEMP_IN_DIR first.
72
+
73
+ for path in [TEMP_OUT_DIR, TEMP_IN_DIR]:
74
+ if os.path.exists(path):
75
+ shutil.rmtree(path)
76
+ os.makedirs(path)
77
+
78
+ caption_csv = None
79
+
80
+ if input_images_filetype == "zip" or str(input_zip_path).endswith(".zip"):
81
+ with ZipFile(str(input_zip_path), "r") as zip_ref:
82
+ for zip_info in zip_ref.infolist():
83
+ if zip_info.filename[-1] == "/" or zip_info.filename.startswith(
84
+ "__MACOSX"
85
+ ):
86
+ continue
87
+ mt = mimetypes.guess_type(zip_info.filename)
88
+ if mt and mt[0] and mt[0].startswith("image/"):
89
+ zip_info.filename = os.path.basename(zip_info.filename)
90
+ zip_ref.extract(zip_info, TEMP_IN_DIR)
91
+ if (
92
+ mt
93
+ and mt[0]
94
+ and mt[0] == "text/csv"
95
+ and CSV_MATCH in zip_info.filename
96
+ ):
97
+ zip_info.filename = os.path.basename(zip_info.filename)
98
+ zip_ref.extract(zip_info, TEMP_IN_DIR)
99
+ caption_csv = os.path.join(TEMP_IN_DIR, zip_info.filename)
100
+ elif input_images_filetype == "tar" or str(input_zip_path).endswith(".tar"):
101
+ assert str(input_zip_path).endswith(
102
+ ".tar"
103
+ ), "files must be a tar file if not zip"
104
+ with tarfile.open(input_zip_path, "r") as tar_ref:
105
+ for tar_info in tar_ref:
106
+ if tar_info.name[-1] == "/" or tar_info.name.startswith("__MACOSX"):
107
+ continue
108
+
109
+ mt = mimetypes.guess_type(tar_info.name)
110
+ if mt and mt[0] and mt[0].startswith("image/"):
111
+ tar_info.name = os.path.basename(tar_info.name)
112
+ tar_ref.extract(tar_info, TEMP_IN_DIR)
113
+ if mt and mt[0] and mt[0] == "text/csv" and CSV_MATCH in tar_info.name:
114
+ tar_info.name = os.path.basename(tar_info.name)
115
+ tar_ref.extract(tar_info, TEMP_IN_DIR)
116
+ caption_csv = os.path.join(TEMP_IN_DIR, tar_info.name)
117
+ else:
118
+ assert False, "input_images_filetype must be zip or tar"
119
+
120
+ output_dir: str = TEMP_OUT_DIR
121
+
122
+ load_and_save_masks_and_captions(
123
+ files=TEMP_IN_DIR,
124
+ output_dir=output_dir,
125
+ caption_text=caption_text,
126
+ caption_csv=caption_csv,
127
+ mask_target_prompts=mask_target_prompts,
128
+ target_size=target_size,
129
+ crop_based_on_salience=crop_based_on_salience,
130
+ use_face_detection_instead=use_face_detection_instead,
131
+ temp=temp,
132
+ substitution_tokens=substitution_tokens,
133
+ )
134
+
135
+ return Path(TEMP_OUT_DIR)
136
+
137
+
138
+ @torch.no_grad()
139
+ @torch.cuda.amp.autocast()
140
+ def swin_ir_sr(
141
+ images: List[Image.Image],
142
+ target_size: Optional[Tuple[int, int]] = None,
143
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
144
+ **kwargs,
145
+ ) -> List[Image.Image]:
146
+ """
147
+ Upscales images using SwinIR. Returns a list of PIL images.
148
+ If the image is already larger than the target size, it will not be upscaled
149
+ and will be returned as is.
150
+
151
+ """
152
+ if not os.path.exists(SWIN2SR_PATH):
153
+ download_weights(SWIN2SR_URL, SWIN2SR_PATH)
154
+ model = Swin2SRForImageSuperResolution.from_pretrained(SWIN2SR_PATH).to(device)
155
+ processor = Swin2SRImageProcessor()
156
+
157
+ out_images = []
158
+
159
+ for image in tqdm(images):
160
+ ori_w, ori_h = image.size
161
+ if target_size is not None:
162
+ if ori_w >= target_size[0] and ori_h >= target_size[1]:
163
+ out_images.append(image)
164
+ continue
165
+
166
+ inputs = processor(image, return_tensors="pt").to(device)
167
+ with torch.no_grad():
168
+ outputs = model(**inputs)
169
+
170
+ output = (
171
+ outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
172
+ )
173
+ output = np.moveaxis(output, source=0, destination=-1)
174
+ output = (output * 255.0).round().astype(np.uint8)
175
+ output = Image.fromarray(output)
176
+
177
+ out_images.append(output)
178
+
179
+ return out_images
180
+
181
+
182
+ @torch.no_grad()
183
+ @torch.cuda.amp.autocast()
184
+ def clipseg_mask_generator(
185
+ images: List[Image.Image],
186
+ target_prompts: Union[List[str], str],
187
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
188
+ bias: float = 0.01,
189
+ temp: float = 1.0,
190
+ **kwargs,
191
+ ) -> List[Image.Image]:
192
+ """
193
+ Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
194
+ """
195
+
196
+ if isinstance(target_prompts, str):
197
+ print(
198
+ f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
199
+ )
200
+
201
+ target_prompts = [target_prompts] * len(images)
202
+ if not os.path.exists(CLIPSEG_PROCESSOR_PATH):
203
+ download_weights(CLIPSEG_PROCESSOR, CLIPSEG_PROCESSOR_PATH)
204
+ if not os.path.exists(CLIPSEG_PATH):
205
+ download_weights(CLIPSEG_URL, CLIPSEG_PATH)
206
+ processor = CLIPSegProcessor.from_pretrained(CLIPSEG_PROCESSOR_PATH)
207
+ model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_PATH).to(device)
208
+
209
+ masks = []
210
+
211
+ for image, prompt in tqdm(zip(images, target_prompts)):
212
+ original_size = image.size
213
+
214
+ inputs = processor(
215
+ text=[prompt, ""],
216
+ images=[image] * 2,
217
+ padding="max_length",
218
+ truncation=True,
219
+ return_tensors="pt",
220
+ ).to(device)
221
+
222
+ outputs = model(**inputs)
223
+
224
+ logits = outputs.logits
225
+ probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
226
+ probs = (probs + bias).clamp_(0, 1)
227
+ probs = 255 * probs / probs.max()
228
+
229
+ # make mask greyscale
230
+ mask = Image.fromarray(probs.cpu().numpy()).convert("L")
231
+
232
+ # resize mask to original size
233
+ mask = mask.resize(original_size)
234
+
235
+ masks.append(mask)
236
+
237
+ return masks
238
+
239
+
240
+ @torch.no_grad()
241
+ def blip_captioning_dataset(
242
+ images: List[Image.Image],
243
+ text: Optional[str] = None,
244
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
245
+ substitution_tokens: Optional[List[str]] = None,
246
+ **kwargs,
247
+ ) -> List[str]:
248
+ """
249
+ Returns a list of captions for the given images
250
+ """
251
+ if not os.path.exists(BLIP_PROCESSOR_PATH):
252
+ download_weights(BLIP_PROCESSOR_URL, BLIP_PROCESSOR_PATH)
253
+ if not os.path.exists(BLIP_PATH):
254
+ download_weights(BLIP_URL, BLIP_PATH)
255
+ processor = BlipProcessor.from_pretrained(BLIP_PROCESSOR_PATH)
256
+ model = BlipForConditionalGeneration.from_pretrained(BLIP_PATH).to(device)
257
+ captions = []
258
+ text = text.strip()
259
+ print(f"Input captioning text: {text}")
260
+ for image in tqdm(images):
261
+ inputs = processor(image, return_tensors="pt").to("cuda")
262
+ out = model.generate(
263
+ **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
264
+ )
265
+ caption = processor.decode(out[0], skip_special_tokens=True)
266
+
267
+ # BLIP 2 lowercases all caps tokens. This should properly replace them w/o messing up subwords. I'm sure there's a better way to do this.
268
+ for token in substitution_tokens:
269
+ print(token)
270
+ sub_cap = " " + caption + " "
271
+ print(sub_cap)
272
+ sub_cap = sub_cap.replace(" " + token.lower() + " ", " " + token + " ")
273
+ caption = sub_cap.strip()
274
+
275
+ captions.append(text + " " + caption)
276
+ print("Generated captions", captions)
277
+ return captions
278
+
279
+
280
+ def face_mask_google_mediapipe(
281
+ images: List[Image.Image], blur_amount: float = 0.0, bias: float = 50.0
282
+ ) -> List[Image.Image]:
283
+ """
284
+ Returns a list of images with masks on the face parts.
285
+ """
286
+ mp_face_detection = mp.solutions.face_detection
287
+ mp_face_mesh = mp.solutions.face_mesh
288
+
289
+ face_detection = mp_face_detection.FaceDetection(
290
+ model_selection=1, min_detection_confidence=0.1
291
+ )
292
+ face_mesh = mp_face_mesh.FaceMesh(
293
+ static_image_mode=True, max_num_faces=1, min_detection_confidence=0.1
294
+ )
295
+
296
+ masks = []
297
+ for image in tqdm(images):
298
+ image_np = np.array(image)
299
+
300
+ # Perform face detection
301
+ results_detection = face_detection.process(image_np)
302
+ ih, iw, _ = image_np.shape
303
+ if results_detection.detections:
304
+ for detection in results_detection.detections:
305
+ bboxC = detection.location_data.relative_bounding_box
306
+
307
+ bbox = (
308
+ int(bboxC.xmin * iw),
309
+ int(bboxC.ymin * ih),
310
+ int(bboxC.width * iw),
311
+ int(bboxC.height * ih),
312
+ )
313
+
314
+ # make sure bbox is within image
315
+ bbox = (
316
+ max(0, bbox[0]),
317
+ max(0, bbox[1]),
318
+ min(iw - bbox[0], bbox[2]),
319
+ min(ih - bbox[1], bbox[3]),
320
+ )
321
+
322
+ print(bbox)
323
+
324
+ # Extract face landmarks
325
+ face_landmarks = face_mesh.process(
326
+ image_np[bbox[1] : bbox[1] + bbox[3], bbox[0] : bbox[0] + bbox[2]]
327
+ ).multi_face_landmarks
328
+
329
+ # https://github.com/google/mediapipe/issues/1615
330
+ # This was def helpful
331
+ indexes = [
332
+ 10,
333
+ 338,
334
+ 297,
335
+ 332,
336
+ 284,
337
+ 251,
338
+ 389,
339
+ 356,
340
+ 454,
341
+ 323,
342
+ 361,
343
+ 288,
344
+ 397,
345
+ 365,
346
+ 379,
347
+ 378,
348
+ 400,
349
+ 377,
350
+ 152,
351
+ 148,
352
+ 176,
353
+ 149,
354
+ 150,
355
+ 136,
356
+ 172,
357
+ 58,
358
+ 132,
359
+ 93,
360
+ 234,
361
+ 127,
362
+ 162,
363
+ 21,
364
+ 54,
365
+ 103,
366
+ 67,
367
+ 109,
368
+ ]
369
+
370
+ if face_landmarks:
371
+ mask = Image.new("L", (iw, ih), 0)
372
+ mask_np = np.array(mask)
373
+
374
+ for face_landmark in face_landmarks:
375
+ face_landmark = [face_landmark.landmark[idx] for idx in indexes]
376
+ landmark_points = [
377
+ (int(l.x * bbox[2]) + bbox[0], int(l.y * bbox[3]) + bbox[1])
378
+ for l in face_landmark
379
+ ]
380
+ mask_np = cv2.fillPoly(
381
+ mask_np, [np.array(landmark_points)], 255
382
+ )
383
+
384
+ mask = Image.fromarray(mask_np)
385
+
386
+ # Apply blur to the mask
387
+ if blur_amount > 0:
388
+ mask = mask.filter(ImageFilter.GaussianBlur(blur_amount))
389
+
390
+ # Apply bias to the mask
391
+ if bias > 0:
392
+ mask = np.array(mask)
393
+ mask = mask + bias * np.ones(mask.shape, dtype=mask.dtype)
394
+ mask = np.clip(mask, 0, 255)
395
+ mask = Image.fromarray(mask)
396
+
397
+ # Convert mask to 'L' mode (grayscale) before saving
398
+ mask = mask.convert("L")
399
+
400
+ masks.append(mask)
401
+ else:
402
+ # If face landmarks are not available, add a black mask of the same size as the image
403
+ masks.append(Image.new("L", (iw, ih), 255))
404
+
405
+ else:
406
+ print("No face detected, adding full mask")
407
+ # If no face is detected, add a white mask of the same size as the image
408
+ masks.append(Image.new("L", (iw, ih), 255))
409
+
410
+ return masks
411
+
412
+
413
+ def _crop_to_square(
414
+ image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
415
+ ):
416
+ cx, cy = com
417
+ width, height = image.size
418
+ if width > height:
419
+ left_possible = max(cx - height / 2, 0)
420
+ left = min(left_possible, width - height)
421
+ right = left + height
422
+ top = 0
423
+ bottom = height
424
+ else:
425
+ left = 0
426
+ right = width
427
+ top_possible = max(cy - width / 2, 0)
428
+ top = min(top_possible, height - width)
429
+ bottom = top + width
430
+
431
+ image = image.crop((left, top, right, bottom))
432
+
433
+ if resize_to:
434
+ image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
435
+
436
+ return image
437
+
438
+
439
+ def _center_of_mass(mask: Image.Image):
440
+ """
441
+ Returns the center of mass of the mask
442
+ """
443
+ x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
444
+ mask_np = np.array(mask) + 0.01
445
+ x_ = x * mask_np
446
+ y_ = y * mask_np
447
+
448
+ x = np.sum(x_) / np.sum(mask_np)
449
+ y = np.sum(y_) / np.sum(mask_np)
450
+
451
+ return x, y
452
+
453
+
454
+ def load_and_save_masks_and_captions(
455
+ files: Union[str, List[str]],
456
+ output_dir: str = TEMP_OUT_DIR,
457
+ caption_text: Optional[str] = None,
458
+ caption_csv: Optional[str] = None,
459
+ mask_target_prompts: Optional[Union[List[str], str]] = None,
460
+ target_size: int = 1024,
461
+ crop_based_on_salience: bool = True,
462
+ use_face_detection_instead: bool = False,
463
+ temp: float = 1.0,
464
+ n_length: int = -1,
465
+ substitution_tokens: Optional[List[str]] = None,
466
+ ):
467
+ """
468
+ Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
469
+ to output dir. If mask_target_prompts is given, it will generate kinda-segmentation-masks for the prompts and save them as well.
470
+
471
+ Example:
472
+ >>> x = load_and_save_masks_and_captions(
473
+ files="./data/images",
474
+ output_dir="./data/masks_and_captions",
475
+ caption_text="a photo of",
476
+ mask_target_prompts="cat",
477
+ target_size=768,
478
+ crop_based_on_salience=True,
479
+ use_face_detection_instead=False,
480
+ temp=1.0,
481
+ n_length=-1,
482
+ )
483
+ """
484
+ os.makedirs(output_dir, exist_ok=True)
485
+
486
+ # load images
487
+ if isinstance(files, str):
488
+ # check if it is a directory
489
+ if os.path.isdir(files):
490
+ # get all the .png .jpg in the directory
491
+ files = (
492
+ _find_files("*.png", files)
493
+ + _find_files("*.jpg", files)
494
+ + _find_files("*.jpeg", files)
495
+ )
496
+
497
+ if len(files) == 0:
498
+ raise Exception(
499
+ f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg/jpeg files."
500
+ )
501
+ if n_length == -1:
502
+ n_length = len(files)
503
+ files = sorted(files)[:n_length]
504
+ print("Image files: ", files)
505
+ images = [Image.open(file).convert("RGB") for file in files]
506
+
507
+ # captions
508
+ if caption_csv:
509
+ print(f"Using provided captions")
510
+ caption_df = pd.read_csv(caption_csv)
511
+ # sort images to be consistent with 'sorted' above
512
+ caption_df = caption_df.sort_values("image_file")
513
+ captions = caption_df["caption"].values
514
+ print("Captions: ", captions)
515
+ if len(captions) != len(images):
516
+ print("Not the same number of captions as images!")
517
+ print(f"Num captions: {len(captions)}, Num images: {len(images)}")
518
+ print("Captions: ", captions)
519
+ print("Images: ", files)
520
+ raise Exception(
521
+ "Not the same number of captions as images! Check that all files passed in have a caption in your caption csv, and vice versa"
522
+ )
523
+
524
+ else:
525
+ print(f"Generating {len(images)} captions...")
526
+ captions = blip_captioning_dataset(
527
+ images, text=caption_text, substitution_tokens=substitution_tokens
528
+ )
529
+
530
+ if mask_target_prompts is None:
531
+ mask_target_prompts = ""
532
+ temp = 999
533
+
534
+ print(f"Generating {len(images)} masks...")
535
+ if not use_face_detection_instead:
536
+ seg_masks = clipseg_mask_generator(
537
+ images=images, target_prompts=mask_target_prompts, temp=temp
538
+ )
539
+ else:
540
+ seg_masks = face_mask_google_mediapipe(images=images)
541
+
542
+ # find the center of mass of the mask
543
+ if crop_based_on_salience:
544
+ coms = [_center_of_mass(mask) for mask in seg_masks]
545
+ else:
546
+ coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
547
+ # based on the center of mass, crop the image to a square
548
+ images = [
549
+ _crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
550
+ ]
551
+
552
+ print(f"Upscaling {len(images)} images...")
553
+ # upscale images anyways
554
+ images = swin_ir_sr(images, target_size=(target_size, target_size))
555
+ images = [
556
+ image.resize((target_size, target_size), Image.Resampling.LANCZOS)
557
+ for image in images
558
+ ]
559
+
560
+ seg_masks = [
561
+ _crop_to_square(mask, com, resize_to=target_size)
562
+ for mask, com in zip(seg_masks, coms)
563
+ ]
564
+
565
+ data = []
566
+
567
+ # clean TEMP_OUT_DIR first
568
+ if os.path.exists(output_dir):
569
+ for file in os.listdir(output_dir):
570
+ os.remove(os.path.join(output_dir, file))
571
+
572
+ os.makedirs(output_dir, exist_ok=True)
573
+
574
+ # iterate through the images, masks, and captions and add a row to the dataframe for each
575
+ for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
576
+ image_name = f"{idx}.src.png"
577
+ mask_file = f"{idx}.mask.png"
578
+
579
+ # save the image and mask files
580
+ image.save(output_dir + image_name)
581
+ mask.save(output_dir + mask_file)
582
+
583
+ # add a new row to the dataframe with the file names and caption
584
+ data.append(
585
+ {"image_path": image_name, "mask_path": mask_file, "caption": caption},
586
+ )
587
+
588
+ df = pd.DataFrame(columns=["image_path", "mask_path", "caption"], data=data)
589
+ # save the dataframe to a CSV file
590
+ df.to_csv(os.path.join(output_dir, "captions.csv"), index=False)
591
+
592
+
593
+ def _find_files(pattern, dir="."):
594
+ """Return list of files matching pattern in a given directory, in absolute format.
595
+ Unlike glob, this is case-insensitive.
596
+ """
597
+
598
+ rule = re.compile(fnmatch.translate(pattern), re.IGNORECASE)
599
+ return [os.path.join(dir, f) for f in os.listdir(dir) if rule.match(f)]
cog_sdxl/requirements_test.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ pytest
3
+ replicate
4
+ requests
5
+ Pillow
cog_sdxl/samples.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A handy utility for verifying SDXL image generation locally.
3
+ To set up, first run a local cog server using:
4
+ cog run -p 5000 python -m cog.server.http
5
+ Then, in a separate terminal, generate samples
6
+ python samples.py
7
+ """
8
+
9
+
10
+ import base64
11
+ import os
12
+ import sys
13
+
14
+ import requests
15
+
16
+
17
+ def gen(output_fn, **kwargs):
18
+ if os.path.exists(output_fn):
19
+ return
20
+
21
+ print("Generating", output_fn)
22
+ url = "http://localhost:5000/predictions"
23
+ response = requests.post(url, json={"input": kwargs})
24
+ data = response.json()
25
+
26
+ try:
27
+ datauri = data["output"][0]
28
+ base64_encoded_data = datauri.split(",")[1]
29
+ data = base64.b64decode(base64_encoded_data)
30
+ except:
31
+ print("Error!")
32
+ print("input:", kwargs)
33
+ print(data["logs"])
34
+ sys.exit(1)
35
+
36
+ with open(output_fn, "wb") as f:
37
+ f.write(data)
38
+
39
+
40
+ def main():
41
+ SCHEDULERS = [
42
+ "DDIM",
43
+ "DPMSolverMultistep",
44
+ "HeunDiscrete",
45
+ "KarrasDPM",
46
+ "K_EULER_ANCESTRAL",
47
+ "K_EULER",
48
+ "PNDM",
49
+ ]
50
+
51
+ gen(
52
+ f"sample.txt2img.png",
53
+ prompt="A studio portrait photo of a cat",
54
+ num_inference_steps=25,
55
+ guidance_scale=7,
56
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
57
+ seed=1000,
58
+ width=1024,
59
+ height=1024,
60
+ )
61
+
62
+ for refiner in ["base_image_refiner", "expert_ensemble_refiner", "no_refiner"]:
63
+ gen(
64
+ f"sample.img2img.{refiner}.png",
65
+ prompt="a photo of an astronaut riding a horse on mars",
66
+ image="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png",
67
+ prompt_strength=0.8,
68
+ num_inference_steps=25,
69
+ refine=refiner,
70
+ guidance_scale=7,
71
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
72
+ seed=42,
73
+ )
74
+
75
+ gen(
76
+ f"sample.inpaint.{refiner}.png",
77
+ prompt="A majestic tiger sitting on a bench",
78
+ image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png",
79
+ mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png",
80
+ prompt_strength=0.8,
81
+ num_inference_steps=25,
82
+ refine=refiner,
83
+ guidance_scale=7,
84
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
85
+ seed=42,
86
+ )
87
+
88
+ for split in range(0, 10):
89
+ split = split / 10.0
90
+ gen(
91
+ f"sample.expert_ensemble_refiner.{split}.txt2img.png",
92
+ prompt="A studio portrait photo of a cat",
93
+ num_inference_steps=25,
94
+ guidance_scale=7,
95
+ refine="expert_ensemble_refiner",
96
+ high_noise_frac=split,
97
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
98
+ seed=1000,
99
+ width=1024,
100
+ height=1024,
101
+ )
102
+
103
+ gen(
104
+ f"sample.refine.txt2img.png",
105
+ prompt="A studio portrait photo of a cat",
106
+ num_inference_steps=25,
107
+ guidance_scale=7,
108
+ refine="base_image_refiner",
109
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
110
+ seed=1000,
111
+ width=1024,
112
+ height=1024,
113
+ )
114
+ gen(
115
+ f"sample.refine.10.txt2img.png",
116
+ prompt="A studio portrait photo of a cat",
117
+ num_inference_steps=25,
118
+ guidance_scale=7,
119
+ refine="base_image_refiner",
120
+ refine_steps=10,
121
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
122
+ seed=1000,
123
+ width=1024,
124
+ height=1024,
125
+ )
126
+
127
+ gen(
128
+ "samples.2.txt2img.png",
129
+ prompt="A studio portrait photo of a cat",
130
+ num_inference_steps=25,
131
+ guidance_scale=7,
132
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
133
+ scheduler="KarrasDPM",
134
+ num_outputs=2,
135
+ seed=1000,
136
+ width=1024,
137
+ height=1024,
138
+ )
139
+
140
+ for s in SCHEDULERS:
141
+ gen(
142
+ f"sample.{s}.txt2img.png",
143
+ prompt="A studio portrait photo of a cat",
144
+ num_inference_steps=25,
145
+ guidance_scale=7,
146
+ negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
147
+ scheduler=s,
148
+ seed=1000,
149
+ width=1024,
150
+ height=1024,
151
+ )
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()
cog_sdxl/script/download_preprocessing_weights.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+
5
+ from transformers import (
6
+ BlipForConditionalGeneration,
7
+ BlipProcessor,
8
+ CLIPSegForImageSegmentation,
9
+ CLIPSegProcessor,
10
+ Swin2SRForImageSuperResolution,
11
+ )
12
+
13
+ DEFAULT_BLIP = "Salesforce/blip-image-captioning-large"
14
+ DEFAULT_CLIPSEG = "CIDAS/clipseg-rd64-refined"
15
+ DEFAULT_SWINIR = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
16
+
17
+
18
+ def upload(args):
19
+ blip_processor = BlipProcessor.from_pretrained(DEFAULT_BLIP)
20
+ blip_model = BlipForConditionalGeneration.from_pretrained(DEFAULT_BLIP)
21
+
22
+ clip_processor = CLIPSegProcessor.from_pretrained(DEFAULT_CLIPSEG)
23
+ clip_model = CLIPSegForImageSegmentation.from_pretrained(DEFAULT_CLIPSEG)
24
+
25
+ swin_model = Swin2SRForImageSuperResolution.from_pretrained(DEFAULT_SWINIR)
26
+
27
+ temp_models = "tmp/models"
28
+ if os.path.exists(temp_models):
29
+ shutil.rmtree(temp_models)
30
+ os.makedirs(temp_models)
31
+
32
+ blip_processor.save_pretrained(os.path.join(temp_models, "blip_processor"))
33
+ blip_model.save_pretrained(os.path.join(temp_models, "blip_large"))
34
+ clip_processor.save_pretrained(os.path.join(temp_models, "clip_seg_processor"))
35
+ clip_model.save_pretrained(os.path.join(temp_models, "clip_seg_rd64_refined"))
36
+ swin_model.save_pretrained(
37
+ os.path.join(temp_models, "swin2sr_realworld_sr_x4_64_bsrgan_psnr")
38
+ )
39
+
40
+ for val in os.listdir(temp_models):
41
+ if "tar" not in val:
42
+ os.system(
43
+ f"sudo tar -cvf {os.path.join(temp_models, val)}.tar -C {os.path.join(temp_models, val)} ."
44
+ )
45
+ os.system(
46
+ f"gcloud storage cp -R {os.path.join(temp_models, val)}.tar gs://{args.bucket}/{val}/"
47
+ )
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--bucket", "-m", type=str)
53
+ args = parser.parse_args()
54
+ upload(args)
cog_sdxl/script/download_weights.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run this before you deploy it on replicate, because if you don't
2
+ # whenever you run the model, it will download the weights from the
3
+ # internet, which will take a long time.
4
+
5
+ import torch
6
+ from diffusers import AutoencoderKL, DiffusionPipeline
7
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
8
+ StableDiffusionSafetyChecker,
9
+ )
10
+
11
+ # pipe = DiffusionPipeline.from_pretrained(
12
+ # "stabilityai/stable-diffusion-xl-base-1.0",
13
+ # torch_dtype=torch.float16,
14
+ # use_safetensors=True,
15
+ # variant="fp16",
16
+ # )
17
+
18
+ # pipe.save_pretrained("./cache", safe_serialization=True)
19
+
20
+ better_vae = AutoencoderKL.from_pretrained(
21
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
22
+ )
23
+
24
+ pipe = DiffusionPipeline.from_pretrained(
25
+ "stabilityai/stable-diffusion-xl-base-1.0",
26
+ vae=better_vae,
27
+ torch_dtype=torch.float16,
28
+ use_safetensors=True,
29
+ variant="fp16",
30
+ )
31
+
32
+ pipe.save_pretrained("./sdxl-cache", safe_serialization=True)
33
+
34
+ pipe = DiffusionPipeline.from_pretrained(
35
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
36
+ torch_dtype=torch.float16,
37
+ use_safetensors=True,
38
+ variant="fp16",
39
+ )
40
+
41
+ # TODO - we don't need to save all of this and in fact should save just the unet, tokenizer, and config.
42
+ pipe.save_pretrained("./refiner-cache", safe_serialization=True)
43
+
44
+
45
+ safety = StableDiffusionSafetyChecker.from_pretrained(
46
+ "CompVis/stable-diffusion-safety-checker",
47
+ torch_dtype=torch.float16,
48
+ )
49
+
50
+ safety.save_pretrained("./safety-cache")
cog_sdxl/tests/assets/out.png ADDED

Git LFS Details

  • SHA256: e8fe96688cb1e33a7a99eed8645529eb900c6b7b4f9afeaeee8f4bf0afd762df
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
cog_sdxl/tests/test_predict.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import pickle
4
+ import subprocess
5
+ import sys
6
+ import time
7
+ from functools import partial
8
+ from io import BytesIO
9
+
10
+ import numpy as np
11
+ import pytest
12
+ import replicate
13
+ import requests
14
+ from PIL import Image, ImageChops
15
+
16
+ ENV = os.getenv('TEST_ENV', 'local')
17
+ LOCAL_ENDPOINT = "http://localhost:5000/predictions"
18
+ MODEL = os.getenv('STAGING_MODEL', 'no model configured')
19
+
20
+ def local_run(model_endpoint: str, model_input: dict):
21
+ response = requests.post(model_endpoint, json={"input": model_input})
22
+ data = response.json()
23
+
24
+ try:
25
+ # TODO: this will break if we test batching
26
+ datauri = data["output"][0]
27
+ base64_encoded_data = datauri.split(",")[1]
28
+ data = base64.b64decode(base64_encoded_data)
29
+ return Image.open(BytesIO(data))
30
+ except Exception as e:
31
+ print("Error!")
32
+ print("input:", model_input)
33
+ print(data["logs"])
34
+ raise e
35
+
36
+
37
+ def replicate_run(model: str, version: str, model_input: dict):
38
+ output = replicate.run(
39
+ f"{model}:{version}",
40
+ input=model_input)
41
+ url = output[0]
42
+
43
+ response = requests.get(url)
44
+ return Image.open(BytesIO(response.content))
45
+
46
+
47
+ def wait_for_server_to_be_ready(url, timeout=300):
48
+ """
49
+ Waits for the server to be ready.
50
+
51
+ Args:
52
+ - url: The health check URL to poll.
53
+ - timeout: Maximum time (in seconds) to wait for the server to be ready.
54
+ """
55
+ start_time = time.time()
56
+ while True:
57
+ try:
58
+ response = requests.get(url)
59
+ data = response.json()
60
+
61
+ if data["status"] == "READY":
62
+ return
63
+ elif data["status"] == "SETUP_FAILED":
64
+ raise RuntimeError(
65
+ "Server initialization failed with status: SETUP_FAILED"
66
+ )
67
+
68
+ except requests.RequestException:
69
+ pass
70
+
71
+ if time.time() - start_time > timeout:
72
+ raise TimeoutError("Server did not become ready in the expected time.")
73
+
74
+ time.sleep(5) # Poll every 5 seconds
75
+
76
+
77
+ @pytest.fixture(scope="session")
78
+ def inference_func():
79
+ """
80
+ local inference uses http API to hit local server; staging inference uses python API b/c it's cleaner.
81
+ """
82
+ if ENV == 'local':
83
+ return partial(local_run, LOCAL_ENDPOINT)
84
+ elif ENV == 'staging':
85
+ model = replicate.models.get(MODEL)
86
+ print(f"model,", model)
87
+ version = model.versions.list()[0]
88
+ return partial(replicate_run, MODEL, version.id)
89
+ else:
90
+ raise Exception(f"env should be local or staging but was {ENV}")
91
+
92
+
93
+ @pytest.fixture(scope="session", autouse=True)
94
+ def service():
95
+ """
96
+ Spins up local cog server to hit for tests if running locally, no-op otherwise
97
+ """
98
+ if ENV == 'local':
99
+ print("building model")
100
+ # starts local server if we're running things locally
101
+ build_command = 'cog build -t test-model'.split()
102
+ subprocess.run(build_command, check=True)
103
+ container_name = 'cog-test'
104
+ try:
105
+ subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name])
106
+ print(f"Container '{container_name}' is running. Stopping and removing...")
107
+ subprocess.check_call(['docker', 'stop', container_name])
108
+ subprocess.check_call(['docker', 'rm', container_name])
109
+ print(f"Container '{container_name}' stopped and removed.")
110
+ except subprocess.CalledProcessError:
111
+ # Container not found
112
+ print(f"Container '{container_name}' not found or not running.")
113
+
114
+ run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split()
115
+ process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr)
116
+
117
+ wait_for_server_to_be_ready("http://localhost:5000/health-check")
118
+
119
+ yield
120
+ process.terminate()
121
+ process.wait()
122
+ stop_command = "docker stop cog-test".split()
123
+ subprocess.run(stop_command)
124
+ else:
125
+ yield
126
+
127
+
128
+ def image_equal_fuzzy(img_expected, img_actual, test_name='default', tol=20):
129
+ """
130
+ Assert that average pixel values differ by less than tol across image
131
+ Tol determined empirically - holding everything else equal but varying seed
132
+ generates images that vary by at least 50
133
+ """
134
+ img1 = np.array(img_expected, dtype=np.int32)
135
+ img2 = np.array(img_actual, dtype=np.int32)
136
+
137
+ mean_delta = np.mean(np.abs(img1 - img2))
138
+ imgs_equal = (mean_delta < tol)
139
+ if not imgs_equal:
140
+ # save failures for quick inspection
141
+ save_dir = f"tmp/{test_name}"
142
+ if not os.path.exists(save_dir):
143
+ os.makedirs(save_dir)
144
+ img_expected.save(os.path.join(save_dir, 'expected.png'))
145
+ img_actual.save(os.path.join(save_dir, 'actual.png'))
146
+ difference = ImageChops.difference(img_expected, img_actual)
147
+ difference.save(os.path.join(save_dir, 'delta.png'))
148
+
149
+ return imgs_equal
150
+
151
+
152
+ def test_seeded_prediction(inference_func, request):
153
+ """
154
+ SDXL w/seed should be deterministic. may need to adjust tolerance for optimized SDXLs
155
+ """
156
+ data = {
157
+ "prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic",
158
+ "num_inference_steps": 50,
159
+ "width": 1024,
160
+ "height": 1024,
161
+ "scheduler": "DDIM",
162
+ "refine": "expert_ensemble_refiner",
163
+ "seed": 12103,
164
+ }
165
+ actual_image = inference_func(data)
166
+ expected_image = Image.open("tests/assets/out.png")
167
+ assert image_equal_fuzzy(actual_image, expected_image, test_name=request.node.name)
168
+
169
+
170
+ def test_lora_load_unload(inference_func, request):
171
+ """
172
+ Tests generation with & without loras.
173
+ This is checking for some gnarly state issues (can SDXL load / unload LoRAs), so predictions need to run in series.
174
+ """
175
+ SEED = 1234
176
+ base_data = {
177
+ "prompt": "A photo of a dog on the beach",
178
+ "num_inference_steps": 50,
179
+ # Add other parameters here
180
+ "seed": SEED,
181
+ }
182
+ base_img_1 = inference_func(base_data)
183
+
184
+ lora_a_data = {
185
+ "prompt": "A photo of a TOK on the beach",
186
+ "num_inference_steps": 50,
187
+ # Add other parameters here
188
+ "replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/other_model.tar",
189
+ "seed": SEED
190
+ }
191
+ lora_a_img_1 = inference_func(lora_a_data)
192
+ assert not image_equal_fuzzy(lora_a_img_1, base_img_1, test_name=request.node.name)
193
+
194
+ lora_a_img_2 = inference_func(lora_a_data)
195
+ assert image_equal_fuzzy(lora_a_img_1, lora_a_img_2, test_name=request.node.name)
196
+
197
+ lora_b_data = {
198
+ "prompt": "A photo of a TOK on the beach",
199
+ "num_inference_steps": 50,
200
+ "replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/monstertoy_model.tar",
201
+ "seed": SEED,
202
+ }
203
+ lora_b_img = inference_func(lora_b_data)
204
+ assert not image_equal_fuzzy(lora_a_img_1, lora_b_img, test_name=request.node.name)
205
+ assert not image_equal_fuzzy(base_img_1, lora_b_img, test_name=request.node.name)
cog_sdxl/tests/test_remote_train.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import pytest
3
+ import replicate
4
+
5
+
6
+ @pytest.fixture(scope="module")
7
+ def model_name(request):
8
+ return "stability-ai/sdxl"
9
+
10
+
11
+ @pytest.fixture(scope="module")
12
+ def model(model_name):
13
+ return replicate.models.get(model_name)
14
+
15
+
16
+ @pytest.fixture(scope="module")
17
+ def version(model):
18
+ versions = model.versions.list()
19
+ return versions[0]
20
+
21
+
22
+ @pytest.fixture(scope="module")
23
+ def training(model_name, version):
24
+ training_input = {
25
+ "input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar"
26
+ }
27
+ print(f"Training on {model_name}:{version.id}")
28
+ return replicate.trainings.create(
29
+ version=model_name + ":" + version.id,
30
+ input=training_input,
31
+ destination="replicate-internal/training-scratch",
32
+ )
33
+
34
+
35
+ @pytest.fixture(scope="module")
36
+ def prediction_tests():
37
+ return [
38
+ {
39
+ "prompt": "A photo of TOK at the beach",
40
+ "refine": "expert_ensemble_refiner",
41
+ },
42
+ ]
43
+
44
+
45
+ def test_training(training):
46
+ while training.completed_at is None:
47
+ time.sleep(60)
48
+ training.reload()
49
+ assert training.status == "succeeded"
50
+
51
+
52
+ @pytest.fixture(scope="module")
53
+ def trained_model_and_version(training):
54
+ trained_model, trained_version = training.output["version"].split(":")
55
+ return trained_model, trained_version
56
+
57
+
58
+ def test_post_training_predictions(trained_model_and_version, prediction_tests):
59
+ trained_model, trained_version = trained_model_and_version
60
+ model = replicate.models.get(trained_model)
61
+ version = model.versions.get(trained_version)
62
+ predictions = [
63
+ replicate.predictions.create(version=version, input=val)
64
+ for val in prediction_tests
65
+ ]
66
+
67
+ for val in predictions:
68
+ val.wait()
69
+ assert val.status == "succeeded"
cog_sdxl/tests/test_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import time
5
+ from threading import Thread, Lock
6
+ import re
7
+ import multiprocessing
8
+ import subprocess
9
+
10
+ ERROR_PATTERN = re.compile(r"ERROR:")
11
+
12
+
13
+ def get_image_name():
14
+ current_dir = os.path.basename(os.getcwd())
15
+
16
+ if "cog" in current_dir:
17
+ return current_dir
18
+ else:
19
+ return f"cog-{current_dir}"
20
+
21
+
22
+ def process_log_line(line):
23
+ line = line.decode("utf-8").strip()
24
+ try:
25
+ log_data = json.loads(line)
26
+ return json.dumps(log_data, indent=2)
27
+ except json.JSONDecodeError:
28
+ return line
29
+
30
+
31
+ def capture_output(pipe, print_lock, logs=None, error_detected=None):
32
+ for line in iter(pipe.readline, b""):
33
+ formatted_line = process_log_line(line)
34
+ with print_lock:
35
+ print(formatted_line)
36
+ if logs is not None:
37
+ logs.append(formatted_line)
38
+ if error_detected is not None:
39
+ if ERROR_PATTERN.search(formatted_line):
40
+ error_detected[0] = True
41
+
42
+
43
+ def wait_for_server_to_be_ready(url, timeout=300):
44
+ """
45
+ Waits for the server to be ready.
46
+
47
+ Args:
48
+ - url: The health check URL to poll.
49
+ - timeout: Maximum time (in seconds) to wait for the server to be ready.
50
+ """
51
+ start_time = time.time()
52
+ while True:
53
+ try:
54
+ response = requests.get(url)
55
+ data = response.json()
56
+
57
+ if data["status"] == "READY":
58
+ return
59
+ elif data["status"] == "SETUP_FAILED":
60
+ raise RuntimeError(
61
+ "Server initialization failed with status: SETUP_FAILED"
62
+ )
63
+
64
+ except requests.RequestException:
65
+ pass
66
+
67
+ if time.time() - start_time > timeout:
68
+ raise TimeoutError("Server did not become ready in the expected time.")
69
+
70
+ time.sleep(5) # Poll every 5 seconds
71
+
72
+
73
+ def run_training_subprocess(command):
74
+ # Start the subprocess with pipes for stdout and stderr
75
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
76
+
77
+ # Create a lock for printing and a list to accumulate logs
78
+ print_lock = multiprocessing.Lock()
79
+ logs = multiprocessing.Manager().list()
80
+ error_detected = multiprocessing.Manager().list([False])
81
+
82
+ # Start two separate processes to handle stdout and stderr
83
+ stdout_processor = multiprocessing.Process(
84
+ target=capture_output, args=(process.stdout, print_lock, logs, error_detected)
85
+ )
86
+ stderr_processor = multiprocessing.Process(
87
+ target=capture_output, args=(process.stderr, print_lock, logs, error_detected)
88
+ )
89
+
90
+ # Start the log processors
91
+ stdout_processor.start()
92
+ stderr_processor.start()
93
+
94
+ # Wait for the subprocess to finish
95
+ process.wait()
96
+
97
+ # Wait for the log processors to finish
98
+ stdout_processor.join()
99
+ stderr_processor.join()
100
+
101
+ # Check if an error pattern was detected
102
+ if error_detected[0]:
103
+ raise Exception("Error detected in training logs! Check logs for details")
104
+
105
+ return list(logs)