FluttyProger commited on
Commit
37e3a2f
1 Parent(s): 01a50e9

Upload 9 files

Browse files
Files changed (9) hide show
  1. LICENSE +201 -0
  2. README.md +654 -0
  3. chat_interface.py +638 -0
  4. gradio_app.py +382 -0
  5. lib_omost/canvas.py +248 -0
  6. lib_omost/memory_management.py +67 -0
  7. lib_omost/pipeline.py +435 -0
  8. requirements.txt +13 -0
  9. run.bat +5 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Omost
2
+
3
+ Omost is a project to convert LLM's coding capability to image generation (or more accurately, image composing) capability.
4
+
5
+ The name `Omost` (pronunciation: almost) has two meanings: 1) everytime after you use Omost, your image is almost there; 2) the `O` mean "omni" (multi-modal) and `most` means we want to get the most out of it.
6
+
7
+ Omost provides LLMs models that will write codes to compose image visual contents with Omost's virtual `Canvas` agent. This `Canvas` can be rendered by specific implementations of image generators to actually generate images.
8
+
9
+ Currently, we provide 3 pretrained LLM models based on variations of Llama3 and Phi3 (see also the model notes at the end of this page).
10
+
11
+ All models are trained with mixed data of (1) ground-truth annotations of several datasets including Open-Images, (2) extracted data by automatically annotating images, (3) reinforcement from DPO (Direct Preference Optimization, "whether the codes can be compiled by python 3.10 or not" as a direct preference), and (4) a small amount of tuning data from OpenAI GPT4o's multi-modal capability.
12
+
13
+ # Get Started
14
+
15
+ You can just use the [official HuggingFace space](https://huggingface.co/spaces/lllyasviel/Omost).
16
+
17
+ Or, you can use the below deployment (requires 8GB Nvidia VRAM):
18
+
19
+ git clone https://github.com/lllyasviel/Omost.git
20
+ cd Omost
21
+ conda create -n omost python=3.10
22
+ conda activate omost
23
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
24
+ pip install -r requirements.txt
25
+ python gradio_app.py
26
+
27
+ (Note that quant LLM requires `bitsandbytes` - some 9XX or 10XX or 20XX GPUs may have trouble in running it. If that happens, just use our official huggingface space.)
28
+
29
+ # Screenshots
30
+
31
+ (All with random seed 12345)
32
+
33
+ a ragged man wearing a tattered jacket in the nineteenth century:
34
+
35
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/6acc5eac-87e4-428b-a209-1d4b947c590c)
36
+
37
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/c60dadfc-6e82-4582-9561-8389260714c0)
38
+
39
+ the greatest escape after the apocalypse, the hope of all humanity:
40
+
41
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/997d41ab-ea52-4036-898b-47c8b863acf4)
42
+
43
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/c6653558-bfd8-40d9-95fb-c5ec1b68c289)
44
+
45
+ jurassic dinosaur battle:
46
+
47
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/7efe0b56-9da2-4967-bbb0-ed3cf4eddd23)
48
+
49
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/f4b39798-488d-429c-bf96-27a7da6f91f7)
50
+
51
+ the supreme ruler with endless magic on a gothic throne, with undead armies kneeling:
52
+
53
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/336b7f94-0f8e-425a-b541-6e8cd5642a2e)
54
+
55
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/9de68d20-0c01-4cee-b656-04d59cbdc695)
56
+
57
+ a funny cartoon batman fights joker:
58
+
59
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/dc751ded-f0f7-4312-8cd2-2a3aced15a60)
60
+
61
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/f3e6e8e8-2e08-4e5e-a109-a8054461e71f)
62
+
63
+ # Example Transcripts (with Conversational Editing)
64
+
65
+ *(Below is a raw transcript from default model, with random seed 123456)*
66
+
67
+ *(If the GitHub's code block rendering is causing troubles to reading long lines, click [here](https://raw.githubusercontent.com/lllyasviel/Omost/main/README.md) to read raw texts.)*
68
+
69
+ ### User
70
+
71
+ generate an image of the fierce battle of warriors and the dragon
72
+
73
+ ### Assistant
74
+
75
+ ```python
76
+ # Initialize the canvas
77
+ canvas = Canvas()
78
+
79
+ # Set a global description for the canvas
80
+ canvas.set_global_description(
81
+ description='A fierce battle between warriors and a dragon.',
82
+ detailed_descriptions=[
83
+ 'In this intense scene, a group of fierce warriors is engaged in an epic battle with a mighty dragon.',
84
+ 'The warriors, clad in armor and wielding swords and shields, are positioned on the left side of the image.',
85
+ 'Their expressions are determined and focused, reflecting their resolve to defeat the dragon.',
86
+ 'The dragon, with its massive wings spread wide and its fiery breath illuminating the scene, dominates the center of the image.',
87
+ 'Its scales glisten with a metallic sheen, and its eyes burn with a fierce intensity.',
88
+ 'The background is a dark, stormy sky with flashes of lightning, adding to the drama and tension of the battle.',
89
+ 'The ground is littered with debris and the remnants of previous battles, enhancing the sense of a long and brutal conflict.',
90
+ 'The overall atmosphere is one of chaos and intensity, with the warriors and the dragon locked in a fierce struggle for survival.',
91
+ ],
92
+ tags='battle, warriors, dragon, fierce, armor, swords, shields, determined, focused, epic, intense, metallic, glistening, fiery breath, stormy sky, lightning, debris, conflict, struggle, survival, chaos, tension, drama, wings, scales, eyes, burning, illuminated',
93
+ HTML_web_color_name='darkslategray',
94
+ )
95
+
96
+ # Add a group of fierce warriors.
97
+ canvas.add_local_description(
98
+ location='on the left',
99
+ offset='no offset',
100
+ area='a large horizontal area',
101
+ distance_to_viewer=5.0,
102
+ description='A group of fierce warriors.',
103
+ detailed_descriptions=[
104
+ 'The warriors, clad in gleaming armor, are positioned on the left side of the image.',
105
+ 'They are armed with swords, shields, and spears, ready for battle.',
106
+ 'Their faces are set with determination and focus, reflecting their resolve to defeat the dragon.',
107
+ 'Some warriors are in mid-action, swinging their swords or shields, while others stand ready to strike.',
108
+ 'The armor they wear is intricately designed, with symbols and patterns that signify their rank and bravery.',
109
+ 'The ground beneath them is uneven and littered with debris, adding to the sense of a chaotic and intense battle.',
110
+ ],
111
+ tags='warriors, armor, swords, shields, spears, determined, focused, mid-action, intricate design, symbols, patterns, rank, bravery, uneven ground, debris, chaotic, intense, battle',
112
+ atmosphere='Determined and focused, ready for the fierce battle.',
113
+ style='Highly detailed and dynamic, capturing the intensity of the warriors.',
114
+ quality_meta='High resolution with intricate details and dynamic poses.',
115
+ HTML_web_color_name='darkgoldenrod',
116
+ )
117
+
118
+ # Add a mighty dragon.
119
+ canvas.add_local_description(
120
+ location='in the center',
121
+ offset='no offset',
122
+ area='a large square area',
123
+ distance_to_viewer=7.0,
124
+ description='A mighty dragon.',
125
+ detailed_descriptions=[
126
+ 'The dragon is a massive creature, dominating the center of the image with its wide-spread wings and fiery breath.',
127
+ 'Its scales glisten with a metallic sheen, reflecting the light from its fiery breath.',
128
+ "The dragon's eyes burn with a fierce intensity, and its teeth are sharp and menacing.",
129
+ 'The wings of the dragon are powerful and spread wide, casting shadows over the battlefield.',
130
+ 'The dragon’s body is muscular and covered in protective scales, with a long, sinewy tail that adds to its formidable appearance.',
131
+ 'The fiery breath of the dragon illuminates the scene, casting a reddish glow over the warriors and the battlefield.',
132
+ ],
133
+ tags='dragon, massive, wings, fiery breath, glistening scales, metallic sheen, fierce eyes, sharp teeth, powerful wings, shadows, battlefield, muscular body, protective scales, sinewy tail, formidable, reddish glow, illumination',
134
+ atmosphere='Intense and menacing, with a powerful presence.',
135
+ style='Epic and dramatic, emphasizing the grandeur and danger of the dragon.',
136
+ quality_meta='High resolution with dramatic lighting and detailed textures.',
137
+ HTML_web_color_name='firebrick',
138
+ )
139
+
140
+ # Add a stormy sky with flashes of lightning.
141
+ canvas.add_local_description(
142
+ location='on the top',
143
+ offset='no offset',
144
+ area='a large horizontal area',
145
+ distance_to_viewer=10.0,
146
+ description='A stormy sky with flashes of lightning.',
147
+ detailed_descriptions=[
148
+ 'The background of the image is a dark, stormy sky filled with swirling clouds and flashes of lightning.',
149
+ 'The sky is turbulent, with clouds dark and foreboding, adding to the dramatic tension of the battle.',
150
+ 'The lightning flashes illuminate the scene, casting sharp, brief lights over the warriors and the dragon.',
151
+ 'The stormy sky creates a sense of chaos and unpredictability, heightening the intensity of the battle below.',
152
+ 'The overall atmosphere is one of impending doom and relentless conflict, with the storm mirroring the fierce struggle between the warriors and the dragon.',
153
+ ],
154
+ tags='stormy sky, dark clouds, lightning, turbulent, foreboding, dramatic tension, illumination, chaos, unpredictability, intensity, impending doom, relentless conflict, battle, warriors, dragon, swirling clouds, sharp lights, brief lights',
155
+ atmosphere='Chaotic and intense, mirroring the fierce battle below.',
156
+ style='Dramatic and turbulent, emphasizing the conflict and tension.',
157
+ quality_meta='High resolution with dynamic lighting and detailed cloud textures.',
158
+ HTML_web_color_name='midnightblue',
159
+ )
160
+
161
+ # Add a debris-covered battlefield.
162
+ canvas.add_local_description(
163
+ location='on the bottom',
164
+ offset='no offset',
165
+ area='a large horizontal area',
166
+ distance_to_viewer=5.0,
167
+ description='A debris-covered battlefield.',
168
+ detailed_descriptions=[
169
+ 'The ground of the battlefield is littered with debris, remnants of previous battles.',
170
+ 'Broken weapons, shattered shields, and scattered armor pieces are strewn across the battlefield.',
171
+ 'The terrain is uneven, with patches of mud and dirt, adding to the sense of a long and brutal conflict.',
172
+ 'The debris-covered battlefield enhances the chaotic and intense atmosphere of the scene, reflecting the ferocity and duration of the battle.',
173
+ 'The overall appearance is one of destruction and turmoil, with the remnants of previous battles serving as a grim reminder of the ongoing struggle.',
174
+ ],
175
+ tags='battlefield, debris, broken weapons, shattered shields, scattered armor, uneven terrain, mud, dirt, brutal conflict, chaos, intensity, destruction, turmoil, previous battles, ongoing struggle, remnants, ferocity, duration',
176
+ atmosphere='Chaotic and intense, reflecting the ferocity of the battle.',
177
+ style='Detailed and realistic, emphasizing the destruction and chaos.',
178
+ quality_meta='High resolution with detailed textures and realistic debris.',
179
+ HTML_web_color_name='darkolivegreen',
180
+ )
181
+ ```
182
+
183
+ ### User
184
+
185
+ change the dragon to a dinosaur
186
+
187
+ ### Assistant
188
+
189
+ ```python
190
+ # Initialize the canvas
191
+ canvas = Canvas()
192
+
193
+ # Set a global description for the canvas
194
+ canvas.set_global_description(
195
+ description='A fierce battle between warriors and a dinosaur.',
196
+ detailed_descriptions=[
197
+ 'In this intense scene, a group of fierce warriors is engaged in an epic battle with a mighty dinosaur.',
198
+ 'The warriors, clad in armor and wielding swords and shields, are positioned on the left side of the image.',
199
+ 'Their expressions are determined and focused, reflecting their resolve to defeat the dinosaur.',
200
+ 'The dinosaur, with its massive body and powerful legs, dominates the center of the image.',
201
+ 'Its scales glisten with a metallic sheen, and its eyes burn with a fierce intensity.',
202
+ 'The background is a dark, stormy sky with flashes of lightning, adding to the drama and tension of the battle.',
203
+ 'The ground is littered with debris and the remnants of previous battles, enhancing the sense of a long and brutal conflict.',
204
+ 'The overall atmosphere is one of chaos and intensity, with the warriors and the dinosaur locked in a fierce struggle for survival.',
205
+ ],
206
+ tags='battle, warriors, dinosaur, fierce, armor, swords, shields, determined, focused, epic, intense, metallic, glistening, stormy sky, lightning, debris, conflict, struggle, survival, chaos, tension, drama, powerful legs, scales, eyes, burning, illuminated',
207
+ HTML_web_color_name='darkslategray',
208
+ )
209
+
210
+ # Add a group of fierce warriors.
211
+ canvas.add_local_description(
212
+ location='on the left',
213
+ offset='no offset',
214
+ area='a large horizontal area',
215
+ distance_to_viewer=5.0,
216
+ description='A group of fierce warriors.',
217
+ detailed_descriptions=[
218
+ 'The warriors, clad in gleaming armor, are positioned on the left side of the image.',
219
+ 'They are armed with swords, shields, and spears, ready for battle.',
220
+ 'Their faces are set with determination and focus, reflecting their resolve to defeat the dinosaur.',
221
+ 'Some warriors are in mid-action, swinging their swords or shields, while others stand ready to strike.',
222
+ 'The armor they wear is intricately designed, with symbols and patterns that signify their rank and bravery.',
223
+ 'The ground beneath them is uneven and littered with debris, adding to the sense of a chaotic and intense battle.',
224
+ ],
225
+ tags='warriors, armor, swords, shields, spears, determined, focused, mid-action, intricate design, symbols, patterns, rank, bravery, uneven ground, debris, chaotic, intense, battle',
226
+ atmosphere='Determined and focused, ready for the fierce battle.',
227
+ style='Highly detailed and dynamic, capturing the intensity of the warriors.',
228
+ quality_meta='High resolution with intricate details and dynamic poses.',
229
+ HTML_web_color_name='darkgoldenrod',
230
+ )
231
+
232
+ # Add a mighty dinosaur.
233
+ canvas.add_local_description(
234
+ location='in the center',
235
+ offset='no offset',
236
+ area='a large square area',
237
+ distance_to_viewer=7.0,
238
+ description='A mighty dinosaur.',
239
+ detailed_descriptions=[
240
+ 'The dinosaur is a massive creature, dominating the center of the image with its powerful legs and fierce roar.',
241
+ 'Its scales glisten with a metallic sheen, reflecting the light from the stormy sky.',
242
+ "The dinosaur's eyes burn with a fierce intensity, and its teeth are sharp and menacing.",
243
+ 'The powerful legs of the dinosaur are muscular and add to its formidable appearance.',
244
+ 'The ground shakes beneath its feet, enhancing the sense of a long and brutal conflict.',
245
+ 'The overall appearance is one of destruction and turmoil, with the dinosaur adding to the chaotic atmosphere.',
246
+ ],
247
+ tags='dinosaur, massive, powerful legs, fierce roar, glistening scales, metallic sheen, fierce eyes, sharp teeth, muscular legs, formidable, ground shaking, destruction, turmoil, chaos, battle',
248
+ atmosphere='Intense and menacing, with a powerful presence.',
249
+ style='Epic and dramatic, emphasizing the grandeur and danger of the dinosaur.',
250
+ quality_meta='High resolution with dramatic lighting and detailed textures.',
251
+ HTML_web_color_name='firebrick',
252
+ )
253
+
254
+ # Add a stormy sky with flashes of lightning.
255
+ canvas.add_local_description(
256
+ location='on the top',
257
+ offset='no offset',
258
+ area='a large horizontal area',
259
+ distance_to_viewer=10.0,
260
+ description='A stormy sky with flashes of lightning.',
261
+ detailed_descriptions=[
262
+ 'The background of the image is a dark, stormy sky filled with swirling clouds and flashes of lightning.',
263
+ 'The sky is turbulent, with clouds dark and foreboding, adding to the dramatic tension of the battle.',
264
+ 'The lightning flashes illuminate the scene, casting sharp, brief lights over the warriors and the dinosaur.',
265
+ 'The stormy sky creates a sense of chaos and unpredictability, heightening the intensity of the battle below.',
266
+ 'The overall atmosphere is one of impending doom and relentless conflict, with the storm mirroring the fierce struggle between the warriors and the dinosaur.',
267
+ ],
268
+ tags='stormy sky, dark clouds, lightning, turbulent, foreboding, dramatic tension, illumination, chaos, unpredictability, intensity, impending doom, relentless conflict, battle, warriors, dinosaur, swirling clouds, sharp lights, brief lights',
269
+ atmosphere='Chaotic and intense, mirroring the fierce battle below.',
270
+ style='Dramatic and turbulent, emphasizing the conflict and tension.',
271
+ quality_meta='High resolution with dynamic lighting and detailed cloud textures.',
272
+ HTML_web_color_name='midnightblue',
273
+ )
274
+
275
+ # Add a debris-covered battlefield.
276
+ canvas.add_local_description(
277
+ location='on the bottom',
278
+ offset='no offset',
279
+ area='a large horizontal area',
280
+ distance_to_viewer=5.0,
281
+ description='A debris-covered battlefield.',
282
+ detailed_descriptions=[
283
+ 'The ground of the battlefield is littered with debris, remnants of previous battles.',
284
+ 'Broken weapons, shattered shields, and scattered armor pieces are strewn across the battlefield.',
285
+ 'The terrain is uneven, with patches of mud and dirt, adding to the sense of a long and brutal conflict.',
286
+ 'The debris-covered battlefield enhances the chaotic and intense atmosphere of the scene, reflecting the ferocity and duration of the battle.',
287
+ 'The overall appearance is one of destruction and turmoil, with the remnants of previous battles serving as a grim reminder of the ongoing struggle.',
288
+ ],
289
+ tags='battlefield, debris, broken weapons, shattered shields, scattered armor, uneven terrain, mud, dirt, brutal conflict, chaos, intensity, destruction, turmoil, previous battles, ongoing struggle, remnants, ferocity, duration',
290
+ atmosphere='Chaotic and intense, reflecting the ferocity of the battle.',
291
+ style='Detailed and realistic, emphasizing the destruction and chaos.',
292
+ quality_meta='High resolution with detailed textures and realistic debris.',
293
+ HTML_web_color_name='darkolivegreen',
294
+ )
295
+ ```
296
+
297
+ (User can click the render button to actually render the image from any of those `Canvas` instances.)
298
+
299
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/c8d56f4c-f3e3-4338-8ffc-d9b14445fb54)
300
+
301
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/23d8b0fb-b504-442f-a449-48a7bf613bb4)
302
+
303
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/39434e27-aaf9-441f-a76a-3850ff480c5f)
304
+
305
+ # Symbols
306
+
307
+ All Omost LLMs are trained to obey the following symbols
308
+
309
+ ```python
310
+ class Canvas:
311
+ def set_global_description(
312
+ self,
313
+ description: str,
314
+ detailed_descriptions: list[str],
315
+ tags: str,
316
+ HTML_web_color_name: str
317
+ ):
318
+ pass
319
+
320
+ def add_local_description(
321
+ self,
322
+ location: str,
323
+ offset: str,
324
+ area: str,
325
+ distance_to_viewer: float,
326
+ description: str,
327
+ detailed_descriptions: list[str],
328
+ tags: str,
329
+ atmosphere: str,
330
+ style: str,
331
+ quality_meta: str,
332
+ HTML_web_color_name: str
333
+ ):
334
+ assert location in [
335
+ "in the center",
336
+ "on the left",
337
+ "on the right",
338
+ "on the top",
339
+ "on the bottom",
340
+ "on the top-left",
341
+ "on the top-right",
342
+ "on the bottom-left",
343
+ "on the bottom-right"
344
+ ]
345
+ assert offset in [
346
+ "no offset",
347
+ "slightly to the left",
348
+ "slightly to the right",
349
+ "slightly to the upper",
350
+ "slightly to the lower",
351
+ "slightly to the upper-left",
352
+ "slightly to the upper-right",
353
+ "slightly to the lower-left",
354
+ "slightly to the lower-right"
355
+ ]
356
+ assert area in [
357
+ "a small square area",
358
+ "a small vertical area",
359
+ "a small horizontal area",
360
+ "a medium-sized square area",
361
+ "a medium-sized vertical area",
362
+ "a medium-sized horizontal area",
363
+ "a large square area",
364
+ "a large vertical area",
365
+ "a large horizontal area"
366
+ ]
367
+ assert distance_to_viewer > 0
368
+ pass
369
+ ```
370
+
371
+ During training, the above symbols are associated with specific concepts and use cases related to image generation.
372
+
373
+ The design is to make those codes easy to learn for LLMs, but also easy to handle for diffusion models.
374
+
375
+ Lets breakdown each part:
376
+
377
+ ## Function: Canvas.set_global_description and Canvas.add_local_description
378
+
379
+ They set descriptions to images. The meanings of the parameters are same for them, with `add_local_description` having more fields than `set_global_description`.
380
+
381
+ The `set_global_description` annotate entire image, while `add_local_description` annotates a part of image.
382
+
383
+ ## Parameter: description and detailed_descriptions
384
+
385
+ Let us introduce a concept called "sub-prompt". If a prompt is less than 75 tokens, and is self-supported to describe a thing without relying on other prompts, we call it a "sub-prompt".
386
+
387
+ The `description` is a sub-prompt, and the `detailed_descriptions` is a list of sub-prompts.
388
+
389
+ Note that each sub-prompt is strictly less than 75 tokens (and typically less than 40 tokens), you can safely encode them with any clip without worrying the truncation position affecting the semantics.
390
+
391
+ The design of sub-prompt also allows more satisfying text encoding based on greedy merge. For example, if you have
392
+
393
+ sub-prompt A: 25 tokens
394
+ sub-prompt B: 35 tokens
395
+ sub-prompt C: 5 tokens
396
+ sub-prompt D: 60 tokens
397
+ sub-prompt E: 15 tokens
398
+ sub-prompt F: 25 tokens
399
+
400
+ and since every sub-prompt is promised to be self-supported to describe a thing independently, we can use greedy method to merge them to bags like
401
+
402
+ bag 1 {A, B, C} : 65 tokens
403
+ bag 2 {D} : 60 tokens
404
+ bag 1 {E, F} : 40 tokens
405
+
406
+ where each bag is less than 75 tokens and can be encoded by any clip in one pass (and then concat them).
407
+
408
+ Encoding texts in this way will make sure that text-encoder will never make semantic truncation mistakes.
409
+
410
+ One may ask - if all sub-prompts are less than 75 tokens with independent semantics, why not just encode them without merge and then concat? This is mainly because we want the text embedding to be more coherent. For example, lets say sub-prompt A is "a man" while sub-prompt B is "handsome, professional", then merging them before encoding will give you a more mixed text embedding concept with coherent features of a handsome professional man.
411
+
412
+ All Omost LLMs are trained to give strictly well-defined sub-prompts. You can make use of these definitions to design lossless text encoding methods.
413
+
414
+ ### Parameter: location, offset, area
415
+
416
+ The three parameters defines a bounding box. Note that they must obey
417
+
418
+ ```python
419
+ assert location in [
420
+ "in the center",
421
+ "on the left",
422
+ "on the right",
423
+ "on the top",
424
+ "on the bottom",
425
+ "on the top-left",
426
+ "on the top-right",
427
+ "on the bottom-left",
428
+ "on the bottom-right"
429
+ ]
430
+ assert offset in [
431
+ "no offset",
432
+ "slightly to the left",
433
+ "slightly to the right",
434
+ "slightly to the upper",
435
+ "slightly to the lower",
436
+ "slightly to the upper-left",
437
+ "slightly to the upper-right",
438
+ "slightly to the lower-left",
439
+ "slightly to the lower-right"
440
+ ]
441
+ assert area in [
442
+ "a small square area",
443
+ "a small vertical area",
444
+ "a small horizontal area",
445
+ "a medium-sized square area",
446
+ "a medium-sized vertical area",
447
+ "a medium-sized horizontal area",
448
+ "a large square area",
449
+ "a large vertical area",
450
+ "a large horizontal area"
451
+ ]
452
+ ```
453
+
454
+ First we divide a canvas into 3*3=9 locations:
455
+
456
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/5d39cf93-c229-4c83-ae82-3eeeae2fabea)
457
+
458
+ Then we further divide each location to 3\*3 offsets, resulting in 9\*9=81 positions:
459
+
460
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/b744d787-11f3-4aeb-9d3a-aeba7a41b433)
461
+
462
+ Using these positions as centers, we further define 9 types of bounding boxes:
463
+
464
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/0e484b73-680f-486b-8b61-4373c9eec9a0)
465
+
466
+ We can see that this method allows 9\*9\*9=729 different bounding boxes, covering almost all common possible locations of an object in the image.
467
+
468
+ One may argue that why this is necessary - why not just let the LLMs to learn pixel index or x, y coordinates - and should that be much more accurate? Below is several of my notes:
469
+
470
+ 1. I have tried several representations, including pixel index like {x=32, y=16, w=58, h=99}, or margin pixels like {left=32, right=15, top=27, bottom=33}, or percentage pixel index like {x=0.124, y=0.65, w=0.335, h=0.251}, or percentage margin like {left=0.251, right=0.154, top=0.254, bottom=0.441}. The result is that opensource LLMs are really not very good at learning these representations even for Llama3 (perhaps GPT4o can learn it). Sometimes it works sometimes it gives completely random numbers. Note that our problem is very different from MLLM. The vision-LLM usually have image embedding as inputs and in that case estimating numeric position is like a look-up table problem and can somewhat be learned, but our case is where the LLM need to generate every composition from scratch without help of any image embedding to look-up.
471
+ 2. But the natural language like "on the right", "slightly to the top-right", "a small vertical area" etc, works very well. The model converges very fast and the learning is stable. It aligns to the pretrained knowledge of LLMs very well.
472
+ 3. I have also tried adding some special tokens to represent spatial locations and also train the embedding layers. But that model is very difficult to train and debug. Also, the token-embedding-based method needs many hyperparameter tuning everytime we change the LLM - for example when changing from Llama3 to Phi, if we use the token-embedding method, we need to design training parameters again.
473
+ 4. The number 9\*9\*9=729 is not really a small number from the perspective of bounding box proposals. This can also be called ROI (region of interest) and some old semantic segmentation tech uses (RPN) Region Proposal Network to produce a similar number (<1000) of regions.
474
+ 5. Most region-guided diffusion methods are coarse-level methods (like multi-diffusion and attention couple and gligen), and they do not need pixel-perfect regions.
475
+ 6. These are very personal results from me - if you are working on some similar multi-modal LLM research, using pixel indices is completely okay, worth trying, and probably other training methods can also achieve a robust system.
476
+
477
+ ### Parameter: distance_to_viewer and HTML_web_color_name
478
+
479
+ The `distance_to_viewer` can be viewed as relative depth. Note that this value's absolute number is not reliable at all (because opensource LLMs are not very good at producing image-space numbers) and it should only be used in sorting elements into background-to-foreground layers.
480
+
481
+ You can always use `distance_to_viewer` to sort all local elements before rendering them using a diffusion model. The global description can be always viewed as the most far away background layer.
482
+
483
+ The `HTML_web_color_name` is one of these:
484
+
485
+ ```python
486
+ possible_HTML_web_color_names = { # r, g, b
487
+ 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
488
+ 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
489
+ 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
490
+ 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
491
+ 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
492
+ 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
493
+ 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
494
+ 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
495
+ 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
496
+ 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
497
+ 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
498
+ 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
499
+ 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
500
+ 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
501
+ 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
502
+ 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
503
+ 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
504
+ 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
505
+ 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
506
+ 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
507
+ 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
508
+ 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
509
+ 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
510
+ 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
511
+ 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
512
+ 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
513
+ 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
514
+ 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
515
+ 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
516
+ 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
517
+ 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
518
+ 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
519
+ 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
520
+ 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
521
+ 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
522
+ 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
523
+ 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
524
+ 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
525
+ 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
526
+ 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
527
+ 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
528
+ 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
529
+ 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
530
+ 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
531
+ 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
532
+ 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
533
+ 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
534
+ 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
535
+ }
536
+ ```
537
+
538
+ By combining `distance_to_viewer` and `HTML_web_color_name`, you can draw a very coarse image of the composition. For example, if the LLM works well, "a green bottle in front of a red bottle on a wood table in a dark room" should make it possible for you to compute an image like:
539
+
540
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/ab501872-bbcc-4fd4-8ab4-6fecd1a44d4d)
541
+
542
+ You can use this image as an initial latent and use denoise strength like 0.95 to 0.99 to generate the image.
543
+
544
+ Or if you do not like this and still prefer to let diffusion models to generate from zero-mean (even when you know that most diffusion models have tsnr problems), you can ignore this image and or just use this image as a debugger.
545
+
546
+ Besides, the layer sorting can also be useful in some very special attention formulation - we will discuss this later.
547
+
548
+ # Parameter: tags and atmosphere and style and quality_meta
549
+
550
+ The `tags` is designed as a possible replacement for the `description` since many diffusion models prefer tags. If used with anime models, one may hard code some logics to replace all "girl" to "1girl". If used with Pony then probably always hard code adding "score_9, score_8 ..." to this.
551
+
552
+ The `atmosphere` and `style` and `quality_meta` are some experimental parameters without very specific use cases. Current we can just treat them as sub-prompts and involve them in the greedy merge of sub-prompt bags. This in my experiments will improve the atmosphere and quality a bit.
553
+
554
+ # A Baseline Renderer
555
+
556
+ In this repo, we provide a baseline render for Omost LLMs based on attention manipulation.
557
+
558
+ ### Regional Prompter
559
+
560
+ As of 2024, if we want to achieve a region guided diffusion system, some possible options are:
561
+
562
+ 1. multi-diffusion / mixture-of-diffusers: these method run UNet on different locations, and then merge the estimated epsilon or x0 using weights or masks for different regions.
563
+ 2. attention decomposition: lets say attention is like `y=softmax(q@k)@v`, then one can achieve attention decomposition like `y=mask_A * softmax(q@k_A)@v_A + mask_B * softmax(q@k_B)@v_B` where mask_A, k_A, v_A are masks, k, v for region A; mask_B, k_B, v_B are masks, k, v for region B. This method usually yields image quality a bit better than (1) and some people call it Attention Couple or Region Prompter Attention Mode. But this method has a consideration: the mask only makes regional attention numerically possible, but it does not force the UNet to really attend its activations in those regions. That is to say, the attention is indeed masked, but there is no promise that the attention softmax will really be activated in the masked area, and there is also no promise that the attention softmax will never be activated outside the masked area.
564
+ 3. attention score manipulation: this is a more advanced method compared to (2). It directly manipulates the attention scores to make sure that the activations in mask each area are encouraged and those outside the masks are discouraged. The formulation is like `y=softmax(modify(q@k))@v` where `modify()` is a complicated non-linear function with many normalizations and tricks to change the score's distributions. This method goes beyond a simple masked attention to really make sure that those layers get wanted activations. A typical example is [Dense Diffusion](https://github.com/naver-ai/DenseDiffusion).
565
+ 4. gradient optimization: since the attention can tell us where each part is corresponding to what prompts, we can split prompts into segments and then get attention activations to each prompt segment. Then we compare those activations with external masks to compute a loss function, and back propagate the gradients. Those methods are usually very high quality but VRAM hungry and very slow. Typical methods are [BoxDiff](https://github.com/showlab/BoxDiff) and [Attend-and-Excite](https://github.com/yuval-alaluf/Attend-and-Excite).
566
+ 5. Use external control models like gligen and [InstanceDiffusion](https://github.com/frank-xwang/InstanceDiffusion). Those methods give the highest benchmark performance on region following but will also introduce some style offset to the base model since they are trained parameters. Also, those methods need to convert prompts to vectors and usually do not support prompts of arbitary length (but one can use them together with other attention methods to achieve arbitrary length).
567
+ 6. Some more possible layer options like layerdiffuse and [mulan](https://mulan-dataset.github.io/).
568
+
569
+ In this repo I wrote a baseline formulation based on (3). I consider this parameter-free formulation as a very standard baseline implementation that will almost introduce zero style offsets or quality degradation. In the future we may consider training some parametrized methods for Omost.
570
+
571
+ Lets consider an extremely simplified image with only 2\*2=4 pixels:
572
+
573
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/00f97ad6-202b-4a39-9091-da6d76b0aacb)
574
+
575
+ Then we have three prompts "two cats", "a black cat", "a white cat", and we have their masks:
576
+
577
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/f9f5e87c-5f82-41fe-8a49-580d3eb6f2be)
578
+
579
+ Then we can draw this attention score table:
580
+
581
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/a77936b3-050e-4894-9252-476713144f6c)
582
+
583
+ where the upper arrow mean that we want to encourage the activation, while the lower arrow means we want to get rid of those activation.
584
+
585
+ This manipulation directly modify attention scores and compute all prompts conditions in one single SDP attention pass. (See also the codes for more details.)
586
+
587
+ ### Prompt Prefix Tree
588
+
589
+ In this repo, I also included another trick that I find out to improve prompt understanding a lot. Lets call it a Prompt Prefix Tree. The motivation is that, since now that all our prompts are sub-prompts that can be merged arbitrarily (recall that all sub-prompts are strictly less than 75 tokens and typically less than 40 tokens, describe independent concepts, and can be arbitrarily merged as common prompts for clip to encode), finding a better method to merge those sub-prompts may improve the results and prompt interpretation.
590
+
591
+ For example below is a tree structure of global/local overall/detailed descriptions.
592
+
593
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/2b3a895f-9bb5-4da8-9d5d-989bac9c1a7e)
594
+
595
+ The idea is that, since all sub-prompts can be merged arbitrarily, we can use the paths in this tree graph as prompts.
596
+
597
+ For example the below path will give a prompt "A cat and a dog. The cat on sofa."
598
+
599
+ ![image](https://github.com/lllyasviel/Omost/assets/19834515/902ec939-65be-4232-912b-d1bc6f5da44f)
600
+
601
+ Note that we can use this together with greedy subprompt bag merging when a path exceed 75 tokens. And, if a path has remaining place to contain more subprompts, the greedy subprompt bag merging will also take care of it. And again, since all sub prompts describe independent concepts, the greedy subprompt bag merging never makes semantic truncation mistakes. So satisfying!
602
+
603
+ # Model Notes
604
+
605
+ Currently, we provide 3 models (you can get them by adding the prefix `https://huggingface.co/lllyasviel/` to the below names):
606
+
607
+ omost-llama-3-8b
608
+ omost-dolphin-2.9-llama3-8b
609
+ omost-phi-3-mini-128k
610
+
611
+ And their quant versions:
612
+
613
+ omost-llama-3-8b-4bits
614
+ omost-dolphin-2.9-llama3-8b-4bits
615
+ omost-phi-3-mini-128k-8bits
616
+
617
+ Some notes:
618
+
619
+ 1. The recommended quant for `omost-llama-3-8b` is 4bits, and for `omost-phi-3-mini-128k` (3.8B) is 8 bits. They all fit in 8GB VRAM without offloads. The performance degradation caused by quant is very minimal and I personally never observed any evidences of degradation. However, quant `omost-phi-3-mini-128k` into 4 bits is not recommended since I noticed some obvious performance degradation. The 4bit inference of `omost-phi-3-mini-128k` should be viewed as a last method in extreme cases when you really do not have more capable GPUs.
620
+ 2. My user study shows that `omost-llama-3-8b-4bits` > `omost-dolphin-2.9-llama3-8b-4bits` > `omost-phi-3-mini-128k-8bits`. So in most cases one should just use `omost-llama-3-8b-4bits`.
621
+ 3. The `omost-llama-3-8b` and `omost-phi-3-mini-128k` are trained with filtered safe data without NSFW or inappropriate contents. See (4) if you need a different option.
622
+ 4. The `omost-dolphin-2.9-llama3-8b` is trained with all data WITHOUT any filtering. You must apply your own safety alignment methods if you expose any service of `omost-dolphin-2.9-llama3-8b` to public.
623
+ 5. Note that the filtering in (3) is not because of any policy - the reason is that I noticed slight instability in training gradients in those models since they are pretrained with instruct following regulated by safety alignment, causing the performance to degrade a bit. But the instruct following of `omost-dolphin-2.9-llama3-8b` is pretrained with community efforts and do not have this problem.
624
+ 6. The 128k context length of `omost-phi-3-mini-128k` cannot be trusted. The performance of it will degrade a lot after the tokens reach about 8k. One should just view it as a model with about 8k content length.
625
+ 7. A model of 8k context length can do about 5 to 6 rounds of conversational editing. If you are about to run out of token lengths, use the UI to modify your message and respond again (this can be done with infinite times).
626
+ 8. All models are fully trained with our H100 clusters at precision fp16 without any tricks like quant or Q-LoRA etc. The optimizer is Adam without any tricks.
627
+ 9. You must also follow the licenses of Llama-3 and Phi-3.
628
+ 10. You can request us to train on other LLMs if reasonable and necessary.
629
+
630
+ # Cite
631
+
632
+ @Misc{omost,
633
+ author = {Omost Team},
634
+ title = {Omost GitHub Page},
635
+ year = {2024},
636
+ }
637
+
638
+ # Related Work
639
+
640
+ Also read ...
641
+
642
+ [DOCCI: Descriptions of Connected and Contrasting Images](https://google.github.io/docci/)
643
+
644
+ [(RPG-DiffusionMaster) Mastering Text-to-Image Diffusion: Recaptioning, Planning, and Generating with Multimodal LLMs](https://github.com/YangLing0818/RPG-DiffusionMaster)
645
+
646
+ [LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models](https://arxiv.org/abs/2305.13655) and [Self-correcting LLM-controlled Diffusion Models](https://arxiv.org/abs/2311.16090)
647
+
648
+ [MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://multidiffusion.github.io/)
649
+
650
+ [sd-webui-regional-prompter](https://github.com/hako-mikan/sd-webui-regional-prompter)
651
+
652
+ (please open issue or email me if you want to add more links here)
653
+
654
+
chat_interface.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
9
+
10
+ import anyio
11
+ from gradio_client.documentation import document
12
+
13
+ from gradio.blocks import Blocks
14
+ from gradio.components import (
15
+ Button,
16
+ Chatbot,
17
+ Component,
18
+ Markdown,
19
+ MultimodalTextbox,
20
+ State,
21
+ Textbox,
22
+ get_component_instance,
23
+ Dataset,
24
+ )
25
+ from gradio.events import Dependency, on
26
+ from gradio.helpers import special_args
27
+ from gradio.layouts import Accordion, Group, Row
28
+ from gradio.routes import Request
29
+ from gradio.themes import ThemeClass as Theme
30
+ from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
31
+
32
+
33
+ @document()
34
+ class ChatInterface(Blocks):
35
+ """
36
+ ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
37
+ a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
38
+ takes a function that governs the response of the chatbot based on the user input and chat history. Additional
39
+ parameters can be used to control the appearance and behavior of the demo.
40
+
41
+ Example:
42
+ import gradio as gr
43
+
44
+ def echo(message, history):
45
+ return message
46
+
47
+ demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
48
+ demo.launch()
49
+ Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo
50
+ Guides: creating-a-chatbot-fast, sharing-your-app
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ fn: Callable,
56
+ post_fn: Callable,
57
+ pre_fn: Callable,
58
+ chatbot: Chatbot,
59
+ *,
60
+ post_fn_kwargs: dict = None,
61
+ pre_fn_kwargs: dict = None,
62
+ multimodal: bool = False,
63
+ textbox: Textbox | MultimodalTextbox | None = None,
64
+ additional_inputs: str | Component | list[str | Component] | None = None,
65
+ additional_inputs_accordion_name: str | None = None,
66
+ additional_inputs_accordion: str | Accordion | None = None,
67
+ examples: Dataset = None,
68
+ title: str | None = None,
69
+ description: str | None = None,
70
+ theme: Theme | str | None = None,
71
+ css: str | None = None,
72
+ js: str | None = None,
73
+ head: str | None = None,
74
+ analytics_enabled: bool | None = None,
75
+ submit_btn: str | None | Button = "Submit",
76
+ stop_btn: str | None | Button = "Stop",
77
+ retry_btn: str | None | Button = "🔄 Retry",
78
+ undo_btn: str | None | Button = "↩️ Undo",
79
+ clear_btn: str | None | Button = "🗑️ Clear",
80
+ autofocus: bool = True,
81
+ concurrency_limit: int | None | Literal["default"] = "default",
82
+ fill_height: bool = True,
83
+ delete_cache: tuple[int, int] | None = None,
84
+ ):
85
+ super().__init__(
86
+ analytics_enabled=analytics_enabled,
87
+ mode="chat_interface",
88
+ css=css,
89
+ title=title or "Gradio",
90
+ theme=theme,
91
+ js=js,
92
+ head=head,
93
+ fill_height=fill_height,
94
+ delete_cache=delete_cache,
95
+ )
96
+
97
+ if post_fn_kwargs is None:
98
+ post_fn_kwargs = []
99
+
100
+ self.post_fn = post_fn
101
+ self.post_fn_kwargs = post_fn_kwargs
102
+
103
+ self.pre_fn = pre_fn
104
+ self.pre_fn_kwargs = pre_fn_kwargs
105
+
106
+ self.interrupter = State(None)
107
+
108
+ self.multimodal = multimodal
109
+ self.concurrency_limit = concurrency_limit
110
+ self.fn = fn
111
+ self.is_async = inspect.iscoroutinefunction(
112
+ self.fn
113
+ ) or inspect.isasyncgenfunction(self.fn)
114
+ self.is_generator = inspect.isgeneratorfunction(
115
+ self.fn
116
+ ) or inspect.isasyncgenfunction(self.fn)
117
+
118
+ if additional_inputs:
119
+ if not isinstance(additional_inputs, list):
120
+ additional_inputs = [additional_inputs]
121
+ self.additional_inputs = [
122
+ get_component_instance(i)
123
+ for i in additional_inputs # type: ignore
124
+ ]
125
+ else:
126
+ self.additional_inputs = []
127
+ if additional_inputs_accordion_name is not None:
128
+ print(
129
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
130
+ )
131
+ self.additional_inputs_accordion_params = {
132
+ "label": additional_inputs_accordion_name
133
+ }
134
+ if additional_inputs_accordion is None:
135
+ self.additional_inputs_accordion_params = {
136
+ "label": "Additional Inputs",
137
+ "open": False,
138
+ }
139
+ elif isinstance(additional_inputs_accordion, str):
140
+ self.additional_inputs_accordion_params = {
141
+ "label": additional_inputs_accordion
142
+ }
143
+ elif isinstance(additional_inputs_accordion, Accordion):
144
+ self.additional_inputs_accordion_params = (
145
+ additional_inputs_accordion.recover_kwargs(
146
+ additional_inputs_accordion.get_config()
147
+ )
148
+ )
149
+ else:
150
+ raise ValueError(
151
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
152
+ )
153
+
154
+ with self:
155
+ if title:
156
+ Markdown(
157
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
158
+ )
159
+ if description:
160
+ Markdown(description)
161
+
162
+ self.chatbot = chatbot.render()
163
+
164
+ self.buttons = [retry_btn, undo_btn, clear_btn]
165
+
166
+ with Group():
167
+ with Row():
168
+ if textbox:
169
+ if self.multimodal:
170
+ submit_btn = None
171
+ else:
172
+ textbox.container = False
173
+ textbox.show_label = False
174
+ textbox_ = textbox.render()
175
+ if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
176
+ raise TypeError(
177
+ f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}"
178
+ )
179
+ self.textbox = textbox_
180
+ elif self.multimodal:
181
+ submit_btn = None
182
+ self.textbox = MultimodalTextbox(
183
+ show_label=False,
184
+ label="Message",
185
+ placeholder="Type a message...",
186
+ scale=7,
187
+ autofocus=autofocus,
188
+ )
189
+ else:
190
+ self.textbox = Textbox(
191
+ container=False,
192
+ show_label=False,
193
+ label="Message",
194
+ placeholder="Type a message...",
195
+ scale=7,
196
+ autofocus=autofocus,
197
+ )
198
+ if submit_btn is not None and not multimodal:
199
+ if isinstance(submit_btn, Button):
200
+ submit_btn.render()
201
+ elif isinstance(submit_btn, str):
202
+ submit_btn = Button(
203
+ submit_btn,
204
+ variant="primary",
205
+ scale=1,
206
+ min_width=150,
207
+ )
208
+ else:
209
+ raise ValueError(
210
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
211
+ )
212
+ if stop_btn is not None:
213
+ if isinstance(stop_btn, Button):
214
+ stop_btn.visible = False
215
+ stop_btn.render()
216
+ elif isinstance(stop_btn, str):
217
+ stop_btn = Button(
218
+ stop_btn,
219
+ variant="stop",
220
+ visible=False,
221
+ scale=1,
222
+ min_width=150,
223
+ )
224
+ else:
225
+ raise ValueError(
226
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
227
+ )
228
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
229
+
230
+ self.fake_api_btn = Button("Fake API", visible=False)
231
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
232
+ (
233
+ self.retry_btn,
234
+ self.undo_btn,
235
+ self.clear_btn,
236
+ self.submit_btn,
237
+ self.stop_btn,
238
+ ) = self.buttons
239
+
240
+ any_unrendered_inputs = any(
241
+ not inp.is_rendered for inp in self.additional_inputs
242
+ )
243
+ if self.additional_inputs and any_unrendered_inputs:
244
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
245
+ for input_component in self.additional_inputs:
246
+ if not input_component.is_rendered:
247
+ input_component.render()
248
+
249
+ self.saved_input = State()
250
+ self.chatbot_state = (
251
+ State(self.chatbot.value) if self.chatbot.value else State([])
252
+ )
253
+
254
+ self._setup_events()
255
+ self._setup_api()
256
+
257
+ if examples:
258
+ examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False)
259
+
260
+ def _setup_events(self) -> None:
261
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
262
+ submit_triggers = (
263
+ [self.textbox.submit, self.submit_btn.click]
264
+ if self.submit_btn
265
+ else [self.textbox.submit]
266
+ )
267
+ submit_event = (
268
+ on(
269
+ submit_triggers,
270
+ self._clear_and_save_textbox,
271
+ [self.textbox],
272
+ [self.textbox, self.saved_input],
273
+ show_api=False,
274
+ queue=False,
275
+ )
276
+ .then(
277
+ self.pre_fn,
278
+ **self.pre_fn_kwargs,
279
+ show_api=False,
280
+ queue=False,
281
+ )
282
+ .then(
283
+ self._display_input,
284
+ [self.saved_input, self.chatbot_state],
285
+ [self.chatbot, self.chatbot_state],
286
+ show_api=False,
287
+ queue=False,
288
+ )
289
+ .then(
290
+ submit_fn,
291
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
292
+ [self.chatbot, self.chatbot_state, self.interrupter],
293
+ show_api=False,
294
+ concurrency_limit=cast(
295
+ Union[int, Literal["default"], None], self.concurrency_limit
296
+ ),
297
+ ).then(
298
+ self.post_fn,
299
+ **self.post_fn_kwargs,
300
+ show_api=False,
301
+ concurrency_limit=cast(
302
+ Union[int, Literal["default"], None], self.concurrency_limit
303
+ ),
304
+ )
305
+ )
306
+ self._setup_stop_events(submit_triggers, submit_event)
307
+
308
+ if self.retry_btn:
309
+ retry_event = (
310
+ self.retry_btn.click(
311
+ self._delete_prev_fn,
312
+ [self.saved_input, self.chatbot_state],
313
+ [self.chatbot, self.saved_input, self.chatbot_state],
314
+ show_api=False,
315
+ queue=False,
316
+ )
317
+ .then(
318
+ self.pre_fn,
319
+ **self.pre_fn_kwargs,
320
+ show_api=False,
321
+ queue=False,
322
+ )
323
+ .then(
324
+ self._display_input,
325
+ [self.saved_input, self.chatbot_state],
326
+ [self.chatbot, self.chatbot_state],
327
+ show_api=False,
328
+ queue=False,
329
+ )
330
+ .then(
331
+ submit_fn,
332
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
333
+ [self.chatbot, self.chatbot_state],
334
+ show_api=False,
335
+ concurrency_limit=cast(
336
+ Union[int, Literal["default"], None], self.concurrency_limit
337
+ ),
338
+ ).then(
339
+ self.post_fn,
340
+ **self.post_fn_kwargs,
341
+ show_api=False,
342
+ concurrency_limit=cast(
343
+ Union[int, Literal["default"], None], self.concurrency_limit
344
+ ),
345
+ )
346
+ )
347
+ self._setup_stop_events([self.retry_btn.click], retry_event)
348
+
349
+ if self.undo_btn:
350
+ self.undo_btn.click(
351
+ self._delete_prev_fn,
352
+ [self.saved_input, self.chatbot_state],
353
+ [self.chatbot, self.saved_input, self.chatbot_state],
354
+ show_api=False,
355
+ queue=False,
356
+ ).then(
357
+ self.pre_fn,
358
+ **self.pre_fn_kwargs,
359
+ show_api=False,
360
+ queue=False,
361
+ ).then(
362
+ async_lambda(lambda x: x),
363
+ [self.saved_input],
364
+ [self.textbox],
365
+ show_api=False,
366
+ queue=False,
367
+ ).then(
368
+ self.post_fn,
369
+ **self.post_fn_kwargs,
370
+ show_api=False,
371
+ concurrency_limit=cast(
372
+ Union[int, Literal["default"], None], self.concurrency_limit
373
+ ),
374
+ )
375
+
376
+ if self.clear_btn:
377
+ self.clear_btn.click(
378
+ async_lambda(lambda: ([], [], None)),
379
+ None,
380
+ [self.chatbot, self.chatbot_state, self.saved_input],
381
+ queue=False,
382
+ show_api=False,
383
+ ).then(
384
+ self.pre_fn,
385
+ **self.pre_fn_kwargs,
386
+ show_api=False,
387
+ queue=False,
388
+ ).then(
389
+ self.post_fn,
390
+ **self.post_fn_kwargs,
391
+ show_api=False,
392
+ concurrency_limit=cast(
393
+ Union[int, Literal["default"], None], self.concurrency_limit
394
+ ),
395
+ )
396
+
397
+ def _setup_stop_events(
398
+ self, event_triggers: list[Callable], event_to_cancel: Dependency
399
+ ) -> None:
400
+ def perform_interrupt(ipc):
401
+ if ipc is not None:
402
+ ipc()
403
+ return
404
+
405
+ if self.stop_btn and self.is_generator:
406
+ if self.submit_btn:
407
+ for event_trigger in event_triggers:
408
+ event_trigger(
409
+ async_lambda(
410
+ lambda: (
411
+ Button(visible=False),
412
+ Button(visible=True),
413
+ )
414
+ ),
415
+ None,
416
+ [self.submit_btn, self.stop_btn],
417
+ show_api=False,
418
+ queue=False,
419
+ )
420
+ event_to_cancel.then(
421
+ async_lambda(lambda: (Button(visible=True), Button(visible=False))),
422
+ None,
423
+ [self.submit_btn, self.stop_btn],
424
+ show_api=False,
425
+ queue=False,
426
+ )
427
+ else:
428
+ for event_trigger in event_triggers:
429
+ event_trigger(
430
+ async_lambda(lambda: Button(visible=True)),
431
+ None,
432
+ [self.stop_btn],
433
+ show_api=False,
434
+ queue=False,
435
+ )
436
+ event_to_cancel.then(
437
+ async_lambda(lambda: Button(visible=False)),
438
+ None,
439
+ [self.stop_btn],
440
+ show_api=False,
441
+ queue=False,
442
+ )
443
+ self.stop_btn.click(
444
+ fn=perform_interrupt,
445
+ inputs=[self.interrupter],
446
+ cancels=event_to_cancel,
447
+ show_api=False,
448
+ )
449
+
450
+ def _setup_api(self) -> None:
451
+ api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
452
+
453
+ self.fake_api_btn.click(
454
+ api_fn,
455
+ [self.textbox, self.chatbot_state] + self.additional_inputs,
456
+ [self.textbox, self.chatbot_state],
457
+ api_name="chat",
458
+ concurrency_limit=cast(
459
+ Union[int, Literal["default"], None], self.concurrency_limit
460
+ ),
461
+ )
462
+
463
+ def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
464
+ if self.multimodal:
465
+ return {"text": "", "files": []}, message
466
+ else:
467
+ return "", message
468
+
469
+ def _append_multimodal_history(
470
+ self,
471
+ message: dict[str, list],
472
+ response: str | None,
473
+ history: list[list[str | tuple | None]],
474
+ ):
475
+ for x in message["files"]:
476
+ history.append([(x,), None])
477
+ if message["text"] is None or not isinstance(message["text"], str):
478
+ return
479
+ elif message["text"] == "" and message["files"] != []:
480
+ history.append([None, response])
481
+ else:
482
+ history.append([message["text"], response])
483
+
484
+ async def _display_input(
485
+ self, message: str | dict[str, list], history: list[list[str | tuple | None]]
486
+ ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
487
+ if self.multimodal and isinstance(message, dict):
488
+ self._append_multimodal_history(message, None, history)
489
+ elif isinstance(message, str):
490
+ history.append([message, None])
491
+ return history, history
492
+
493
+ async def _submit_fn(
494
+ self,
495
+ message: str | dict[str, list],
496
+ history_with_input: list[list[str | tuple | None]],
497
+ request: Request,
498
+ *args,
499
+ ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
500
+ if self.multimodal and isinstance(message, dict):
501
+ remove_input = (
502
+ len(message["files"]) + 1
503
+ if message["text"] is not None
504
+ else len(message["files"])
505
+ )
506
+ history = history_with_input[:-remove_input]
507
+ else:
508
+ history = history_with_input[:-1]
509
+ inputs, _, _ = special_args(
510
+ self.fn, inputs=[message, history, *args], request=request
511
+ )
512
+
513
+ if self.is_async:
514
+ response = await self.fn(*inputs)
515
+ else:
516
+ response = await anyio.to_thread.run_sync(
517
+ self.fn, *inputs, limiter=self.limiter
518
+ )
519
+
520
+ if self.multimodal and isinstance(message, dict):
521
+ self._append_multimodal_history(message, response, history)
522
+ elif isinstance(message, str):
523
+ history.append([message, response])
524
+ return history, history
525
+
526
+ async def _stream_fn(
527
+ self,
528
+ message: str | dict[str, list],
529
+ history_with_input: list[list[str | tuple | None]],
530
+ request: Request,
531
+ *args,
532
+ ) -> AsyncGenerator:
533
+ if self.multimodal and isinstance(message, dict):
534
+ remove_input = (
535
+ len(message["files"]) + 1
536
+ if message["text"] is not None
537
+ else len(message["files"])
538
+ )
539
+ history = history_with_input[:-remove_input]
540
+ else:
541
+ history = history_with_input[:-1]
542
+ inputs, _, _ = special_args(
543
+ self.fn, inputs=[message, history, *args], request=request
544
+ )
545
+
546
+ if self.is_async:
547
+ generator = self.fn(*inputs)
548
+ else:
549
+ generator = await anyio.to_thread.run_sync(
550
+ self.fn, *inputs, limiter=self.limiter
551
+ )
552
+ generator = SyncToAsyncIterator(generator, self.limiter)
553
+ try:
554
+ first_response, first_interrupter = await async_iteration(generator)
555
+ if self.multimodal and isinstance(message, dict):
556
+ for x in message["files"]:
557
+ history.append([(x,), None])
558
+ update = history + [[message["text"], first_response]]
559
+ yield update, update
560
+ else:
561
+ update = history + [[message, first_response]]
562
+ yield update, update, first_interrupter
563
+ except StopIteration:
564
+ if self.multimodal and isinstance(message, dict):
565
+ self._append_multimodal_history(message, None, history)
566
+ yield history, history
567
+ else:
568
+ update = history + [[message, None]]
569
+ yield update, update, first_interrupter
570
+ async for response, interrupter in generator:
571
+ if self.multimodal and isinstance(message, dict):
572
+ update = history + [[message["text"], response]]
573
+ yield update, update
574
+ else:
575
+ update = history + [[message, response]]
576
+ yield update, update, interrupter
577
+
578
+ async def _api_submit_fn(
579
+ self, message: str, history: list[list[str | None]], request: Request, *args
580
+ ) -> tuple[str, list[list[str | None]]]:
581
+ inputs, _, _ = special_args(
582
+ self.fn, inputs=[message, history, *args], request=request
583
+ )
584
+
585
+ if self.is_async:
586
+ response = await self.fn(*inputs)
587
+ else:
588
+ response = await anyio.to_thread.run_sync(
589
+ self.fn, *inputs, limiter=self.limiter
590
+ )
591
+ history.append([message, response])
592
+ return response, history
593
+
594
+ async def _api_stream_fn(
595
+ self, message: str, history: list[list[str | None]], request: Request, *args
596
+ ) -> AsyncGenerator:
597
+ inputs, _, _ = special_args(
598
+ self.fn, inputs=[message, history, *args], request=request
599
+ )
600
+
601
+ if self.is_async:
602
+ generator = self.fn(*inputs)
603
+ else:
604
+ generator = await anyio.to_thread.run_sync(
605
+ self.fn, *inputs, limiter=self.limiter
606
+ )
607
+ generator = SyncToAsyncIterator(generator, self.limiter)
608
+ try:
609
+ first_response = await async_iteration(generator)
610
+ yield first_response, history + [[message, first_response]]
611
+ except StopIteration:
612
+ yield None, history + [[message, None]]
613
+ async for response in generator:
614
+ yield response, history + [[message, response]]
615
+
616
+ async def _delete_prev_fn(
617
+ self,
618
+ message: str | dict[str, list],
619
+ history: list[list[str | tuple | None]],
620
+ ) -> tuple[
621
+ list[list[str | tuple | None]],
622
+ str | dict[str, list],
623
+ list[list[str | tuple | None]],
624
+ ]:
625
+ if self.multimodal and isinstance(message, dict):
626
+ remove_input = (
627
+ len(message["files"]) + 1
628
+ if message["text"] is not None
629
+ else len(message["files"])
630
+ )
631
+ history = history[:-remove_input]
632
+ else:
633
+ while history:
634
+ deleted_a, deleted_b = history[-1]
635
+ history = history[:-1]
636
+ if isinstance(deleted_a, str) and isinstance(deleted_b, str):
637
+ break
638
+ return history, message or "", history
gradio_app.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
4
+ HF_TOKEN = None
5
+
6
+ import lib_omost.memory_management as memory_management
7
+ import uuid
8
+
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ import tempfile
13
+
14
+ gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio')
15
+ os.makedirs(gradio_temp_dir, exist_ok=True)
16
+
17
+ from threading import Thread
18
+
19
+ # Phi3 Hijack
20
+ from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel
21
+
22
+ Phi3PreTrainedModel._supports_sdpa = True
23
+
24
+ from PIL import Image
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
26
+ from diffusers import AutoencoderKL, UNet2DConditionModel
27
+ from diffusers.models.attention_processor import AttnProcessor2_0
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+ from lib_omost.pipeline import StableDiffusionXLOmostPipeline
30
+ from chat_interface import ChatInterface
31
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
32
+
33
+ import lib_omost.canvas as omost_canvas
34
+
35
+
36
+ # SDXL
37
+
38
+ sdxl_name = 'SG161222/RealVisXL_V4.0'
39
+ # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
40
+
41
+ tokenizer = CLIPTokenizer.from_pretrained(
42
+ sdxl_name, subfolder="tokenizer")
43
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
44
+ sdxl_name, subfolder="tokenizer_2")
45
+ text_encoder = CLIPTextModel.from_pretrained(
46
+ sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16")
47
+ text_encoder_2 = CLIPTextModel.from_pretrained(
48
+ sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16")
49
+ vae = AutoencoderKL.from_pretrained(
50
+ sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") # bfloat16 vae
51
+ unet = UNet2DConditionModel.from_pretrained(
52
+ sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16")
53
+
54
+ unet.set_attn_processor(AttnProcessor2_0())
55
+ vae.set_attn_processor(AttnProcessor2_0())
56
+
57
+ pipeline = StableDiffusionXLOmostPipeline(
58
+ vae=vae,
59
+ text_encoder=text_encoder,
60
+ tokenizer=tokenizer,
61
+ text_encoder_2=text_encoder_2,
62
+ tokenizer_2=tokenizer_2,
63
+ unet=unet,
64
+ scheduler=None, # We completely give up diffusers sampling system and use A1111's method
65
+ )
66
+
67
+ memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet])
68
+
69
+ # LLM
70
+
71
+ # llm_name = 'lllyasviel/omost-phi-3-mini-128k-8bits'
72
+ llm_name = 'lllyasviel/omost-llama-3-8b-4bits'
73
+ # llm_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits'
74
+
75
+ llm_model = AutoModelForCausalLM.from_pretrained(
76
+ llm_name,
77
+ torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config.
78
+ token=HF_TOKEN,
79
+ device_map="auto" # This will load model to gpu with an offload system
80
+ )
81
+
82
+ llm_tokenizer = AutoTokenizer.from_pretrained(
83
+ llm_name,
84
+ token=HF_TOKEN
85
+ )
86
+
87
+ memory_management.unload_all_models(llm_model)
88
+
89
+
90
+ @torch.inference_mode()
91
+ def pytorch2numpy(imgs):
92
+ results = []
93
+ for x in imgs:
94
+ y = x.movedim(0, -1)
95
+ y = y * 127.5 + 127.5
96
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
97
+ results.append(y)
98
+ return results
99
+
100
+
101
+ @torch.inference_mode()
102
+ def numpy2pytorch(imgs):
103
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
104
+ h = h.movedim(-1, 1)
105
+ return h
106
+
107
+
108
+ def resize_without_crop(image, target_width, target_height):
109
+ pil_image = Image.fromarray(image)
110
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
111
+ return np.array(resized_image)
112
+
113
+
114
+ @torch.inference_mode()
115
+ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str:
116
+ np.random.seed(int(seed))
117
+ torch.manual_seed(int(seed))
118
+
119
+ conversation = [{"role": "system", "content": omost_canvas.system_prompt}]
120
+
121
+ for user, assistant in history:
122
+ if isinstance(user, str) and isinstance(assistant, str):
123
+ if len(user) > 0 and len(assistant) > 0:
124
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
125
+
126
+ conversation.append({"role": "user", "content": message})
127
+
128
+ memory_management.load_models_to_gpu(llm_model)
129
+
130
+ input_ids = llm_tokenizer.apply_chat_template(
131
+ conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device)
132
+
133
+ streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
134
+
135
+ def interactive_stopping_criteria(*args, **kwargs) -> bool:
136
+ if getattr(streamer, 'user_interrupted', False):
137
+ print('User stopped generation')
138
+ return True
139
+ else:
140
+ return False
141
+
142
+ stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria])
143
+
144
+ def interrupter():
145
+ streamer.user_interrupted = True
146
+ return
147
+
148
+ generate_kwargs = dict(
149
+ input_ids=input_ids,
150
+ streamer=streamer,
151
+ stopping_criteria=stopping_criteria,
152
+ max_new_tokens=max_new_tokens,
153
+ do_sample=True,
154
+ temperature=temperature,
155
+ top_p=top_p,
156
+ )
157
+
158
+ if temperature == 0:
159
+ generate_kwargs['do_sample'] = False
160
+
161
+ Thread(target=llm_model.generate, kwargs=generate_kwargs).start()
162
+
163
+ outputs = []
164
+ for text in streamer:
165
+ outputs.append(text)
166
+ # print(outputs)
167
+ yield "".join(outputs), interrupter
168
+
169
+ return
170
+
171
+
172
+ @torch.inference_mode()
173
+ def post_chat(history):
174
+ canvas_outputs = None
175
+
176
+ try:
177
+ if history:
178
+ history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)]
179
+ last_assistant = history[-1][1] if len(history) > 0 else None
180
+ canvas = omost_canvas.Canvas.from_bot_response(last_assistant)
181
+ canvas_outputs = canvas.process()
182
+ except Exception as e:
183
+ print('Last assistant response is not valid canvas:', e)
184
+
185
+ return canvas_outputs, gr.update(visible=canvas_outputs is not None), gr.update(interactive=len(history) > 0)
186
+
187
+
188
+ @torch.inference_mode()
189
+ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height,
190
+ highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt):
191
+
192
+ use_initial_latent = False
193
+ eps = 0.05
194
+
195
+ image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64
196
+
197
+ rng = torch.Generator(device=memory_management.gpu).manual_seed(seed)
198
+
199
+ memory_management.load_models_to_gpu([text_encoder, text_encoder_2])
200
+
201
+ positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt)
202
+
203
+ if use_initial_latent:
204
+ memory_management.load_models_to_gpu([vae])
205
+ initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0
206
+ initial_latent_blur = 40
207
+ initial_latent = torch.nn.functional.avg_pool2d(
208
+ torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'),
209
+ kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1))
210
+ initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width))
211
+ initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device)
212
+ initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor
213
+ else:
214
+ initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32)
215
+
216
+ memory_management.load_models_to_gpu([unet])
217
+
218
+ initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device)
219
+
220
+ latents = pipeline(
221
+ initial_latent=initial_latent,
222
+ strength=1.0,
223
+ num_inference_steps=int(steps),
224
+ batch_size=num_samples,
225
+ prompt_embeds=positive_cond,
226
+ negative_prompt_embeds=negative_cond,
227
+ pooled_prompt_embeds=positive_pooler,
228
+ negative_pooled_prompt_embeds=negative_pooler,
229
+ generator=rng,
230
+ guidance_scale=float(cfg),
231
+ ).images
232
+
233
+ memory_management.load_models_to_gpu([vae])
234
+ latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
235
+ pixels = vae.decode(latents).sample
236
+ B, C, H, W = pixels.shape
237
+ pixels = pytorch2numpy(pixels)
238
+
239
+ if highres_scale > 1.0 + eps:
240
+ pixels = [
241
+ resize_without_crop(
242
+ image=p,
243
+ target_width=int(round(W * highres_scale / 64.0) * 64),
244
+ target_height=int(round(H * highres_scale / 64.0) * 64)
245
+ ) for p in pixels
246
+ ]
247
+
248
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
249
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
250
+
251
+ memory_management.load_models_to_gpu([unet])
252
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
253
+
254
+ latents = pipeline(
255
+ initial_latent=latents,
256
+ strength=highres_denoise,
257
+ num_inference_steps=highres_steps,
258
+ batch_size=num_samples,
259
+ prompt_embeds=positive_cond,
260
+ negative_prompt_embeds=negative_cond,
261
+ pooled_prompt_embeds=positive_pooler,
262
+ negative_pooled_prompt_embeds=negative_pooler,
263
+ generator=rng,
264
+ guidance_scale=float(cfg),
265
+ ).images
266
+
267
+ memory_management.load_models_to_gpu([vae])
268
+ latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
269
+ pixels = vae.decode(latents).sample
270
+ pixels = pytorch2numpy(pixels)
271
+
272
+ for i in range(len(pixels)):
273
+ unique_hex = uuid.uuid4().hex
274
+ image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png")
275
+ image = Image.fromarray(pixels[i])
276
+ image.save(image_path)
277
+ chatbot = chatbot + [(None, (image_path, 'image'))]
278
+
279
+ return chatbot
280
+
281
+
282
+ css = '''
283
+ code {white-space: pre-wrap !important;}
284
+ .gradio-container {max-width: none !important;}
285
+ .outer_parent {flex: 1;}
286
+ .inner_parent {flex: 1;}
287
+ footer {display: none !important; visibility: hidden !important;}
288
+ .translucent {display: none !important; visibility: hidden !important;}
289
+ '''
290
+
291
+ from gradio.themes.utils import colors
292
+
293
+ with gr.Blocks(
294
+ fill_height=True, css=css,
295
+ theme=gr.themes.Default(primary_hue=colors.blue, secondary_hue=colors.cyan, neutral_hue=colors.gray)
296
+ ) as demo:
297
+ with gr.Row(elem_classes='outer_parent'):
298
+ with gr.Column(scale=25):
299
+ with gr.Row():
300
+ clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60)
301
+ retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False)
302
+ undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False)
303
+
304
+ seed = gr.Number(label="Random Seed", value=12345, precision=0)
305
+
306
+ with gr.Accordion(open=True, label='Language Model'):
307
+ with gr.Group():
308
+ with gr.Row():
309
+ temperature = gr.Slider(
310
+ minimum=0.0,
311
+ maximum=2.0,
312
+ step=0.01,
313
+ value=0.6,
314
+ label="Temperature")
315
+ top_p = gr.Slider(
316
+ minimum=0.0,
317
+ maximum=1.0,
318
+ step=0.01,
319
+ value=0.9,
320
+ label="Top P")
321
+ max_new_tokens = gr.Slider(
322
+ minimum=128,
323
+ maximum=4096,
324
+ step=1,
325
+ value=4096,
326
+ label="Max New Tokens")
327
+ with gr.Accordion(open=True, label='Image Diffusion Model'):
328
+ with gr.Group():
329
+ with gr.Row():
330
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64)
331
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64)
332
+
333
+ with gr.Row():
334
+ num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1)
335
+ steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1)
336
+
337
+ with gr.Accordion(open=False, label='Advanced'):
338
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01)
339
+ highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01)
340
+ highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1)
341
+ highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01)
342
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
343
+
344
+ render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False)
345
+
346
+ examples = gr.Dataset(
347
+ samples=[
348
+ ['generate an image of the fierce battle of warriors and a dragon'],
349
+ ['change the dragon to a dinosaur']
350
+ ],
351
+ components=[gr.Textbox(visible=False)],
352
+ label='Quick Prompts'
353
+ )
354
+ with gr.Column(scale=75, elem_classes='inner_parent'):
355
+ canvas_state = gr.State(None)
356
+ chatbot = gr.Chatbot(label='Omost', scale=1, show_copy_button=True, layout="panel", render=False)
357
+ chatInterface = ChatInterface(
358
+ fn=chat_fn,
359
+ post_fn=post_chat,
360
+ post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button, undo_btn]),
361
+ pre_fn=lambda: gr.update(visible=False),
362
+ pre_fn_kwargs=dict(outputs=[render_button]),
363
+ chatbot=chatbot,
364
+ retry_btn=retry_btn,
365
+ undo_btn=undo_btn,
366
+ clear_btn=clear_btn,
367
+ additional_inputs=[seed, temperature, top_p, max_new_tokens],
368
+ examples=examples
369
+ )
370
+
371
+ render_button.click(
372
+ fn=diffusion_fn, inputs=[
373
+ chatInterface.chatbot, canvas_state,
374
+ num_samples, seed, image_width, image_height, highres_scale,
375
+ steps, cfg, highres_steps, highres_denoise, n_prompt
376
+ ], outputs=[chatInterface.chatbot]).then(
377
+ fn=lambda x: x, inputs=[
378
+ chatInterface.chatbot
379
+ ], outputs=[chatInterface.chatbot_state])
380
+
381
+ if __name__ == "__main__":
382
+ demo.queue().launch(inbrowser=True, server_name='0.0.0.0')
lib_omost/canvas.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import difflib
3
+ import numpy as np
4
+
5
+ system_prompt = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
6
+
7
+ ```python
8
+ class Canvas:
9
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
10
+ pass
11
+
12
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
13
+ assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
14
+ assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
15
+ assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
16
+ assert distance_to_viewer > 0
17
+ pass
18
+ ```'''
19
+
20
+ valid_colors = { # r, g, b
21
+ 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
22
+ 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
23
+ 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
24
+ 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
25
+ 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
26
+ 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
27
+ 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
28
+ 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
29
+ 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
30
+ 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
31
+ 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
32
+ 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
33
+ 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
34
+ 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
35
+ 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
36
+ 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
37
+ 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
38
+ 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
39
+ 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
40
+ 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
41
+ 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
42
+ 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
43
+ 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
44
+ 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
45
+ 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
46
+ 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
47
+ 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
48
+ 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
49
+ 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
50
+ 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
51
+ 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
52
+ 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
53
+ 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
54
+ 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
55
+ 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
56
+ 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
57
+ 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
58
+ 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
59
+ 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
60
+ 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
61
+ 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
62
+ 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
63
+ 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
64
+ 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
65
+ 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
66
+ 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
67
+ 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
68
+ 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
69
+ }
70
+
71
+ valid_locations = { # x, y in 90*90
72
+ 'in the center': (45, 45),
73
+ 'on the left': (15, 45),
74
+ 'on the right': (75, 45),
75
+ 'on the top': (45, 15),
76
+ 'on the bottom': (45, 75),
77
+ 'on the top-left': (15, 15),
78
+ 'on the top-right': (75, 15),
79
+ 'on the bottom-left': (15, 75),
80
+ 'on the bottom-right': (75, 75)
81
+ }
82
+
83
+ valid_offsets = { # x, y in 90*90
84
+ 'no offset': (0, 0),
85
+ 'slightly to the left': (-10, 0),
86
+ 'slightly to the right': (10, 0),
87
+ 'slightly to the upper': (0, -10),
88
+ 'slightly to the lower': (0, 10),
89
+ 'slightly to the upper-left': (-10, -10),
90
+ 'slightly to the upper-right': (10, -10),
91
+ 'slightly to the lower-left': (-10, 10),
92
+ 'slightly to the lower-right': (10, 10)}
93
+
94
+ valid_areas = { # w, h in 90*90
95
+ "a small square area": (50, 50),
96
+ "a small vertical area": (40, 60),
97
+ "a small horizontal area": (60, 40),
98
+ "a medium-sized square area": (60, 60),
99
+ "a medium-sized vertical area": (50, 80),
100
+ "a medium-sized horizontal area": (80, 50),
101
+ "a large square area": (70, 70),
102
+ "a large vertical area": (60, 90),
103
+ "a large horizontal area": (90, 60)
104
+ }
105
+
106
+
107
+ def closest_name(input_str, options):
108
+ input_str = input_str.lower()
109
+
110
+ closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
111
+ assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
112
+ result = closest_match[0]
113
+
114
+ if result != input_str:
115
+ print(f'Automatically corrected [{input_str}] -> [{result}].')
116
+
117
+ return result
118
+
119
+
120
+ def safe_str(x):
121
+ return x.strip(',. ') + '.'
122
+
123
+
124
+ def binary_nonzero_positions(n, offset=0):
125
+ binary_str = bin(n)[2:]
126
+ positions = [i + offset for i, bit in enumerate(reversed(binary_str)) if bit == '1']
127
+ return positions
128
+
129
+
130
+ class Canvas:
131
+ @staticmethod
132
+ def from_bot_response(response: str):
133
+ matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
134
+ assert matched, 'Response does not contain codes!'
135
+ code_content = matched.group(1)
136
+ assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
137
+ local_vars = {'Canvas': Canvas}
138
+ exec(code_content, {}, local_vars)
139
+ canvas = local_vars.get('canvas', None)
140
+ assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
141
+ return canvas
142
+
143
+ def __init__(self):
144
+ self.components = []
145
+ self.color = None
146
+ self.record_tags = True
147
+ self.prefixes = []
148
+ self.suffixes = []
149
+ return
150
+
151
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
152
+ HTML_web_color_name: str):
153
+ assert isinstance(description, str), 'Global description is not valid!'
154
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
155
+ 'Global detailed_descriptions is not valid!'
156
+ assert isinstance(tags, str), 'Global tags is not valid!'
157
+
158
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
159
+ self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
160
+
161
+ self.prefixes = [description]
162
+ self.suffixes = detailed_descriptions
163
+
164
+ if self.record_tags:
165
+ self.suffixes = self.suffixes + [tags]
166
+
167
+ self.prefixes = [safe_str(x) for x in self.prefixes]
168
+ self.suffixes = [safe_str(x) for x in self.suffixes]
169
+
170
+ return
171
+
172
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
173
+ detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
174
+ quality_meta: str, HTML_web_color_name: str):
175
+ assert isinstance(description, str), 'Local description is wrong!'
176
+ assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
177
+ f'The distance_to_viewer for [{description}] is not positive float number!'
178
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
179
+ f'The detailed_descriptions for [{description}] is not valid!'
180
+ assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
181
+ assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
182
+ assert isinstance(style, str), f'The style for [{description}] is not valid!'
183
+ assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
184
+
185
+ location = closest_name(location, valid_locations)
186
+ offset = closest_name(offset, valid_offsets)
187
+ area = closest_name(area, valid_areas)
188
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
189
+
190
+ xb, yb = valid_locations[location]
191
+ xo, yo = valid_offsets[offset]
192
+ w, h = valid_areas[area]
193
+ rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
194
+ rect = [max(0, min(90, i)) for i in rect]
195
+ color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
196
+
197
+ prefixes = self.prefixes + [description]
198
+ suffixes = detailed_descriptions
199
+
200
+ if self.record_tags:
201
+ suffixes = suffixes + [tags, atmosphere, style, quality_meta]
202
+
203
+ prefixes = [safe_str(x) for x in prefixes]
204
+ suffixes = [safe_str(x) for x in suffixes]
205
+
206
+ self.components.append(dict(
207
+ rect=rect,
208
+ distance_to_viewer=distance_to_viewer,
209
+ color=color,
210
+ prefixes=prefixes,
211
+ suffixes=suffixes
212
+ ))
213
+
214
+ return
215
+
216
+ def process(self):
217
+ # sort components
218
+ self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
219
+
220
+ # compute initial latent
221
+ initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
222
+
223
+ for component in self.components:
224
+ a, b, c, d = component['rect']
225
+ initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
226
+
227
+ initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
228
+
229
+ # compute conditions
230
+
231
+ bag_of_conditions = [
232
+ dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
233
+ ]
234
+
235
+ for i, component in enumerate(self.components):
236
+ a, b, c, d = component['rect']
237
+ m = np.zeros(shape=(90, 90), dtype=np.float32)
238
+ m[a:b, c:d] = 1.0
239
+ bag_of_conditions.append(dict(
240
+ mask=m,
241
+ prefixes=component['prefixes'],
242
+ suffixes=component['suffixes']
243
+ ))
244
+
245
+ return dict(
246
+ initial_latent=initial_latent,
247
+ bag_of_conditions=bag_of_conditions,
248
+ )
lib_omost/memory_management.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ high_vram = False
6
+ gpu = torch.device('cuda')
7
+ cpu = torch.device('cpu')
8
+
9
+ torch.zeros((1, 1)).to(gpu, torch.float32)
10
+ torch.cuda.empty_cache()
11
+
12
+ models_in_gpu = []
13
+
14
+
15
+ @contextmanager
16
+ def movable_bnb_model(m):
17
+ if hasattr(m, 'quantization_method'):
18
+ m.quantization_method_backup = m.quantization_method
19
+ del m.quantization_method
20
+ try:
21
+ yield None
22
+ finally:
23
+ if hasattr(m, 'quantization_method_backup'):
24
+ m.quantization_method = m.quantization_method_backup
25
+ del m.quantization_method_backup
26
+ return
27
+
28
+
29
+ def load_models_to_gpu(models):
30
+ global models_in_gpu
31
+
32
+ if not isinstance(models, (tuple, list)):
33
+ models = [models]
34
+
35
+ models_to_remain = [m for m in set(models) if m in models_in_gpu]
36
+ models_to_load = [m for m in set(models) if m not in models_in_gpu]
37
+ models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
38
+
39
+ if not high_vram:
40
+ for m in models_to_unload:
41
+ with movable_bnb_model(m):
42
+ m.to(cpu)
43
+ print('Unload to CPU:', m.__class__.__name__)
44
+ models_in_gpu = models_to_remain
45
+
46
+ for m in models_to_load:
47
+ with movable_bnb_model(m):
48
+ m.to(gpu)
49
+ print('Load to GPU:', m.__class__.__name__)
50
+
51
+ models_in_gpu = list(set(models_in_gpu + models))
52
+ torch.cuda.empty_cache()
53
+ return
54
+
55
+
56
+ def unload_all_models(extra_models=None):
57
+ global models_in_gpu
58
+
59
+ if extra_models is None:
60
+ extra_models = []
61
+
62
+ if not isinstance(extra_models, (tuple, list)):
63
+ extra_models = [extra_models]
64
+
65
+ models_in_gpu = list(set(models_in_gpu + extra_models))
66
+
67
+ return load_models_to_gpu([])
lib_omost/pipeline.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import copy
3
+
4
+ from tqdm.auto import trange
5
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import *
6
+ from diffusers.models.transformers import Transformer2DModel
7
+
8
+
9
+ original_Transformer2DModel_forward = Transformer2DModel.forward
10
+
11
+
12
+ def hacked_Transformer2DModel_forward(
13
+ self,
14
+ hidden_states: torch.Tensor,
15
+ encoder_hidden_states: Optional[torch.Tensor] = None,
16
+ timestep: Optional[torch.LongTensor] = None,
17
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
18
+ class_labels: Optional[torch.LongTensor] = None,
19
+ cross_attention_kwargs: Dict[str, Any] = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ encoder_attention_mask: Optional[torch.Tensor] = None,
22
+ return_dict: bool = True,
23
+ ):
24
+ cross_attention_kwargs = cross_attention_kwargs or {}
25
+ cross_attention_kwargs['hidden_states_original_shape'] = hidden_states.shape
26
+ return original_Transformer2DModel_forward(
27
+ self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs,
28
+ attention_mask, encoder_attention_mask, return_dict)
29
+
30
+
31
+ Transformer2DModel.forward = hacked_Transformer2DModel_forward
32
+
33
+
34
+ @torch.no_grad()
35
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
36
+ """DPM-Solver++(2M)."""
37
+ extra_args = {} if extra_args is None else extra_args
38
+ s_in = x.new_ones([x.shape[0]])
39
+ sigma_fn = lambda t: t.neg().exp()
40
+ t_fn = lambda sigma: sigma.log().neg()
41
+ old_denoised = None
42
+
43
+ for i in trange(len(sigmas) - 1, disable=disable):
44
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
45
+ if callback is not None:
46
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
47
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
48
+ h = t_next - t
49
+ if old_denoised is None or sigmas[i + 1] == 0:
50
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
51
+ else:
52
+ h_last = t - t_fn(sigmas[i - 1])
53
+ r = h_last / h
54
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
55
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
56
+ old_denoised = denoised
57
+ return x
58
+
59
+
60
+ class KModel:
61
+ def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012):
62
+ betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
63
+ alphas = 1. - betas
64
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
65
+
66
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
67
+ self.log_sigmas = self.sigmas.log()
68
+ self.sigma_data = 1.0
69
+ self.unet = unet
70
+ return
71
+
72
+ @property
73
+ def sigma_min(self):
74
+ return self.sigmas[0]
75
+
76
+ @property
77
+ def sigma_max(self):
78
+ return self.sigmas[-1]
79
+
80
+ def timestep(self, sigma):
81
+ log_sigma = sigma.log()
82
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
83
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
84
+
85
+ def get_sigmas_karras(self, n, rho=7.):
86
+ ramp = torch.linspace(0, 1, n)
87
+ min_inv_rho = self.sigma_min ** (1 / rho)
88
+ max_inv_rho = self.sigma_max ** (1 / rho)
89
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
90
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
91
+
92
+ def __call__(self, x, sigma, **extra_args):
93
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
94
+ t = self.timestep(sigma)
95
+ cfg_scale = extra_args['cfg_scale']
96
+ eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
97
+ eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
98
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
99
+ return x - noise_pred * sigma[:, None, None, None]
100
+
101
+
102
+ class OmostSelfAttnProcessor:
103
+ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
104
+ batch_size, sequence_length, _ = hidden_states.shape
105
+
106
+ query = attn.to_q(hidden_states)
107
+ key = attn.to_k(hidden_states)
108
+ value = attn.to_v(hidden_states)
109
+
110
+ inner_dim = key.shape[-1]
111
+ head_dim = inner_dim // attn.heads
112
+
113
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
114
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
115
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
116
+
117
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
118
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
119
+ )
120
+
121
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
122
+ hidden_states = hidden_states.to(query.dtype)
123
+ hidden_states = attn.to_out[0](hidden_states)
124
+ hidden_states = attn.to_out[1](hidden_states)
125
+
126
+ return hidden_states
127
+
128
+
129
+ class OmostCrossAttnProcessor:
130
+ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
131
+ B, C, H, W = hidden_states_original_shape
132
+
133
+ conds = []
134
+ masks = []
135
+
136
+ for m, c in encoder_hidden_states:
137
+ m = torch.nn.functional.interpolate(m[None, None, :, :], (H, W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, c.size(1))
138
+ conds.append(c)
139
+ masks.append(m)
140
+
141
+ conds = torch.cat(conds, dim=1)
142
+ masks = torch.cat(masks, dim=1)
143
+
144
+ mask_bool = masks > 0.5
145
+ mask_scale = (H * W) / torch.sum(masks, dim=0, keepdim=True)
146
+
147
+ batch_size, sequence_length, _ = conds.shape
148
+
149
+ query = attn.to_q(hidden_states)
150
+ key = attn.to_k(conds)
151
+ value = attn.to_v(conds)
152
+
153
+ inner_dim = key.shape[-1]
154
+ head_dim = inner_dim // attn.heads
155
+
156
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
157
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
158
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
159
+
160
+ mask_bool = mask_bool[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
161
+ mask_scale = mask_scale[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
162
+
163
+ sim = query @ key.transpose(-2, -1) * attn.scale
164
+ sim = sim * mask_scale.to(sim)
165
+ sim.masked_fill_(mask_bool.logical_not(), float("-inf"))
166
+ sim = sim.softmax(dim=-1)
167
+
168
+ h = sim @ value
169
+ h = h.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
170
+
171
+ h = attn.to_out[0](h)
172
+ h = attn.to_out[1](h)
173
+ return h
174
+
175
+
176
+ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
177
+ def __init__(self, *args, **kwargs):
178
+ super().__init__(*args, **kwargs)
179
+ self.k_model = KModel(unet=self.unet)
180
+
181
+ attn_procs = {}
182
+ for name in self.unet.attn_processors.keys():
183
+ if name.endswith("attn2.processor"):
184
+ attn_procs[name] = OmostCrossAttnProcessor()
185
+ else:
186
+ attn_procs[name] = OmostSelfAttnProcessor()
187
+
188
+ self.unet.set_attn_processor(attn_procs)
189
+ return
190
+
191
+ @torch.inference_mode()
192
+ def encode_bag_of_subprompts_greedy(self, prefixes: list[str], suffixes: list[str]):
193
+ device = self.text_encoder.device
194
+
195
+ @torch.inference_mode()
196
+ def greedy_partition(items, max_sum):
197
+ bags = []
198
+ current_bag = []
199
+ current_sum = 0
200
+
201
+ for item in items:
202
+ num = item['length']
203
+ if current_sum + num > max_sum:
204
+ if current_bag:
205
+ bags.append(current_bag)
206
+ current_bag = [item]
207
+ current_sum = num
208
+ else:
209
+ current_bag.append(item)
210
+ current_sum += num
211
+
212
+ if current_bag:
213
+ bags.append(current_bag)
214
+
215
+ return bags
216
+
217
+ @torch.inference_mode()
218
+ def get_77_tokens_in_torch(subprompt_inds, tokenizer):
219
+ # Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
220
+ result = [tokenizer.bos_token_id] + subprompt_inds[:75] + [tokenizer.eos_token_id] + [tokenizer.pad_token_id] * 75
221
+ result = result[:77]
222
+ result = torch.tensor([result]).to(device=device, dtype=torch.int64)
223
+ return result
224
+
225
+ @torch.inference_mode()
226
+ def merge_with_prefix(bag):
227
+ merged_ids_t1 = copy.deepcopy(prefix_ids_t1)
228
+ merged_ids_t2 = copy.deepcopy(prefix_ids_t2)
229
+
230
+ for item in bag:
231
+ merged_ids_t1.extend(item['ids_t1'])
232
+ merged_ids_t2.extend(item['ids_t2'])
233
+
234
+ return dict(
235
+ ids_t1=get_77_tokens_in_torch(merged_ids_t1, self.tokenizer),
236
+ ids_t2=get_77_tokens_in_torch(merged_ids_t2, self.tokenizer_2)
237
+ )
238
+
239
+ @torch.inference_mode()
240
+ def double_encode(pair_of_inds):
241
+ inds = [pair_of_inds['ids_t1'], pair_of_inds['ids_t2']]
242
+ text_encoders = [self.text_encoder, self.text_encoder_2]
243
+
244
+ pooled_prompt_embeds = None
245
+ prompt_embeds_list = []
246
+
247
+ for text_input_ids, text_encoder in zip(inds, text_encoders):
248
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)
249
+
250
+ # Only last pooler_output is needed
251
+ pooled_prompt_embeds = prompt_embeds.pooler_output
252
+
253
+ # "2" because SDXL always indexes from the penultimate layer.
254
+ prompt_embeds = prompt_embeds.hidden_states[-2]
255
+ prompt_embeds_list.append(prompt_embeds)
256
+
257
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
258
+ return prompt_embeds, pooled_prompt_embeds
259
+
260
+ # Begin with tokenizing prefixes
261
+
262
+ prefix_length = 0
263
+ prefix_ids_t1 = []
264
+ prefix_ids_t2 = []
265
+
266
+ for prefix in prefixes:
267
+ ids_t1 = self.tokenizer(prefix, truncation=False, add_special_tokens=False).input_ids
268
+ ids_t2 = self.tokenizer_2(prefix, truncation=False, add_special_tokens=False).input_ids
269
+ assert len(ids_t1) == len(ids_t2)
270
+ prefix_length += len(ids_t1)
271
+ prefix_ids_t1 += ids_t1
272
+ prefix_ids_t2 += ids_t2
273
+
274
+ # Then tokenizing suffixes
275
+
276
+ allowed_suffix_length = 75 - prefix_length
277
+ suffix_targets = []
278
+
279
+ for subprompt in suffixes:
280
+ # Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
281
+ # So we can safely just crop it to 75
282
+ ids_t1 = self.tokenizer(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
283
+ ids_t2 = self.tokenizer_2(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
284
+ assert len(ids_t1) == len(ids_t2)
285
+ suffix_targets.append(dict(
286
+ length=len(ids_t1),
287
+ ids_t1=ids_t1,
288
+ ids_t2=ids_t2
289
+ ))
290
+
291
+ # Then merge prefix and suffix tokens
292
+
293
+ suffix_targets = greedy_partition(suffix_targets, max_sum=allowed_suffix_length)
294
+ targets = [merge_with_prefix(b) for b in suffix_targets]
295
+
296
+ # Encode!
297
+
298
+ conds, poolers = [], []
299
+
300
+ for target in targets:
301
+ cond, pooler = double_encode(target)
302
+ conds.append(cond)
303
+ poolers.append(pooler)
304
+
305
+ conds_merged = torch.concat(conds, dim=1)
306
+ poolers_merged = poolers[0]
307
+
308
+ return dict(cond=conds_merged, pooler=poolers_merged)
309
+
310
+ @torch.inference_mode()
311
+ def all_conds_from_canvas(self, canvas_outputs, negative_prompt):
312
+ mask_all = torch.ones(size=(90, 90), dtype=torch.float32)
313
+ negative_cond, negative_pooler = self.encode_cropped_prompt_77tokens(negative_prompt)
314
+ negative_result = [(mask_all, negative_cond)]
315
+
316
+ positive_result = []
317
+ positive_pooler = None
318
+
319
+ for item in canvas_outputs['bag_of_conditions']:
320
+ current_mask = torch.from_numpy(item['mask']).to(torch.float32)
321
+ current_prefixes = item['prefixes']
322
+ current_suffixes = item['suffixes']
323
+
324
+ current_cond = self.encode_bag_of_subprompts_greedy(prefixes=current_prefixes, suffixes=current_suffixes)
325
+
326
+ if positive_pooler is None:
327
+ positive_pooler = current_cond['pooler']
328
+
329
+ positive_result.append((current_mask, current_cond['cond']))
330
+
331
+ return positive_result, positive_pooler, negative_result, negative_pooler
332
+
333
+ @torch.inference_mode()
334
+ def encode_cropped_prompt_77tokens(self, prompt: str):
335
+ device = self.text_encoder.device
336
+ tokenizers = [self.tokenizer, self.tokenizer_2]
337
+ text_encoders = [self.text_encoder, self.text_encoder_2]
338
+
339
+ pooled_prompt_embeds = None
340
+ prompt_embeds_list = []
341
+
342
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
343
+ text_input_ids = tokenizer(
344
+ prompt,
345
+ padding="max_length",
346
+ max_length=tokenizer.model_max_length,
347
+ truncation=True,
348
+ return_tensors="pt",
349
+ ).input_ids
350
+
351
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
352
+
353
+ # Only last pooler_output is needed
354
+ pooled_prompt_embeds = prompt_embeds.pooler_output
355
+
356
+ # "2" because SDXL always indexes from the penultimate layer.
357
+ prompt_embeds = prompt_embeds.hidden_states[-2]
358
+ prompt_embeds_list.append(prompt_embeds)
359
+
360
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
361
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
362
+
363
+ return prompt_embeds, pooled_prompt_embeds
364
+
365
+ @torch.inference_mode()
366
+ def __call__(
367
+ self,
368
+ initial_latent: torch.FloatTensor = None,
369
+ strength: float = 1.0,
370
+ num_inference_steps: int = 25,
371
+ guidance_scale: float = 5.0,
372
+ batch_size: Optional[int] = 1,
373
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
374
+ prompt_embeds: Optional[torch.FloatTensor] = None,
375
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
376
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
377
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
378
+ cross_attention_kwargs: Optional[dict] = None,
379
+ ):
380
+
381
+ device = self.unet.device
382
+ cross_attention_kwargs = cross_attention_kwargs or {}
383
+
384
+ # Sigmas
385
+
386
+ sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps / strength))
387
+ sigmas = sigmas[-(num_inference_steps + 1):].to(device)
388
+
389
+ # Initial latents
390
+
391
+ _, C, H, W = initial_latent.shape
392
+ noise = randn_tensor((batch_size, C, H, W), generator=generator, device=device, dtype=self.unet.dtype)
393
+ latents = initial_latent.to(noise) + noise * sigmas[0].to(noise)
394
+
395
+ # Shape
396
+
397
+ height, width = latents.shape[-2:]
398
+ height = height * self.vae_scale_factor
399
+ width = width * self.vae_scale_factor
400
+
401
+ add_time_ids = list((height, width) + (0, 0) + (height, width))
402
+ add_time_ids = torch.tensor([add_time_ids], dtype=self.unet.dtype)
403
+ add_neg_time_ids = add_time_ids.clone()
404
+
405
+ # Batch
406
+
407
+ latents = latents.to(device)
408
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(device)
409
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device)
410
+ prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in prompt_embeds]
411
+ negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in negative_prompt_embeds]
412
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
413
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
414
+
415
+ # Feeds
416
+
417
+ sampler_kwargs = dict(
418
+ cfg_scale=guidance_scale,
419
+ positive=dict(
420
+ encoder_hidden_states=prompt_embeds,
421
+ added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
422
+ cross_attention_kwargs=cross_attention_kwargs
423
+ ),
424
+ negative=dict(
425
+ encoder_hidden_states=negative_prompt_embeds,
426
+ added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids},
427
+ cross_attention_kwargs=cross_attention_kwargs
428
+ )
429
+ )
430
+
431
+ # Sample
432
+
433
+ results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False)
434
+
435
+ return StableDiffusionXLPipelineOutput(images=results)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.28.0
2
+ transformers==4.41.1
3
+ gradio==4.31.5
4
+ bitsandbytes==0.43.1
5
+ accelerate==0.30.1
6
+ protobuf==3.20
7
+ opencv-python
8
+ tensorboardX
9
+ safetensors
10
+ pillow
11
+ einops
12
+ torch
13
+ peft
run.bat ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ "%~dp0venv\Scripts\python.exe" "gradio_app.py"
4
+
5
+ PAUSE