layerdiffusion commited on
Commit
9ab270d
1 Parent(s): fdeb859
Files changed (5) hide show
  1. LICENSE +201 -0
  2. app.py +357 -8
  3. chat_interface.py +628 -0
  4. lib_omost/canvas.py +248 -0
  5. lib_omost/pipeline.py +435 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py CHANGED
@@ -1,14 +1,363 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import spaces
 
 
 
 
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
1
+ # import gradio as gr
2
+ #
3
+ # import torch
4
+ #
5
+ # zero = torch.Tensor([0]).cuda()
6
+ # print(zero.device) # <-- 'cpu' 🤔
7
+ #
8
+ # @spaces.GPU
9
+ # def greet(n):
10
+ # print(zero.device) # <-- 'cuda:0' 🤗
11
+ # return f"Hello {zero + n} Tensor"
12
+ #
13
+ # demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
+ # demo.launch()
15
+
16
+ import os
17
  import spaces
18
+
19
+ os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
20
+ HF_TOKEN = os.environ['hf_token'] if 'hf_token' in os.environ else None
21
+
22
+ import uuid
23
  import torch
24
+ import numpy as np
25
+ import gradio as gr
26
+ import tempfile
27
+
28
+ gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio')
29
+ os.makedirs(gradio_temp_dir, exist_ok=True)
30
+
31
+ from threading import Thread
32
+
33
+ # Phi3 Hijack
34
+ from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel
35
+
36
+ Phi3PreTrainedModel._supports_sdpa = True
37
+
38
+ from PIL import Image
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
40
+ from diffusers import AutoencoderKL, UNet2DConditionModel
41
+ from diffusers.models.attention_processor import AttnProcessor2_0
42
+ from transformers import CLIPTextModel, CLIPTokenizer
43
+ from lib_omost.pipeline import StableDiffusionXLOmostPipeline
44
+ from chat_interface import ChatInterface
45
+
46
+ import lib_omost.canvas as omost_canvas
47
+
48
+
49
+ # SDXL
50
+
51
+ sdxl_name = 'SG161222/RealVisXL_V4.0'
52
+ # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
53
+
54
+ tokenizer = CLIPTokenizer.from_pretrained(
55
+ sdxl_name, subfolder="tokenizer")
56
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
57
+ sdxl_name, subfolder="tokenizer_2")
58
+ text_encoder = CLIPTextModel.from_pretrained(
59
+ sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto")
60
+ text_encoder_2 = CLIPTextModel.from_pretrained(
61
+ sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto")
62
+ vae = AutoencoderKL.from_pretrained(
63
+ sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae
64
+ unet = UNet2DConditionModel.from_pretrained(
65
+ sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto")
66
+
67
+ unet.set_attn_processor(AttnProcessor2_0())
68
+ vae.set_attn_processor(AttnProcessor2_0())
69
+
70
+ pipeline = StableDiffusionXLOmostPipeline(
71
+ vae=vae,
72
+ text_encoder=text_encoder,
73
+ tokenizer=tokenizer,
74
+ text_encoder_2=text_encoder_2,
75
+ tokenizer_2=tokenizer_2,
76
+ unet=unet,
77
+ scheduler=None, # We completely give up diffusers sampling system and use A1111's method
78
+ )
79
+
80
+ # LLM
81
+
82
+ # model_name = 'lllyasviel/omost-phi-3-mini-128k-8bits'
83
+ llm_name = 'lllyasviel/omost-llama-3-8b-4bits'
84
+ # model_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits'
85
+
86
+ llm_model = AutoModelForCausalLM.from_pretrained(
87
+ llm_name,
88
+ torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config.
89
+ token=HF_TOKEN,
90
+ device_map="auto"
91
+ )
92
+
93
+ llm_tokenizer = AutoTokenizer.from_pretrained(
94
+ llm_name,
95
+ token=HF_TOKEN
96
+ )
97
+
98
+
99
+ @torch.inference_mode()
100
+ def pytorch2numpy(imgs):
101
+ results = []
102
+ for x in imgs:
103
+ y = x.movedim(0, -1)
104
+ y = y * 127.5 + 127.5
105
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
106
+ results.append(y)
107
+ return results
108
+
109
+
110
+ @torch.inference_mode()
111
+ def numpy2pytorch(imgs):
112
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
113
+ h = h.movedim(-1, 1)
114
+ return h
115
+
116
+
117
+ def resize_without_crop(image, target_width, target_height):
118
+ pil_image = Image.fromarray(image)
119
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
120
+ return np.array(resized_image)
121
 
 
 
122
 
123
  @spaces.GPU
124
+ @torch.inference_mode()
125
+ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str:
126
+ np.random.seed(int(seed))
127
+ torch.manual_seed(int(seed))
128
+
129
+ conversation = [{"role": "system", "content": omost_canvas.system_prompt}]
130
+
131
+ for user, assistant in history:
132
+ if user is None or assistant is None:
133
+ continue
134
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
135
+
136
+ conversation.append({"role": "user", "content": message})
137
+
138
+ input_ids = llm_tokenizer.apply_chat_template(
139
+ conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device)
140
+
141
+ streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
142
+
143
+ generate_kwargs = dict(
144
+ input_ids=input_ids,
145
+ streamer=streamer,
146
+ max_new_tokens=max_new_tokens,
147
+ do_sample=True,
148
+ temperature=temperature,
149
+ top_p=top_p,
150
+ )
151
+
152
+ if temperature == 0:
153
+ generate_kwargs['do_sample'] = False
154
+
155
+ Thread(target=llm_model.generate, kwargs=generate_kwargs).start()
156
+
157
+ outputs = []
158
+ for text in streamer:
159
+ outputs.append(text)
160
+ # print(outputs)
161
+ yield "".join(outputs)
162
+
163
+ return
164
+
165
+
166
+ @torch.inference_mode()
167
+ def post_chat(history):
168
+ history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)]
169
+ last_assistant = history[-1][1]
170
+ canvas_outputs = None
171
+
172
+ try:
173
+ canvas = omost_canvas.Canvas.from_bot_response(last_assistant)
174
+ canvas_outputs = canvas.process()
175
+ except Exception as e:
176
+ print('Last assistant response is not valid canvas:', e)
177
+
178
+ return canvas_outputs, gr.update(visible=canvas_outputs is not None)
179
+
180
+
181
+ @spaces.GPU
182
+ @torch.inference_mode()
183
+ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height,
184
+ highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt):
185
+
186
+ use_initial_latent = False
187
+ eps = 0.05
188
+
189
+ image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64
190
+
191
+ rng = torch.Generator(unet.device).manual_seed(seed)
192
+
193
+ positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt)
194
+
195
+ if use_initial_latent:
196
+ initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0
197
+ initial_latent_blur = 40
198
+ initial_latent = torch.nn.functional.avg_pool2d(
199
+ torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'),
200
+ kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1))
201
+ initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width))
202
+ initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device)
203
+ initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor
204
+ else:
205
+ initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32)
206
+
207
+ initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device)
208
+
209
+ latents = pipeline(
210
+ initial_latent=initial_latent,
211
+ strength=1.0,
212
+ num_inference_steps=int(steps),
213
+ batch_size=num_samples,
214
+ prompt_embeds=positive_cond,
215
+ negative_prompt_embeds=negative_cond,
216
+ pooled_prompt_embeds=positive_pooler,
217
+ negative_pooled_prompt_embeds=negative_pooler,
218
+ generator=rng,
219
+ guidance_scale=float(cfg),
220
+ ).images
221
+
222
+ latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
223
+ pixels = vae.decode(latents).sample
224
+ B, C, H, W = pixels.shape
225
+ pixels = pytorch2numpy(pixels)
226
+
227
+ if highres_scale > 1.0 + eps:
228
+ pixels = [
229
+ resize_without_crop(
230
+ image=p,
231
+ target_width=int(round(W * highres_scale / 64.0) * 64),
232
+ target_height=int(round(H * highres_scale / 64.0) * 64)
233
+ ) for p in pixels
234
+ ]
235
+
236
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
237
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
238
+
239
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
240
+
241
+ latents = pipeline(
242
+ initial_latent=latents,
243
+ strength=highres_denoise,
244
+ num_inference_steps=highres_steps,
245
+ batch_size=num_samples,
246
+ prompt_embeds=positive_cond,
247
+ negative_prompt_embeds=negative_cond,
248
+ pooled_prompt_embeds=positive_pooler,
249
+ negative_pooled_prompt_embeds=negative_pooler,
250
+ generator=rng,
251
+ guidance_scale=float(cfg),
252
+ ).images
253
+
254
+ latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
255
+ pixels = vae.decode(latents).sample
256
+ pixels = pytorch2numpy(pixels)
257
+
258
+ for i in range(len(pixels)):
259
+ unique_hex = uuid.uuid4().hex
260
+ image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png")
261
+ image = Image.fromarray(pixels[i])
262
+ image.save(image_path)
263
+ chatbot = chatbot + [(None, (image_path, 'image'))]
264
+
265
+ return chatbot
266
+
267
+
268
+ css = '''
269
+ code {white-space: pre-wrap !important;}
270
+ .gradio-container {max-width: none !important;}
271
+ .outer_parent {flex: 1;}
272
+ .inner_parent {flex: 1;}
273
+ footer {display: none !important; visibility: hidden !important;}
274
+ .translucent {display: none !important; visibility: hidden !important;}
275
+ '''
276
+
277
+ with gr.Blocks(fill_height=True, css=css) as demo:
278
+ with gr.Row(elem_classes='outer_parent'):
279
+ with gr.Column(scale=25):
280
+ with gr.Row():
281
+ retry_btn = gr.Button("🔄 Retry", variant="secondary", size="sm", min_width=60)
282
+ undo_btn = gr.Button("↩️ Undo", variant="secondary", size="sm", min_width=60)
283
+ clear_btn = gr.Button("⭐️ New Chat", variant="secondary", size="sm", min_width=60)
284
+
285
+ seed = gr.Number(label="Random Seed", value=12345, precision=0)
286
+
287
+ with gr.Accordion(open=True, label='Language Model'):
288
+ with gr.Group():
289
+ with gr.Row():
290
+ temperature = gr.Slider(
291
+ minimum=0.0,
292
+ maximum=2.0,
293
+ step=0.01,
294
+ value=0.6,
295
+ label="Temperature")
296
+ top_p = gr.Slider(
297
+ minimum=0.0,
298
+ maximum=1.0,
299
+ step=0.01,
300
+ value=0.9,
301
+ label="Top P")
302
+ max_new_tokens = gr.Slider(
303
+ minimum=128,
304
+ maximum=4096,
305
+ step=1,
306
+ value=4096,
307
+ label="Max New Tokens")
308
+ with gr.Accordion(open=True, label='Image Diffusion Model'):
309
+ with gr.Group():
310
+ with gr.Row():
311
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64)
312
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64)
313
+
314
+ with gr.Row():
315
+ num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1)
316
+ steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1)
317
+
318
+ with gr.Accordion(open=False, label='Advanced'):
319
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01)
320
+ highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01)
321
+ highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1)
322
+ highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01)
323
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
324
+
325
+ render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False)
326
+
327
+ examples = gr.Dataset(
328
+ samples=[
329
+ ['generate an image of the fierce battle of warriors and a dragon'],
330
+ ['change the dragon to a dinosaur']
331
+ ],
332
+ components=[gr.Textbox(visible=False)],
333
+ label='Quick Prompts'
334
+ )
335
+ with gr.Column(scale=75, elem_classes='inner_parent'):
336
+ canvas_state = gr.State(None)
337
+ chatbot = gr.Chatbot(label='Omost', scale=1, bubble_full_width=True, render=False)
338
+ chatInterface = ChatInterface(
339
+ fn=chat_fn,
340
+ post_fn=post_chat,
341
+ post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button]),
342
+ pre_fn=lambda: gr.update(visible=False),
343
+ pre_fn_kwargs=dict(outputs=[render_button]),
344
+ chatbot=chatbot,
345
+ retry_btn=retry_btn,
346
+ undo_btn=undo_btn,
347
+ clear_btn=clear_btn,
348
+ additional_inputs=[seed, temperature, top_p, max_new_tokens],
349
+ examples=examples
350
+ )
351
+
352
+ render_button.click(
353
+ fn=diffusion_fn, inputs=[
354
+ chatInterface.chatbot, canvas_state,
355
+ num_samples, seed, image_width, image_height, highres_scale,
356
+ steps, cfg, highres_steps, highres_denoise, n_prompt
357
+ ], outputs=[chatInterface.chatbot]).then(
358
+ fn=lambda x: x, inputs=[
359
+ chatInterface.chatbot
360
+ ], outputs=[chatInterface.chatbot_state])
361
 
362
+ if __name__ == "__main__":
363
+ demo.queue().launch(inbrowser=True, server_name='0.0.0.0')
chat_interface.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
9
+
10
+ import anyio
11
+ from gradio_client.documentation import document
12
+
13
+ from gradio.blocks import Blocks
14
+ from gradio.components import (
15
+ Button,
16
+ Chatbot,
17
+ Component,
18
+ Markdown,
19
+ MultimodalTextbox,
20
+ State,
21
+ Textbox,
22
+ get_component_instance,
23
+ Dataset
24
+ )
25
+ from gradio.events import Dependency, on
26
+ from gradio.helpers import special_args
27
+ from gradio.layouts import Accordion, Group, Row
28
+ from gradio.routes import Request
29
+ from gradio.themes import ThemeClass as Theme
30
+ from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
31
+
32
+
33
+ @document()
34
+ class ChatInterface(Blocks):
35
+ """
36
+ ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
37
+ a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
38
+ takes a function that governs the response of the chatbot based on the user input and chat history. Additional
39
+ parameters can be used to control the appearance and behavior of the demo.
40
+
41
+ Example:
42
+ import gradio as gr
43
+
44
+ def echo(message, history):
45
+ return message
46
+
47
+ demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
48
+ demo.launch()
49
+ Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo
50
+ Guides: creating-a-chatbot-fast, sharing-your-app
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ fn: Callable,
56
+ post_fn: Callable,
57
+ pre_fn: Callable,
58
+ chatbot: Chatbot,
59
+ *,
60
+ post_fn_kwargs: dict = None,
61
+ pre_fn_kwargs: dict = None,
62
+ multimodal: bool = False,
63
+ textbox: Textbox | MultimodalTextbox | None = None,
64
+ additional_inputs: str | Component | list[str | Component] | None = None,
65
+ additional_inputs_accordion_name: str | None = None,
66
+ additional_inputs_accordion: str | Accordion | None = None,
67
+ examples: Dataset = None,
68
+ title: str | None = None,
69
+ description: str | None = None,
70
+ theme: Theme | str | None = None,
71
+ css: str | None = None,
72
+ js: str | None = None,
73
+ head: str | None = None,
74
+ analytics_enabled: bool | None = None,
75
+ submit_btn: str | None | Button = "Submit",
76
+ stop_btn: str | None | Button = "Stop",
77
+ retry_btn: str | None | Button = "🔄 Retry",
78
+ undo_btn: str | None | Button = "↩️ Undo",
79
+ clear_btn: str | None | Button = "🗑️ Clear",
80
+ autofocus: bool = True,
81
+ concurrency_limit: int | None | Literal["default"] = "default",
82
+ fill_height: bool = True,
83
+ delete_cache: tuple[int, int] | None = None,
84
+ ):
85
+ super().__init__(
86
+ analytics_enabled=analytics_enabled,
87
+ mode="chat_interface",
88
+ css=css,
89
+ title=title or "Gradio",
90
+ theme=theme,
91
+ js=js,
92
+ head=head,
93
+ fill_height=fill_height,
94
+ delete_cache=delete_cache,
95
+ )
96
+
97
+ if post_fn_kwargs is None:
98
+ post_fn_kwargs = []
99
+
100
+ self.post_fn = post_fn
101
+ self.post_fn_kwargs = post_fn_kwargs
102
+
103
+ self.pre_fn = pre_fn
104
+ self.pre_fn_kwargs = pre_fn_kwargs
105
+
106
+ self.multimodal = multimodal
107
+ self.concurrency_limit = concurrency_limit
108
+ self.fn = fn
109
+ self.is_async = inspect.iscoroutinefunction(
110
+ self.fn
111
+ ) or inspect.isasyncgenfunction(self.fn)
112
+ self.is_generator = inspect.isgeneratorfunction(
113
+ self.fn
114
+ ) or inspect.isasyncgenfunction(self.fn)
115
+
116
+ if additional_inputs:
117
+ if not isinstance(additional_inputs, list):
118
+ additional_inputs = [additional_inputs]
119
+ self.additional_inputs = [
120
+ get_component_instance(i)
121
+ for i in additional_inputs # type: ignore
122
+ ]
123
+ else:
124
+ self.additional_inputs = []
125
+ if additional_inputs_accordion_name is not None:
126
+ print(
127
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
128
+ )
129
+ self.additional_inputs_accordion_params = {
130
+ "label": additional_inputs_accordion_name
131
+ }
132
+ if additional_inputs_accordion is None:
133
+ self.additional_inputs_accordion_params = {
134
+ "label": "Additional Inputs",
135
+ "open": False,
136
+ }
137
+ elif isinstance(additional_inputs_accordion, str):
138
+ self.additional_inputs_accordion_params = {
139
+ "label": additional_inputs_accordion
140
+ }
141
+ elif isinstance(additional_inputs_accordion, Accordion):
142
+ self.additional_inputs_accordion_params = (
143
+ additional_inputs_accordion.recover_kwargs(
144
+ additional_inputs_accordion.get_config()
145
+ )
146
+ )
147
+ else:
148
+ raise ValueError(
149
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
150
+ )
151
+
152
+ with self:
153
+ if title:
154
+ Markdown(
155
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
156
+ )
157
+ if description:
158
+ Markdown(description)
159
+
160
+ self.chatbot = chatbot.render()
161
+
162
+ self.buttons = [retry_btn, undo_btn, clear_btn]
163
+
164
+ with Group():
165
+ with Row():
166
+ if textbox:
167
+ if self.multimodal:
168
+ submit_btn = None
169
+ else:
170
+ textbox.container = False
171
+ textbox.show_label = False
172
+ textbox_ = textbox.render()
173
+ if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
174
+ raise TypeError(
175
+ f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}"
176
+ )
177
+ self.textbox = textbox_
178
+ elif self.multimodal:
179
+ submit_btn = None
180
+ self.textbox = MultimodalTextbox(
181
+ show_label=False,
182
+ label="Message",
183
+ placeholder="Type a message...",
184
+ scale=7,
185
+ autofocus=autofocus,
186
+ )
187
+ else:
188
+ self.textbox = Textbox(
189
+ container=False,
190
+ show_label=False,
191
+ label="Message",
192
+ placeholder="Type a message...",
193
+ scale=7,
194
+ autofocus=autofocus,
195
+ )
196
+ if submit_btn is not None and not multimodal:
197
+ if isinstance(submit_btn, Button):
198
+ submit_btn.render()
199
+ elif isinstance(submit_btn, str):
200
+ submit_btn = Button(
201
+ submit_btn,
202
+ variant="primary",
203
+ scale=1,
204
+ min_width=150,
205
+ )
206
+ else:
207
+ raise ValueError(
208
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
209
+ )
210
+ if stop_btn is not None:
211
+ if isinstance(stop_btn, Button):
212
+ stop_btn.visible = False
213
+ stop_btn.render()
214
+ elif isinstance(stop_btn, str):
215
+ stop_btn = Button(
216
+ stop_btn,
217
+ variant="stop",
218
+ visible=False,
219
+ scale=1,
220
+ min_width=150,
221
+ )
222
+ else:
223
+ raise ValueError(
224
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
225
+ )
226
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
227
+
228
+ self.fake_api_btn = Button("Fake API", visible=False)
229
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
230
+ (
231
+ self.retry_btn,
232
+ self.undo_btn,
233
+ self.clear_btn,
234
+ self.submit_btn,
235
+ self.stop_btn,
236
+ ) = self.buttons
237
+
238
+ any_unrendered_inputs = any(
239
+ not inp.is_rendered for inp in self.additional_inputs
240
+ )
241
+ if self.additional_inputs and any_unrendered_inputs:
242
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
243
+ for input_component in self.additional_inputs:
244
+ if not input_component.is_rendered:
245
+ input_component.render()
246
+
247
+ self.saved_input = State()
248
+ self.chatbot_state = (
249
+ State(self.chatbot.value) if self.chatbot.value else State([])
250
+ )
251
+
252
+ self._setup_events()
253
+ self._setup_api()
254
+
255
+ if examples:
256
+ examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False)
257
+
258
+ def _setup_events(self) -> None:
259
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
260
+ submit_triggers = (
261
+ [self.textbox.submit, self.submit_btn.click]
262
+ if self.submit_btn
263
+ else [self.textbox.submit]
264
+ )
265
+ submit_event = (
266
+ on(
267
+ submit_triggers,
268
+ self._clear_and_save_textbox,
269
+ [self.textbox],
270
+ [self.textbox, self.saved_input],
271
+ show_api=False,
272
+ queue=False,
273
+ )
274
+ .then(
275
+ self.pre_fn,
276
+ **self.pre_fn_kwargs,
277
+ show_api=False,
278
+ queue=False,
279
+ )
280
+ .then(
281
+ self._display_input,
282
+ [self.saved_input, self.chatbot_state],
283
+ [self.chatbot, self.chatbot_state],
284
+ show_api=False,
285
+ queue=False,
286
+ )
287
+ .then(
288
+ submit_fn,
289
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
290
+ [self.chatbot, self.chatbot_state],
291
+ show_api=False,
292
+ concurrency_limit=cast(
293
+ Union[int, Literal["default"], None], self.concurrency_limit
294
+ ),
295
+ ).then(
296
+ self.post_fn,
297
+ **self.post_fn_kwargs,
298
+ show_api=False,
299
+ concurrency_limit=cast(
300
+ Union[int, Literal["default"], None], self.concurrency_limit
301
+ ),
302
+ )
303
+ )
304
+ self._setup_stop_events(submit_triggers, submit_event)
305
+
306
+ if self.retry_btn:
307
+ retry_event = (
308
+ self.retry_btn.click(
309
+ self._delete_prev_fn,
310
+ [self.saved_input, self.chatbot_state],
311
+ [self.chatbot, self.saved_input, self.chatbot_state],
312
+ show_api=False,
313
+ queue=False,
314
+ )
315
+ .then(
316
+ self.pre_fn,
317
+ **self.pre_fn_kwargs,
318
+ show_api=False,
319
+ queue=False,
320
+ )
321
+ .then(
322
+ self._display_input,
323
+ [self.saved_input, self.chatbot_state],
324
+ [self.chatbot, self.chatbot_state],
325
+ show_api=False,
326
+ queue=False,
327
+ )
328
+ .then(
329
+ submit_fn,
330
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
331
+ [self.chatbot, self.chatbot_state],
332
+ show_api=False,
333
+ concurrency_limit=cast(
334
+ Union[int, Literal["default"], None], self.concurrency_limit
335
+ ),
336
+ ).then(
337
+ self.post_fn,
338
+ **self.post_fn_kwargs,
339
+ show_api=False,
340
+ concurrency_limit=cast(
341
+ Union[int, Literal["default"], None], self.concurrency_limit
342
+ ),
343
+ )
344
+ )
345
+ self._setup_stop_events([self.retry_btn.click], retry_event)
346
+
347
+ if self.undo_btn:
348
+ self.undo_btn.click(
349
+ self._delete_prev_fn,
350
+ [self.saved_input, self.chatbot_state],
351
+ [self.chatbot, self.saved_input, self.chatbot_state],
352
+ show_api=False,
353
+ queue=False,
354
+ ).then(
355
+ self.pre_fn,
356
+ **self.pre_fn_kwargs,
357
+ show_api=False,
358
+ queue=False,
359
+ ).then(
360
+ async_lambda(lambda x: x),
361
+ [self.saved_input],
362
+ [self.textbox],
363
+ show_api=False,
364
+ queue=False,
365
+ ).then(
366
+ self.post_fn,
367
+ **self.post_fn_kwargs,
368
+ show_api=False,
369
+ concurrency_limit=cast(
370
+ Union[int, Literal["default"], None], self.concurrency_limit
371
+ ),
372
+ )
373
+
374
+ if self.clear_btn:
375
+ self.clear_btn.click(
376
+ async_lambda(lambda: ([], [], None)),
377
+ None,
378
+ [self.chatbot, self.chatbot_state, self.saved_input],
379
+ queue=False,
380
+ show_api=False,
381
+ ).then(
382
+ self.pre_fn,
383
+ **self.pre_fn_kwargs,
384
+ show_api=False,
385
+ queue=False,
386
+ ).then(
387
+ self.post_fn,
388
+ **self.post_fn_kwargs,
389
+ show_api=False,
390
+ concurrency_limit=cast(
391
+ Union[int, Literal["default"], None], self.concurrency_limit
392
+ ),
393
+ )
394
+
395
+ def _setup_stop_events(
396
+ self, event_triggers: list[Callable], event_to_cancel: Dependency
397
+ ) -> None:
398
+ if self.stop_btn and self.is_generator:
399
+ if self.submit_btn:
400
+ for event_trigger in event_triggers:
401
+ event_trigger(
402
+ async_lambda(
403
+ lambda: (
404
+ Button(visible=False),
405
+ Button(visible=True),
406
+ )
407
+ ),
408
+ None,
409
+ [self.submit_btn, self.stop_btn],
410
+ show_api=False,
411
+ queue=False,
412
+ )
413
+ event_to_cancel.then(
414
+ async_lambda(lambda: (Button(visible=True), Button(visible=False))),
415
+ None,
416
+ [self.submit_btn, self.stop_btn],
417
+ show_api=False,
418
+ queue=False,
419
+ )
420
+ else:
421
+ for event_trigger in event_triggers:
422
+ event_trigger(
423
+ async_lambda(lambda: Button(visible=True)),
424
+ None,
425
+ [self.stop_btn],
426
+ show_api=False,
427
+ queue=False,
428
+ )
429
+ event_to_cancel.then(
430
+ async_lambda(lambda: Button(visible=False)),
431
+ None,
432
+ [self.stop_btn],
433
+ show_api=False,
434
+ queue=False,
435
+ )
436
+ self.stop_btn.click(
437
+ None,
438
+ None,
439
+ None,
440
+ cancels=event_to_cancel,
441
+ show_api=False,
442
+ )
443
+
444
+ def _setup_api(self) -> None:
445
+ api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
446
+
447
+ self.fake_api_btn.click(
448
+ api_fn,
449
+ [self.textbox, self.chatbot_state] + self.additional_inputs,
450
+ [self.textbox, self.chatbot_state],
451
+ api_name="chat",
452
+ concurrency_limit=cast(
453
+ Union[int, Literal["default"], None], self.concurrency_limit
454
+ ),
455
+ )
456
+
457
+ def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
458
+ if self.multimodal:
459
+ return {"text": "", "files": []}, message
460
+ else:
461
+ return "", message
462
+
463
+ def _append_multimodal_history(
464
+ self,
465
+ message: dict[str, list],
466
+ response: str | None,
467
+ history: list[list[str | tuple | None]],
468
+ ):
469
+ for x in message["files"]:
470
+ history.append([(x,), None])
471
+ if message["text"] is None or not isinstance(message["text"], str):
472
+ return
473
+ elif message["text"] == "" and message["files"] != []:
474
+ history.append([None, response])
475
+ else:
476
+ history.append([message["text"], response])
477
+
478
+ async def _display_input(
479
+ self, message: str | dict[str, list], history: list[list[str | tuple | None]]
480
+ ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
481
+ if self.multimodal and isinstance(message, dict):
482
+ self._append_multimodal_history(message, None, history)
483
+ elif isinstance(message, str):
484
+ history.append([message, None])
485
+ return history, history
486
+
487
+ async def _submit_fn(
488
+ self,
489
+ message: str | dict[str, list],
490
+ history_with_input: list[list[str | tuple | None]],
491
+ request: Request,
492
+ *args,
493
+ ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
494
+ if self.multimodal and isinstance(message, dict):
495
+ remove_input = (
496
+ len(message["files"]) + 1
497
+ if message["text"] is not None
498
+ else len(message["files"])
499
+ )
500
+ history = history_with_input[:-remove_input]
501
+ else:
502
+ history = history_with_input[:-1]
503
+ inputs, _, _ = special_args(
504
+ self.fn, inputs=[message, history, *args], request=request
505
+ )
506
+
507
+ if self.is_async:
508
+ response = await self.fn(*inputs)
509
+ else:
510
+ response = await anyio.to_thread.run_sync(
511
+ self.fn, *inputs, limiter=self.limiter
512
+ )
513
+
514
+ if self.multimodal and isinstance(message, dict):
515
+ self._append_multimodal_history(message, response, history)
516
+ elif isinstance(message, str):
517
+ history.append([message, response])
518
+ return history, history
519
+
520
+ async def _stream_fn(
521
+ self,
522
+ message: str | dict[str, list],
523
+ history_with_input: list[list[str | tuple | None]],
524
+ request: Request,
525
+ *args,
526
+ ) -> AsyncGenerator:
527
+ if self.multimodal and isinstance(message, dict):
528
+ remove_input = (
529
+ len(message["files"]) + 1
530
+ if message["text"] is not None
531
+ else len(message["files"])
532
+ )
533
+ history = history_with_input[:-remove_input]
534
+ else:
535
+ history = history_with_input[:-1]
536
+ inputs, _, _ = special_args(
537
+ self.fn, inputs=[message, history, *args], request=request
538
+ )
539
+
540
+ if self.is_async:
541
+ generator = self.fn(*inputs)
542
+ else:
543
+ generator = await anyio.to_thread.run_sync(
544
+ self.fn, *inputs, limiter=self.limiter
545
+ )
546
+ generator = SyncToAsyncIterator(generator, self.limiter)
547
+ try:
548
+ first_response = await async_iteration(generator)
549
+ if self.multimodal and isinstance(message, dict):
550
+ for x in message["files"]:
551
+ history.append([(x,), None])
552
+ update = history + [[message["text"], first_response]]
553
+ yield update, update
554
+ else:
555
+ update = history + [[message, first_response]]
556
+ yield update, update
557
+ except StopIteration:
558
+ if self.multimodal and isinstance(message, dict):
559
+ self._append_multimodal_history(message, None, history)
560
+ yield history, history
561
+ else:
562
+ update = history + [[message, None]]
563
+ yield update, update
564
+ async for response in generator:
565
+ if self.multimodal and isinstance(message, dict):
566
+ update = history + [[message["text"], response]]
567
+ yield update, update
568
+ else:
569
+ update = history + [[message, response]]
570
+ yield update, update
571
+
572
+ async def _api_submit_fn(
573
+ self, message: str, history: list[list[str | None]], request: Request, *args
574
+ ) -> tuple[str, list[list[str | None]]]:
575
+ inputs, _, _ = special_args(
576
+ self.fn, inputs=[message, history, *args], request=request
577
+ )
578
+
579
+ if self.is_async:
580
+ response = await self.fn(*inputs)
581
+ else:
582
+ response = await anyio.to_thread.run_sync(
583
+ self.fn, *inputs, limiter=self.limiter
584
+ )
585
+ history.append([message, response])
586
+ return response, history
587
+
588
+ async def _api_stream_fn(
589
+ self, message: str, history: list[list[str | None]], request: Request, *args
590
+ ) -> AsyncGenerator:
591
+ inputs, _, _ = special_args(
592
+ self.fn, inputs=[message, history, *args], request=request
593
+ )
594
+
595
+ if self.is_async:
596
+ generator = self.fn(*inputs)
597
+ else:
598
+ generator = await anyio.to_thread.run_sync(
599
+ self.fn, *inputs, limiter=self.limiter
600
+ )
601
+ generator = SyncToAsyncIterator(generator, self.limiter)
602
+ try:
603
+ first_response = await async_iteration(generator)
604
+ yield first_response, history + [[message, first_response]]
605
+ except StopIteration:
606
+ yield None, history + [[message, None]]
607
+ async for response in generator:
608
+ yield response, history + [[message, response]]
609
+
610
+ async def _delete_prev_fn(
611
+ self,
612
+ message: str | dict[str, list],
613
+ history: list[list[str | tuple | None]],
614
+ ) -> tuple[
615
+ list[list[str | tuple | None]],
616
+ str | dict[str, list],
617
+ list[list[str | tuple | None]],
618
+ ]:
619
+ if self.multimodal and isinstance(message, dict):
620
+ remove_input = (
621
+ len(message["files"]) + 1
622
+ if message["text"] is not None
623
+ else len(message["files"])
624
+ )
625
+ history = history[:-remove_input]
626
+ else:
627
+ history = history[:-1]
628
+ return history, message or "", history
lib_omost/canvas.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import difflib
3
+ import numpy as np
4
+
5
+ system_prompt = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
6
+
7
+ ```python
8
+ class Canvas:
9
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
10
+ pass
11
+
12
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
13
+ assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
14
+ assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
15
+ assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
16
+ assert distance_to_viewer > 0
17
+ pass
18
+ ```'''
19
+
20
+ valid_colors = { # r, g, b
21
+ 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
22
+ 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
23
+ 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
24
+ 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
25
+ 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
26
+ 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
27
+ 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
28
+ 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
29
+ 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
30
+ 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
31
+ 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
32
+ 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
33
+ 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
34
+ 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
35
+ 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
36
+ 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
37
+ 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
38
+ 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
39
+ 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
40
+ 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
41
+ 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
42
+ 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
43
+ 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
44
+ 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
45
+ 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
46
+ 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
47
+ 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
48
+ 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
49
+ 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
50
+ 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
51
+ 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
52
+ 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
53
+ 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
54
+ 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
55
+ 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
56
+ 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
57
+ 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
58
+ 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
59
+ 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
60
+ 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
61
+ 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
62
+ 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
63
+ 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
64
+ 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
65
+ 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
66
+ 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
67
+ 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
68
+ 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
69
+ }
70
+
71
+ valid_locations = { # x, y in 90*90
72
+ 'in the center': (45, 45),
73
+ 'on the left': (15, 45),
74
+ 'on the right': (75, 45),
75
+ 'on the top': (45, 15),
76
+ 'on the bottom': (45, 75),
77
+ 'on the top-left': (15, 15),
78
+ 'on the top-right': (75, 15),
79
+ 'on the bottom-left': (15, 75),
80
+ 'on the bottom-right': (75, 75)
81
+ }
82
+
83
+ valid_offsets = { # x, y in 90*90
84
+ 'no offset': (0, 0),
85
+ 'slightly to the left': (-10, 0),
86
+ 'slightly to the right': (10, 0),
87
+ 'slightly to the upper': (0, -10),
88
+ 'slightly to the lower': (0, 10),
89
+ 'slightly to the upper-left': (-10, -10),
90
+ 'slightly to the upper-right': (10, -10),
91
+ 'slightly to the lower-left': (-10, 10),
92
+ 'slightly to the lower-right': (10, 10)}
93
+
94
+ valid_areas = { # w, h in 90*90
95
+ "a small square area": (50, 50),
96
+ "a small vertical area": (40, 60),
97
+ "a small horizontal area": (60, 40),
98
+ "a medium-sized square area": (60, 60),
99
+ "a medium-sized vertical area": (50, 80),
100
+ "a medium-sized horizontal area": (80, 50),
101
+ "a large square area": (70, 70),
102
+ "a large vertical area": (60, 90),
103
+ "a large horizontal area": (90, 60)
104
+ }
105
+
106
+
107
+ def closest_name(input_str, options):
108
+ input_str = input_str.lower()
109
+
110
+ closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
111
+ assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
112
+ result = closest_match[0]
113
+
114
+ if result != input_str:
115
+ print(f'Automatically corrected [{input_str}] -> [{result}].')
116
+
117
+ return result
118
+
119
+
120
+ def safe_str(x):
121
+ return x.strip(',. ') + '.'
122
+
123
+
124
+ def binary_nonzero_positions(n, offset=0):
125
+ binary_str = bin(n)[2:]
126
+ positions = [i + offset for i, bit in enumerate(reversed(binary_str)) if bit == '1']
127
+ return positions
128
+
129
+
130
+ class Canvas:
131
+ @staticmethod
132
+ def from_bot_response(response: str):
133
+ matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
134
+ assert matched, 'Response does not contain codes!'
135
+ code_content = matched.group(1)
136
+ assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
137
+ local_vars = {'Canvas': Canvas}
138
+ exec(code_content, {}, local_vars)
139
+ canvas = local_vars.get('canvas', None)
140
+ assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
141
+ return canvas
142
+
143
+ def __init__(self):
144
+ self.components = []
145
+ self.color = None
146
+ self.record_tags = True
147
+ self.prefixes = []
148
+ self.suffixes = []
149
+ return
150
+
151
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
152
+ HTML_web_color_name: str):
153
+ assert isinstance(description, str), 'Global description is not valid!'
154
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
155
+ 'Global detailed_descriptions is not valid!'
156
+ assert isinstance(tags, str), 'Global tags is not valid!'
157
+
158
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
159
+ self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
160
+
161
+ self.prefixes = [description]
162
+ self.suffixes = detailed_descriptions
163
+
164
+ if self.record_tags:
165
+ self.suffixes = self.suffixes + [tags]
166
+
167
+ self.prefixes = [safe_str(x) for x in self.prefixes]
168
+ self.suffixes = [safe_str(x) for x in self.suffixes]
169
+
170
+ return
171
+
172
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
173
+ detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
174
+ quality_meta: str, HTML_web_color_name: str):
175
+ assert isinstance(description, str), 'Local description is wrong!'
176
+ assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
177
+ f'The distance_to_viewer for [{description}] is not positive float number!'
178
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
179
+ f'The detailed_descriptions for [{description}] is not valid!'
180
+ assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
181
+ assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
182
+ assert isinstance(style, str), f'The style for [{description}] is not valid!'
183
+ assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
184
+
185
+ location = closest_name(location, valid_locations)
186
+ offset = closest_name(offset, valid_offsets)
187
+ area = closest_name(area, valid_areas)
188
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
189
+
190
+ xb, yb = valid_locations[location]
191
+ xo, yo = valid_offsets[offset]
192
+ w, h = valid_areas[area]
193
+ rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
194
+ rect = [max(0, min(90, i)) for i in rect]
195
+ color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
196
+
197
+ prefixes = self.prefixes + [description]
198
+ suffixes = detailed_descriptions
199
+
200
+ if self.record_tags:
201
+ suffixes = suffixes + [tags, atmosphere, style, quality_meta]
202
+
203
+ prefixes = [safe_str(x) for x in prefixes]
204
+ suffixes = [safe_str(x) for x in suffixes]
205
+
206
+ self.components.append(dict(
207
+ rect=rect,
208
+ distance_to_viewer=distance_to_viewer,
209
+ color=color,
210
+ prefixes=prefixes,
211
+ suffixes=suffixes
212
+ ))
213
+
214
+ return
215
+
216
+ def process(self):
217
+ # sort components
218
+ self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
219
+
220
+ # compute initial latent
221
+ initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
222
+
223
+ for component in self.components:
224
+ a, b, c, d = component['rect']
225
+ initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
226
+
227
+ initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
228
+
229
+ # compute conditions
230
+
231
+ bag_of_conditions = [
232
+ dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
233
+ ]
234
+
235
+ for i, component in enumerate(self.components):
236
+ a, b, c, d = component['rect']
237
+ m = np.zeros(shape=(90, 90), dtype=np.float32)
238
+ m[a:b, c:d] = 1.0
239
+ bag_of_conditions.append(dict(
240
+ mask=m,
241
+ prefixes=component['prefixes'],
242
+ suffixes=component['suffixes']
243
+ ))
244
+
245
+ return dict(
246
+ initial_latent=initial_latent,
247
+ bag_of_conditions=bag_of_conditions,
248
+ )
lib_omost/pipeline.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import copy
3
+
4
+ from tqdm.auto import trange
5
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import *
6
+ from diffusers.models.transformers import Transformer2DModel
7
+
8
+
9
+ original_Transformer2DModel_forward = Transformer2DModel.forward
10
+
11
+
12
+ def hacked_Transformer2DModel_forward(
13
+ self,
14
+ hidden_states: torch.Tensor,
15
+ encoder_hidden_states: Optional[torch.Tensor] = None,
16
+ timestep: Optional[torch.LongTensor] = None,
17
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
18
+ class_labels: Optional[torch.LongTensor] = None,
19
+ cross_attention_kwargs: Dict[str, Any] = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ encoder_attention_mask: Optional[torch.Tensor] = None,
22
+ return_dict: bool = True,
23
+ ):
24
+ cross_attention_kwargs = cross_attention_kwargs or {}
25
+ cross_attention_kwargs['hidden_states_original_shape'] = hidden_states.shape
26
+ return original_Transformer2DModel_forward(
27
+ self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs,
28
+ attention_mask, encoder_attention_mask, return_dict)
29
+
30
+
31
+ Transformer2DModel.forward = hacked_Transformer2DModel_forward
32
+
33
+
34
+ @torch.no_grad()
35
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
36
+ """DPM-Solver++(2M)."""
37
+ extra_args = {} if extra_args is None else extra_args
38
+ s_in = x.new_ones([x.shape[0]])
39
+ sigma_fn = lambda t: t.neg().exp()
40
+ t_fn = lambda sigma: sigma.log().neg()
41
+ old_denoised = None
42
+
43
+ for i in trange(len(sigmas) - 1, disable=disable):
44
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
45
+ if callback is not None:
46
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
47
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
48
+ h = t_next - t
49
+ if old_denoised is None or sigmas[i + 1] == 0:
50
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
51
+ else:
52
+ h_last = t - t_fn(sigmas[i - 1])
53
+ r = h_last / h
54
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
55
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
56
+ old_denoised = denoised
57
+ return x
58
+
59
+
60
+ class KModel:
61
+ def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012):
62
+ betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
63
+ alphas = 1. - betas
64
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
65
+
66
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
67
+ self.log_sigmas = self.sigmas.log()
68
+ self.sigma_data = 1.0
69
+ self.unet = unet
70
+ return
71
+
72
+ @property
73
+ def sigma_min(self):
74
+ return self.sigmas[0]
75
+
76
+ @property
77
+ def sigma_max(self):
78
+ return self.sigmas[-1]
79
+
80
+ def timestep(self, sigma):
81
+ log_sigma = sigma.log()
82
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
83
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
84
+
85
+ def get_sigmas_karras(self, n, rho=7.):
86
+ ramp = torch.linspace(0, 1, n)
87
+ min_inv_rho = self.sigma_min ** (1 / rho)
88
+ max_inv_rho = self.sigma_max ** (1 / rho)
89
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
90
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
91
+
92
+ def __call__(self, x, sigma, **extra_args):
93
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
94
+ t = self.timestep(sigma)
95
+ cfg_scale = extra_args['cfg_scale']
96
+ eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
97
+ eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
98
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
99
+ return x - noise_pred * sigma[:, None, None, None]
100
+
101
+
102
+ class OmostSelfAttnProcessor:
103
+ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
104
+ batch_size, sequence_length, _ = hidden_states.shape
105
+
106
+ query = attn.to_q(hidden_states)
107
+ key = attn.to_k(hidden_states)
108
+ value = attn.to_v(hidden_states)
109
+
110
+ inner_dim = key.shape[-1]
111
+ head_dim = inner_dim // attn.heads
112
+
113
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
114
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
115
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
116
+
117
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
118
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
119
+ )
120
+
121
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
122
+ hidden_states = hidden_states.to(query.dtype)
123
+ hidden_states = attn.to_out[0](hidden_states)
124
+ hidden_states = attn.to_out[1](hidden_states)
125
+
126
+ return hidden_states
127
+
128
+
129
+ class OmostCrossAttnProcessor:
130
+ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
131
+ B, C, H, W = hidden_states_original_shape
132
+
133
+ conds = []
134
+ masks = []
135
+
136
+ for m, c in encoder_hidden_states:
137
+ m = torch.nn.functional.interpolate(m[None, None, :, :], (H, W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, c.size(1))
138
+ conds.append(c)
139
+ masks.append(m)
140
+
141
+ conds = torch.cat(conds, dim=1)
142
+ masks = torch.cat(masks, dim=1)
143
+
144
+ mask_bool = masks > 0.5
145
+ mask_scale = (H * W) / torch.sum(masks, dim=0, keepdim=True)
146
+
147
+ batch_size, sequence_length, _ = conds.shape
148
+
149
+ query = attn.to_q(hidden_states)
150
+ key = attn.to_k(conds)
151
+ value = attn.to_v(conds)
152
+
153
+ inner_dim = key.shape[-1]
154
+ head_dim = inner_dim // attn.heads
155
+
156
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
157
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
158
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
159
+
160
+ mask_bool = mask_bool[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
161
+ mask_scale = mask_scale[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
162
+
163
+ sim = query @ key.transpose(-2, -1) * attn.scale
164
+ sim = sim * mask_scale.to(sim)
165
+ sim.masked_fill_(mask_bool.logical_not(), float("-inf"))
166
+ sim = sim.softmax(dim=-1)
167
+
168
+ h = sim @ value
169
+ h = h.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
170
+
171
+ h = attn.to_out[0](h)
172
+ h = attn.to_out[1](h)
173
+ return h
174
+
175
+
176
+ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
177
+ def __init__(self, *args, **kwargs):
178
+ super().__init__(*args, **kwargs)
179
+ self.k_model = KModel(unet=self.unet)
180
+
181
+ attn_procs = {}
182
+ for name in self.unet.attn_processors.keys():
183
+ if name.endswith("attn2.processor"):
184
+ attn_procs[name] = OmostCrossAttnProcessor()
185
+ else:
186
+ attn_procs[name] = OmostSelfAttnProcessor()
187
+
188
+ self.unet.set_attn_processor(attn_procs)
189
+ return
190
+
191
+ @torch.inference_mode()
192
+ def encode_bag_of_subprompts_greedy(self, prefixes: list[str], suffixes: list[str]):
193
+ device = self.text_encoder.device
194
+
195
+ @torch.inference_mode()
196
+ def greedy_partition(items, max_sum):
197
+ bags = []
198
+ current_bag = []
199
+ current_sum = 0
200
+
201
+ for item in items:
202
+ num = item['length']
203
+ if current_sum + num > max_sum:
204
+ if current_bag:
205
+ bags.append(current_bag)
206
+ current_bag = [item]
207
+ current_sum = num
208
+ else:
209
+ current_bag.append(item)
210
+ current_sum += num
211
+
212
+ if current_bag:
213
+ bags.append(current_bag)
214
+
215
+ return bags
216
+
217
+ @torch.inference_mode()
218
+ def get_77_tokens_in_torch(subprompt_inds, tokenizer):
219
+ # Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
220
+ result = [tokenizer.bos_token_id] + subprompt_inds[:75] + [tokenizer.eos_token_id] + [tokenizer.pad_token_id] * 75
221
+ result = result[:77]
222
+ result = torch.tensor([result]).to(device=device, dtype=torch.int64)
223
+ return result
224
+
225
+ @torch.inference_mode()
226
+ def merge_with_prefix(bag):
227
+ merged_ids_t1 = copy.deepcopy(prefix_ids_t1)
228
+ merged_ids_t2 = copy.deepcopy(prefix_ids_t2)
229
+
230
+ for item in bag:
231
+ merged_ids_t1.extend(item['ids_t1'])
232
+ merged_ids_t2.extend(item['ids_t2'])
233
+
234
+ return dict(
235
+ ids_t1=get_77_tokens_in_torch(merged_ids_t1, self.tokenizer),
236
+ ids_t2=get_77_tokens_in_torch(merged_ids_t2, self.tokenizer_2)
237
+ )
238
+
239
+ @torch.inference_mode()
240
+ def double_encode(pair_of_inds):
241
+ inds = [pair_of_inds['ids_t1'], pair_of_inds['ids_t2']]
242
+ text_encoders = [self.text_encoder, self.text_encoder_2]
243
+
244
+ pooled_prompt_embeds = None
245
+ prompt_embeds_list = []
246
+
247
+ for text_input_ids, text_encoder in zip(inds, text_encoders):
248
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)
249
+
250
+ # Only last pooler_output is needed
251
+ pooled_prompt_embeds = prompt_embeds.pooler_output
252
+
253
+ # "2" because SDXL always indexes from the penultimate layer.
254
+ prompt_embeds = prompt_embeds.hidden_states[-2]
255
+ prompt_embeds_list.append(prompt_embeds)
256
+
257
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
258
+ return prompt_embeds, pooled_prompt_embeds
259
+
260
+ # Begin with tokenizing prefixes
261
+
262
+ prefix_length = 0
263
+ prefix_ids_t1 = []
264
+ prefix_ids_t2 = []
265
+
266
+ for prefix in prefixes:
267
+ ids_t1 = self.tokenizer(prefix, truncation=False, add_special_tokens=False).input_ids
268
+ ids_t2 = self.tokenizer_2(prefix, truncation=False, add_special_tokens=False).input_ids
269
+ assert len(ids_t1) == len(ids_t2)
270
+ prefix_length += len(ids_t1)
271
+ prefix_ids_t1 += ids_t1
272
+ prefix_ids_t2 += ids_t2
273
+
274
+ # Then tokenizing suffixes
275
+
276
+ allowed_suffix_length = 75 - prefix_length
277
+ suffix_targets = []
278
+
279
+ for subprompt in suffixes:
280
+ # Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
281
+ # So we can safely just crop it to 75
282
+ ids_t1 = self.tokenizer(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
283
+ ids_t2 = self.tokenizer_2(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
284
+ assert len(ids_t1) == len(ids_t2)
285
+ suffix_targets.append(dict(
286
+ length=len(ids_t1),
287
+ ids_t1=ids_t1,
288
+ ids_t2=ids_t2
289
+ ))
290
+
291
+ # Then merge prefix and suffix tokens
292
+
293
+ suffix_targets = greedy_partition(suffix_targets, max_sum=allowed_suffix_length)
294
+ targets = [merge_with_prefix(b) for b in suffix_targets]
295
+
296
+ # Encode!
297
+
298
+ conds, poolers = [], []
299
+
300
+ for target in targets:
301
+ cond, pooler = double_encode(target)
302
+ conds.append(cond)
303
+ poolers.append(pooler)
304
+
305
+ conds_merged = torch.concat(conds, dim=1)
306
+ poolers_merged = poolers[0]
307
+
308
+ return dict(cond=conds_merged, pooler=poolers_merged)
309
+
310
+ @torch.inference_mode()
311
+ def all_conds_from_canvas(self, canvas_outputs, negative_prompt):
312
+ mask_all = torch.ones(size=(90, 90), dtype=torch.float32)
313
+ negative_cond, negative_pooler = self.encode_cropped_prompt_77tokens(negative_prompt)
314
+ negative_result = [(mask_all, negative_cond)]
315
+
316
+ positive_result = []
317
+ positive_pooler = None
318
+
319
+ for item in canvas_outputs['bag_of_conditions']:
320
+ current_mask = torch.from_numpy(item['mask']).to(torch.float32)
321
+ current_prefixes = item['prefixes']
322
+ current_suffixes = item['suffixes']
323
+
324
+ current_cond = self.encode_bag_of_subprompts_greedy(prefixes=current_prefixes, suffixes=current_suffixes)
325
+
326
+ if positive_pooler is None:
327
+ positive_pooler = current_cond['pooler']
328
+
329
+ positive_result.append((current_mask, current_cond['cond']))
330
+
331
+ return positive_result, positive_pooler, negative_result, negative_pooler
332
+
333
+ @torch.inference_mode()
334
+ def encode_cropped_prompt_77tokens(self, prompt: str):
335
+ device = self.text_encoder.device
336
+ tokenizers = [self.tokenizer, self.tokenizer_2]
337
+ text_encoders = [self.text_encoder, self.text_encoder_2]
338
+
339
+ pooled_prompt_embeds = None
340
+ prompt_embeds_list = []
341
+
342
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
343
+ text_input_ids = tokenizer(
344
+ prompt,
345
+ padding="max_length",
346
+ max_length=tokenizer.model_max_length,
347
+ truncation=True,
348
+ return_tensors="pt",
349
+ ).input_ids
350
+
351
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
352
+
353
+ # Only last pooler_output is needed
354
+ pooled_prompt_embeds = prompt_embeds.pooler_output
355
+
356
+ # "2" because SDXL always indexes from the penultimate layer.
357
+ prompt_embeds = prompt_embeds.hidden_states[-2]
358
+ prompt_embeds_list.append(prompt_embeds)
359
+
360
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
361
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
362
+
363
+ return prompt_embeds, pooled_prompt_embeds
364
+
365
+ @torch.inference_mode()
366
+ def __call__(
367
+ self,
368
+ initial_latent: torch.FloatTensor = None,
369
+ strength: float = 1.0,
370
+ num_inference_steps: int = 25,
371
+ guidance_scale: float = 5.0,
372
+ batch_size: Optional[int] = 1,
373
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
374
+ prompt_embeds: Optional[torch.FloatTensor] = None,
375
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
376
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
377
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
378
+ cross_attention_kwargs: Optional[dict] = None,
379
+ ):
380
+
381
+ device = self.unet.device
382
+ cross_attention_kwargs = cross_attention_kwargs or {}
383
+
384
+ # Sigmas
385
+
386
+ sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps / strength))
387
+ sigmas = sigmas[-(num_inference_steps + 1):].to(device)
388
+
389
+ # Initial latents
390
+
391
+ _, C, H, W = initial_latent.shape
392
+ noise = randn_tensor((batch_size, C, H, W), generator=generator, device=device, dtype=self.unet.dtype)
393
+ latents = initial_latent.to(noise) + noise * sigmas[0].to(noise)
394
+
395
+ # Shape
396
+
397
+ height, width = latents.shape[-2:]
398
+ height = height * self.vae_scale_factor
399
+ width = width * self.vae_scale_factor
400
+
401
+ add_time_ids = list((height, width) + (0, 0) + (height, width))
402
+ add_time_ids = torch.tensor([add_time_ids], dtype=self.unet.dtype)
403
+ add_neg_time_ids = add_time_ids.clone()
404
+
405
+ # Batch
406
+
407
+ latents = latents.to(device)
408
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(device)
409
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device)
410
+ prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in prompt_embeds]
411
+ negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in negative_prompt_embeds]
412
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
413
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
414
+
415
+ # Feeds
416
+
417
+ sampler_kwargs = dict(
418
+ cfg_scale=guidance_scale,
419
+ positive=dict(
420
+ encoder_hidden_states=prompt_embeds,
421
+ added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
422
+ cross_attention_kwargs=cross_attention_kwargs
423
+ ),
424
+ negative=dict(
425
+ encoder_hidden_states=negative_prompt_embeds,
426
+ added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids},
427
+ cross_attention_kwargs=cross_attention_kwargs
428
+ )
429
+ )
430
+
431
+ # Sample
432
+
433
+ results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False)
434
+
435
+ return StableDiffusionXLPipelineOutput(images=results)