gokaygokay commited on
Commit
ece05f2
1 Parent(s): df8e598
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import streamlit as st
3
+ import torch
4
+ from huggingface_hub import snapshot_download
5
+ from txt2panoimg import Text2360PanoramaImagePipeline
6
+ from img2panoimg import Image2360PanoramaImagePipeline
7
+ from PIL import Image
8
+ from streamlit_pannellum import streamlit_pannellum
9
+
10
+ # Custom CSS to make the UI more attractive
11
+ st.markdown("""
12
+ <style>
13
+ .stApp {
14
+ max-width: 1200px;
15
+ margin: 0 auto;
16
+ }
17
+ .main {
18
+ background-color: #f0f2f6;
19
+ }
20
+ h1 {
21
+ color: #1E3A8A;
22
+ text-align: center;
23
+ padding: 20px 0;
24
+ font-size: 2.5rem;
25
+ }
26
+ .stTabs {
27
+ background-color: white;
28
+ padding: 20px;
29
+ border-radius: 10px;
30
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
31
+ }
32
+ .stButton>button {
33
+ background-color: #1E3A8A;
34
+ color: white;
35
+ font-weight: bold;
36
+ }
37
+ .viewer-column {
38
+ background-color: white;
39
+ padding: 20px;
40
+ border-radius: 10px;
41
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
42
+ }
43
+ </style>
44
+ """, unsafe_allow_html=True)
45
+
46
+ # Download the model
47
+ model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
48
+
49
+ # Initialize pipelines
50
+ txt2panoimg = Text2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
51
+ img2panoimg = Image2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
52
+
53
+ # Load the default mask image
54
+ default_mask = Image.open("i2p-mask.jpg").convert("RGB")
55
+
56
+ @spaces.GPU(duration=200)
57
+ def text_to_pano(prompt, upscale):
58
+ input_data = {'prompt': prompt, 'upscale': upscale}
59
+ output = txt2panoimg(input_data)
60
+ return output
61
+
62
+ @spaces.GPU(duration=200)
63
+ def image_to_pano(image, mask, prompt, upscale):
64
+ image = image.resize((512, 512))
65
+ if mask is None:
66
+ mask = default_mask.resize((512, 512))
67
+ else:
68
+ mask = mask.resize((512, 512))
69
+ input_data = {
70
+ 'prompt': prompt,
71
+ 'image': image,
72
+ 'mask': mask,
73
+ 'upscale': upscale
74
+ }
75
+ output = img2panoimg(input_data)
76
+ return output
77
+
78
+ st.title("360° Panorama Image Generation")
79
+
80
+ tab1, tab2 = st.tabs(["Text to 360° Panorama", "Image to 360° Panorama"])
81
+
82
+ # Function to display the panorama viewer
83
+ def display_panorama(image):
84
+ streamlit_pannellum(
85
+ config={
86
+ "default": {
87
+ "firstScene": "generated",
88
+ },
89
+ "scenes": {
90
+ "generated": {
91
+ "title": "Generated Panorama",
92
+ "type": "equirectangular",
93
+ "panorama": image,
94
+ "autoLoad": True,
95
+ }
96
+ }
97
+ }
98
+ )
99
+
100
+ with tab1:
101
+ col1, col2 = st.columns([1, 1])
102
+
103
+ with col1:
104
+ st.subheader("Input")
105
+ t2p_input = st.text_area("Enter your prompt", height=100)
106
+ t2p_upscale = st.checkbox("Upscale (requires >16GB GPU)")
107
+ generate_button = st.button("Generate Panorama")
108
+
109
+ with col2:
110
+ st.subheader("Output")
111
+ output_placeholder = st.empty()
112
+ viewer_placeholder = st.empty()
113
+
114
+ if generate_button:
115
+ with st.spinner("Generating your 360° panorama..."):
116
+ output = text_to_pano(t2p_input, t2p_upscale)
117
+ output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
118
+ with viewer_placeholder.container():
119
+ display_panorama(output)
120
+
121
+ with tab2:
122
+ col1, col2 = st.columns([1, 1])
123
+
124
+ with col1:
125
+ st.subheader("Input")
126
+ i2p_image = st.file_uploader("Upload Input Image", type=["png", "jpg", "jpeg"])
127
+ i2p_mask = st.file_uploader("Upload Mask Image (Optional)", type=["png", "jpg", "jpeg"])
128
+ i2p_prompt = st.text_area("Enter your prompt", height=100)
129
+ i2p_upscale = st.checkbox("Upscale (requires >16GB GPU)", key="i2p_upscale")
130
+ generate_button = st.button("Generate Panorama", key="i2p_generate")
131
+
132
+ with col2:
133
+ st.subheader("Output")
134
+ output_placeholder = st.empty()
135
+ viewer_placeholder = st.empty()
136
+
137
+ if generate_button and i2p_image is not None:
138
+ with st.spinner("Generating your 360° panorama..."):
139
+ image = Image.open(i2p_image)
140
+ mask = Image.open(i2p_mask) if i2p_mask is not None else None
141
+ output = image_to_pano(image, mask, i2p_prompt, i2p_upscale)
142
+ output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
143
+ with viewer_placeholder.container():
144
+ display_panorama(output)
145
+ elif generate_button and i2p_image is None:
146
+ st.error("Please upload an input image.")
i2p-mask.jpg ADDED
img2panoimg/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline_i2p import StableDiffusionImage2PanoPipeline
2
+ from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
3
+ from .image_to_360panorama_image_pipeline import Image2360PanoramaImagePipeline
img2panoimg/image_to_360panorama_image_pipeline.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © Alibaba, Inc. and its affiliates.
2
+ import random
3
+ from typing import Any, Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import (ControlNetModel, DiffusionPipeline,
8
+ EulerAncestralDiscreteScheduler,
9
+ UniPCMultistepScheduler)
10
+ from PIL import Image
11
+ from RealESRGAN import RealESRGAN
12
+
13
+ from .pipeline_i2p import StableDiffusionImage2PanoPipeline
14
+ from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
15
+ import py360convert
16
+
17
+ class LazyRealESRGAN:
18
+ def __init__(self, device, scale):
19
+ self.device = device
20
+ self.scale = scale
21
+ self.model = None
22
+ self.model_path = None
23
+
24
+ def load_model(self):
25
+ if self.model is None:
26
+ self.model = RealESRGAN(self.device, scale=self.scale)
27
+ self.model.load_weights(self.model_path, download=False)
28
+
29
+ def predict(self, img):
30
+ self.load_model()
31
+ return self.model.predict(img)
32
+
33
+ class Image2360PanoramaImagePipeline(DiffusionPipeline):
34
+ """ Stable Diffusion for 360 Panorama Image Generation Pipeline.
35
+ Example:
36
+ >>> import torch
37
+ >>> from txt2panoimg import Text2360PanoramaImagePipeline
38
+ >>> prompt = 'The mountains'
39
+ >>> input = {'prompt': prompt, 'upscale': True}
40
+ >>> model_id = 'models/'
41
+ >>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16)
42
+ >>> output = txt2panoimg(input)
43
+ >>> output.save('result.png')
44
+ """
45
+
46
+ def __init__(self, model: str, device: str = 'cuda', **kwargs):
47
+ """
48
+ Use `model` to create a stable diffusion pipeline for 360 panorama image generation.
49
+ Args:
50
+ model: model id on modelscope hub.
51
+ device: str = 'cuda'
52
+ """
53
+ super().__init__()
54
+
55
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
56
+ ) if device is None else device
57
+ if device == 'gpu':
58
+ device = torch.device('cuda')
59
+
60
+ torch_dtype = kwargs.get('torch_dtype', torch.float16)
61
+ enable_xformers_memory_efficient_attention = kwargs.get(
62
+ 'enable_xformers_memory_efficient_attention', True)
63
+
64
+ model_id = model + '/sr-base/'
65
+
66
+ # init i2p model
67
+ controlnet = ControlNetModel.from_pretrained(model + '/sd-i2p', torch_dtype=torch.float16)
68
+
69
+ self.pipe = StableDiffusionImage2PanoPipeline.from_pretrained(
70
+ model_id, controlnet=controlnet, torch_dtype=torch_dtype).to(device)
71
+ self.pipe.vae.enable_tiling()
72
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
73
+ self.pipe.scheduler.config)
74
+ # remove following line if xformers is not installed
75
+ try:
76
+ if enable_xformers_memory_efficient_attention:
77
+ self.pipe.enable_xformers_memory_efficient_attention()
78
+ except Exception as e:
79
+ print(e)
80
+ self.pipe.enable_model_cpu_offload()
81
+
82
+ # init controlnet-sr model
83
+ base_model_path = model + '/sr-base'
84
+ controlnet_path = model + '/sr-control'
85
+ controlnet = ControlNetModel.from_pretrained(
86
+ controlnet_path, torch_dtype=torch_dtype)
87
+ self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(
88
+ base_model_path, controlnet=controlnet,
89
+ torch_dtype=torch_dtype).to(device)
90
+ self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config(
91
+ self.pipe.scheduler.config)
92
+ self.pipe_sr.vae.enable_tiling()
93
+ # remove following line if xformers is not installed
94
+ try:
95
+ if enable_xformers_memory_efficient_attention:
96
+ self.pipe_sr.enable_xformers_memory_efficient_attention()
97
+ except Exception as e:
98
+ print(e)
99
+ self.pipe_sr.enable_model_cpu_offload()
100
+ device = torch.device("cuda")
101
+ model_path = model + '/RealESRGAN_x2plus.pth'
102
+ self.upsampler = LazyRealESRGAN(device=device, scale=2)
103
+ self.upsampler.model_path = model_path
104
+
105
+ @staticmethod
106
+ def process_control_image(image, mask):
107
+ def to_tensor(img: Image, batch_size: int, width=1024, height=512):
108
+ img = img.resize((width, height), resample=Image.BICUBIC)
109
+ img = np.array(img).astype(np.float32) / 255.0
110
+ img = np.vstack([img[None].transpose(0, 3, 1, 2)] * batch_size)
111
+ img = torch.from_numpy(img)
112
+ return img
113
+
114
+ zeros = np.zeros_like(np.array(image))
115
+ dice_np = [np.array(image) if x == 0 else zeros for x in range(6)]
116
+ output_image = py360convert.c2e(dice_np, 512, 1024, cube_format='list')
117
+ bk_image = to_tensor(image, batch_size=1)
118
+
119
+ control_image = Image.fromarray(output_image.astype(np.uint8))
120
+ control_image = to_tensor(control_image, batch_size=1)
121
+ mask_image = to_tensor(mask, batch_size=1)
122
+
123
+ control_image = (1 - mask_image) * bk_image + mask_image * control_image
124
+
125
+ control_image = torch.cat([mask_image[:, :1, :, :], control_image], dim=1)
126
+
127
+ return control_image
128
+
129
+ @staticmethod
130
+ def blend_h(a, b, blend_extent):
131
+ a = np.array(a)
132
+ b = np.array(b)
133
+ blend_extent = min(a.shape[1], b.shape[1], blend_extent)
134
+ for x in range(blend_extent):
135
+ b[:, x, :] = a[:, -blend_extent
136
+ + x, :] * (1 - x / blend_extent) + b[:, x, :] * (
137
+ x / blend_extent)
138
+ return b
139
+
140
+ def __call__(self, inputs: Dict[str, Any],
141
+ **forward_params) -> Dict[str, Any]:
142
+ if not isinstance(inputs, dict):
143
+ raise ValueError(
144
+ f'Expected the input to be a dictionary, but got {type(input)}'
145
+ )
146
+ num_inference_steps = inputs.get('num_inference_steps', 20)
147
+ guidance_scale = inputs.get('guidance_scale', 7.0)
148
+ preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))'
149
+ add_prompt = inputs.get('add_prompt', preset_a_prompt)
150
+ preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\
151
+ 'low quality, zombie, logo, text, watermark, username, monochrome, '\
152
+ 'complex lighting'
153
+ negative_prompt = inputs.get('negative_prompt', preset_n_prompt)
154
+ seed = inputs.get('seed', -1)
155
+ upscale = inputs.get('upscale', True)
156
+ refinement = inputs.get('refinement', True)
157
+
158
+ guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15)
159
+ guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17)
160
+
161
+ image = inputs['image']
162
+ mask = inputs['mask']
163
+
164
+ control_image = self.process_control_image(image, mask)
165
+
166
+ if 'prompt' in inputs.keys():
167
+ prompt = inputs['prompt']
168
+ else:
169
+ # for demo_service
170
+ prompt = forward_params.get('prompt', 'the living room')
171
+
172
+ print(f'Test with prompt: {prompt}')
173
+
174
+ if seed == -1:
175
+ seed = random.randint(0, 65535)
176
+ print(f'global seed: {seed}')
177
+
178
+ generator = torch.manual_seed(seed)
179
+
180
+ prompt = '<360panorama>, ' + prompt + ', ' + add_prompt
181
+ output_img = self.pipe(
182
+ prompt,
183
+ image=(control_image[:, 1:, :, :] / 0.5 - 1.0),
184
+ control_image=control_image,
185
+ controlnet_conditioning_scale=1.0,
186
+ strength=1.0,
187
+ negative_prompt=negative_prompt,
188
+ num_inference_steps=num_inference_steps,
189
+ height=512,
190
+ width=1024,
191
+ guidance_scale=guidance_scale,
192
+ generator=generator).images[0]
193
+
194
+ if not upscale:
195
+ print('finished')
196
+ else:
197
+ print('inputs: upscale=True, running upscaler.')
198
+ print('running upscaler step1. Initial super-resolution')
199
+ sr_scale = 2.0
200
+ output_img = self.pipe_sr(
201
+ prompt.replace('<360panorama>, ', ''),
202
+ negative_prompt=negative_prompt,
203
+ image=output_img.resize(
204
+ (int(1536 * sr_scale), int(768 * sr_scale))),
205
+ num_inference_steps=7,
206
+ generator=generator,
207
+ control_image=output_img.resize(
208
+ (int(1536 * sr_scale), int(768 * sr_scale))),
209
+ strength=0.8,
210
+ controlnet_conditioning_scale=1.0,
211
+ guidance_scale=guidance_scale_sr_step1,
212
+ ).images[0]
213
+
214
+ print('running upscaler step2. Super-resolution with Real-ESRGAN')
215
+ output_img = output_img.resize((1536 * 2, 768 * 2))
216
+ w = output_img.size[0]
217
+ blend_extend = 10
218
+ outscale = 2
219
+ output_img = np.array(output_img)
220
+ output_img = np.concatenate(
221
+ [output_img, output_img[:, :blend_extend, :]], axis=1)
222
+ output_img = self.upsampler.predict(
223
+ output_img)
224
+ output_img = self.blend_h(output_img, output_img,
225
+ blend_extend * outscale)
226
+ output_img = Image.fromarray(output_img[:, :w * outscale, :])
227
+
228
+ if refinement:
229
+ print(
230
+ 'inputs: refinement=True, running refinement. This is a bit time-consuming.'
231
+ )
232
+ sr_scale = 4
233
+ output_img = self.pipe_sr(
234
+ prompt.replace('<360panorama>, ', ''),
235
+ negative_prompt=negative_prompt,
236
+ image=output_img.resize(
237
+ (int(1536 * sr_scale), int(768 * sr_scale))),
238
+ num_inference_steps=7,
239
+ generator=generator,
240
+ control_image=output_img.resize(
241
+ (int(1536 * sr_scale), int(768 * sr_scale))),
242
+ strength=0.8,
243
+ controlnet_conditioning_scale=1.0,
244
+ guidance_scale=guidance_scale_sr_step2,
245
+ ).images[0]
246
+ print('finished')
247
+
248
+ return output_img
img2panoimg/pipeline_i2p.py ADDED
@@ -0,0 +1,1740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © Alibaba, Inc. and its affiliates.
2
+ # The implementation here is modifed based on diffusers.StableDiffusionPipeline,
3
+ # originally Apache 2.0 License and public available at
4
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
5
+
6
+ import copy
7
+ import inspect
8
+ import re
9
+ import warnings
10
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
11
+
12
+ import os
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from diffusers import (AutoencoderKL, DiffusionPipeline, ControlNetModel, UNet2DConditionModel)
16
+ from diffusers.image_processor import VaeImageProcessor
17
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
18
+ try:
19
+ from diffusers.models.autoencoders.vae import DecoderOutput
20
+ except:
21
+ from diffusers.models.vae import DecoderOutput
22
+ from diffusers.models.controlnet import ControlNetOutput
23
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
24
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
25
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
26
+ StableDiffusionSafetyChecker
27
+ from diffusers.schedulers import KarrasDiffusionSchedulers
28
+ from diffusers.utils import (PIL_INTERPOLATION, deprecate, is_accelerate_available,
29
+ is_accelerate_version, logging,
30
+ replace_example_docstring)
31
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
32
+
33
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
34
+ import PIL
35
+ import numpy as np
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+ >>> from diffusers import EulerAncestralDiscreteScheduler
45
+ >>> from txt2panoimage.pipeline_base import StableDiffusionBlendExtendPipeline
46
+ >>> model_id = "models/sd-base"
47
+ >>> pipe = StableDiffusionBlendExtendPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
48
+ >>> pipe = pipe.to("cuda")
49
+ >>> pipe.vae.enable_tiling()
50
+ >>> pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
51
+ >>> # remove following line if xformers is not installed
52
+ >>> pipe.enable_xformers_memory_efficient_attention()
53
+ >>> pipe.enable_model_cpu_offload()
54
+ >>> prompt = "a living room"
55
+ >>> image = pipe(prompt).images[0]
56
+ ```
57
+ """
58
+
59
+ re_attention = re.compile(
60
+ r"""
61
+ \\\(|
62
+ \\\)|
63
+ \\\[|
64
+ \\]|
65
+ \\\\|
66
+ \\|
67
+ \(|
68
+ \[|
69
+ :([+-]?[.\d]+)\)|
70
+ \)|
71
+ ]|
72
+ [^\\()\[\]:]+|
73
+ :
74
+ """,
75
+ re.X,
76
+ )
77
+
78
+
79
+ def parse_prompt_attention(text):
80
+ """
81
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
82
+ Accepted tokens are:
83
+ (abc) - increases attention to abc by a multiplier of 1.1
84
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
85
+ [abc] - decreases attention to abc by a multiplier of 1.1
86
+ """
87
+
88
+ res = []
89
+ round_brackets = []
90
+ square_brackets = []
91
+
92
+ round_bracket_multiplier = 1.1
93
+ square_bracket_multiplier = 1 / 1.1
94
+
95
+ def multiply_range(start_position, multiplier):
96
+ for p in range(start_position, len(res)):
97
+ res[p][1] *= multiplier
98
+
99
+ for m in re_attention.finditer(text):
100
+ text = m.group(0)
101
+ weight = m.group(1)
102
+
103
+ if text.startswith('\\'):
104
+ res.append([text[1:], 1.0])
105
+ elif text == '(':
106
+ round_brackets.append(len(res))
107
+ elif text == '[':
108
+ square_brackets.append(len(res))
109
+ elif weight is not None and len(round_brackets) > 0:
110
+ multiply_range(round_brackets.pop(), float(weight))
111
+ elif text == ')' and len(round_brackets) > 0:
112
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
113
+ elif text == ']' and len(square_brackets) > 0:
114
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
115
+ else:
116
+ res.append([text, 1.0])
117
+
118
+ for pos in round_brackets:
119
+ multiply_range(pos, round_bracket_multiplier)
120
+
121
+ for pos in square_brackets:
122
+ multiply_range(pos, square_bracket_multiplier)
123
+
124
+ if len(res) == 0:
125
+ res = [['', 1.0]]
126
+
127
+ # merge runs of identical weights
128
+ i = 0
129
+ while i + 1 < len(res):
130
+ if res[i][1] == res[i + 1][1]:
131
+ res[i][0] += res[i + 1][0]
132
+ res.pop(i + 1)
133
+ else:
134
+ i += 1
135
+
136
+ return res
137
+
138
+
139
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str],
140
+ max_length: int):
141
+ r"""
142
+ Tokenize a list of prompts and return its tokens with weights of each token.
143
+
144
+ No padding, starting or ending token is included.
145
+ """
146
+ tokens = []
147
+ weights = []
148
+ truncated = False
149
+ for text in prompt:
150
+ texts_and_weights = parse_prompt_attention(text)
151
+ text_token = []
152
+ text_weight = []
153
+ for word, weight in texts_and_weights:
154
+ # tokenize and discard the starting and the ending token
155
+ token = pipe.tokenizer(word).input_ids[1:-1]
156
+ text_token += token
157
+ # copy the weight by length of token
158
+ text_weight += [weight] * len(token)
159
+ # stop if the text is too long (longer than truncation limit)
160
+ if len(text_token) > max_length:
161
+ truncated = True
162
+ break
163
+ # truncate
164
+ if len(text_token) > max_length:
165
+ truncated = True
166
+ text_token = text_token[:max_length]
167
+ text_weight = text_weight[:max_length]
168
+ tokens.append(text_token)
169
+ weights.append(text_weight)
170
+ if truncated:
171
+ logger.warning(
172
+ 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples'
173
+ )
174
+ return tokens, weights
175
+
176
+
177
+ def pad_tokens_and_weights(tokens,
178
+ weights,
179
+ max_length,
180
+ bos,
181
+ eos,
182
+ pad,
183
+ no_boseos_middle=True,
184
+ chunk_length=77):
185
+ r"""
186
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
187
+ """
188
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
189
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
190
+ for i in range(len(tokens)):
191
+ tokens[i] = [
192
+ bos
193
+ ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
194
+ if no_boseos_middle:
195
+ weights[i] = [1.0] + weights[i] + [1.0] * (
196
+ max_length - 1 - len(weights[i]))
197
+ else:
198
+ w = []
199
+ if len(weights[i]) == 0:
200
+ w = [1.0] * weights_length
201
+ else:
202
+ for j in range(max_embeddings_multiples):
203
+ w.append(1.0) # weight for starting token in this chunk
204
+ w += weights[i][j * (chunk_length - 2):min(
205
+ len(weights[i]), (j + 1) * (chunk_length - 2))]
206
+ w.append(1.0) # weight for ending token in this chunk
207
+ w += [1.0] * (weights_length - len(w))
208
+ weights[i] = w[:]
209
+
210
+ return tokens, weights
211
+
212
+
213
+ def get_unweighted_text_embeddings(
214
+ pipe: DiffusionPipeline,
215
+ text_input: torch.Tensor,
216
+ chunk_length: int,
217
+ no_boseos_middle: Optional[bool] = True,
218
+ ):
219
+ """
220
+ When the length of tokens is a multiple of the capacity of the text encoder,
221
+ it should be split into chunks and sent to the text encoder individually.
222
+ """
223
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
224
+ if max_embeddings_multiples > 1:
225
+ text_embeddings = []
226
+ for i in range(max_embeddings_multiples):
227
+ # extract the i-th chunk
228
+ text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1)
229
+ * (chunk_length - 2) + 2].clone()
230
+
231
+ # cover the head and the tail by the starting and the ending tokens
232
+ text_input_chunk[:, 0] = text_input[0, 0]
233
+ text_input_chunk[:, -1] = text_input[0, -1]
234
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
235
+
236
+ if no_boseos_middle:
237
+ if i == 0:
238
+ # discard the ending token
239
+ text_embedding = text_embedding[:, :-1]
240
+ elif i == max_embeddings_multiples - 1:
241
+ # discard the starting token
242
+ text_embedding = text_embedding[:, 1:]
243
+ else:
244
+ # discard both starting and ending tokens
245
+ text_embedding = text_embedding[:, 1:-1]
246
+
247
+ text_embeddings.append(text_embedding)
248
+ text_embeddings = torch.concat(text_embeddings, axis=1)
249
+ else:
250
+ text_embeddings = pipe.text_encoder(text_input)[0]
251
+ return text_embeddings
252
+
253
+
254
+ def get_weighted_text_embeddings(
255
+ pipe: DiffusionPipeline,
256
+ prompt: Union[str, List[str]],
257
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
258
+ max_embeddings_multiples: Optional[int] = 3,
259
+ no_boseos_middle: Optional[bool] = False,
260
+ skip_parsing: Optional[bool] = False,
261
+ skip_weighting: Optional[bool] = False,
262
+ ):
263
+ r"""
264
+ Prompts can be assigned with local weights using brackets. For example,
265
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
266
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
267
+
268
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
269
+
270
+ Args:
271
+ pipe (`DiffusionPipeline`):
272
+ Pipe to provide access to the tokenizer and the text encoder.
273
+ prompt (`str` or `List[str]`):
274
+ The prompt or prompts to guide the image generation.
275
+ uncond_prompt (`str` or `List[str]`):
276
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
277
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
278
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
279
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
280
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
281
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
282
+ ending token in each of the chunk in the middle.
283
+ skip_parsing (`bool`, *optional*, defaults to `False`):
284
+ Skip the parsing of brackets.
285
+ skip_weighting (`bool`, *optional*, defaults to `False`):
286
+ Skip the weighting. When the parsing is skipped, it is forced True.
287
+ """
288
+ max_length = (pipe.tokenizer.model_max_length
289
+ - 2) * max_embeddings_multiples + 2
290
+ if isinstance(prompt, str):
291
+ prompt = [prompt]
292
+
293
+ if not skip_parsing:
294
+ prompt_tokens, prompt_weights = get_prompts_with_weights(
295
+ pipe, prompt, max_length - 2)
296
+ if uncond_prompt is not None:
297
+ if isinstance(uncond_prompt, str):
298
+ uncond_prompt = [uncond_prompt]
299
+ uncond_tokens, uncond_weights = get_prompts_with_weights(
300
+ pipe, uncond_prompt, max_length - 2)
301
+ else:
302
+ prompt_tokens = [
303
+ token[1:-1] for token in pipe.tokenizer(
304
+ prompt, max_length=max_length, truncation=True).input_ids
305
+ ]
306
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
307
+ if uncond_prompt is not None:
308
+ if isinstance(uncond_prompt, str):
309
+ uncond_prompt = [uncond_prompt]
310
+ uncond_tokens = [
311
+ token[1:-1] for token in pipe.tokenizer(
312
+ uncond_prompt, max_length=max_length,
313
+ truncation=True).input_ids
314
+ ]
315
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
316
+
317
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
318
+ max_length = max([len(token) for token in prompt_tokens])
319
+ if uncond_prompt is not None:
320
+ max_length = max(max_length,
321
+ max([len(token) for token in uncond_tokens]))
322
+
323
+ max_embeddings_multiples = min(
324
+ max_embeddings_multiples,
325
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
326
+ )
327
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
328
+ max_length = (pipe.tokenizer.model_max_length
329
+ - 2) * max_embeddings_multiples + 2
330
+
331
+ # pad the length of tokens and weights
332
+ bos = pipe.tokenizer.bos_token_id
333
+ eos = pipe.tokenizer.eos_token_id
334
+ pad = getattr(pipe.tokenizer, 'pad_token_id', eos)
335
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
336
+ prompt_tokens,
337
+ prompt_weights,
338
+ max_length,
339
+ bos,
340
+ eos,
341
+ pad,
342
+ no_boseos_middle=no_boseos_middle,
343
+ chunk_length=pipe.tokenizer.model_max_length,
344
+ )
345
+ prompt_tokens = torch.tensor(
346
+ prompt_tokens, dtype=torch.long, device=pipe.device)
347
+ if uncond_prompt is not None:
348
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
349
+ uncond_tokens,
350
+ uncond_weights,
351
+ max_length,
352
+ bos,
353
+ eos,
354
+ pad,
355
+ no_boseos_middle=no_boseos_middle,
356
+ chunk_length=pipe.tokenizer.model_max_length,
357
+ )
358
+ uncond_tokens = torch.tensor(
359
+ uncond_tokens, dtype=torch.long, device=pipe.device)
360
+
361
+ # get the embeddings
362
+ text_embeddings = get_unweighted_text_embeddings(
363
+ pipe,
364
+ prompt_tokens,
365
+ pipe.tokenizer.model_max_length,
366
+ no_boseos_middle=no_boseos_middle,
367
+ )
368
+ prompt_weights = torch.tensor(
369
+ prompt_weights,
370
+ dtype=text_embeddings.dtype,
371
+ device=text_embeddings.device)
372
+ if uncond_prompt is not None:
373
+ uncond_embeddings = get_unweighted_text_embeddings(
374
+ pipe,
375
+ uncond_tokens,
376
+ pipe.tokenizer.model_max_length,
377
+ no_boseos_middle=no_boseos_middle,
378
+ )
379
+ uncond_weights = torch.tensor(
380
+ uncond_weights,
381
+ dtype=uncond_embeddings.dtype,
382
+ device=uncond_embeddings.device)
383
+
384
+ # assign weights to the prompts and normalize in the sense of mean
385
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
386
+ if (not skip_parsing) and (not skip_weighting):
387
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
388
+ text_embeddings.dtype)
389
+ text_embeddings *= prompt_weights.unsqueeze(-1)
390
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
391
+ text_embeddings.dtype)
392
+ text_embeddings *= (previous_mean
393
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
394
+ if uncond_prompt is not None:
395
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
396
+ uncond_embeddings.dtype)
397
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
398
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
399
+ uncond_embeddings.dtype)
400
+ uncond_embeddings *= (previous_mean
401
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
402
+
403
+ if uncond_prompt is not None:
404
+ return text_embeddings, uncond_embeddings
405
+ return text_embeddings, None
406
+
407
+
408
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
409
+ """
410
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
411
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
412
+ """
413
+ std_text = noise_pred_text.std(
414
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
415
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
416
+ # rescale the results from guidance (fixes overexposure)
417
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
418
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
419
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (
420
+ 1 - guidance_rescale) * noise_cfg
421
+ return noise_cfg
422
+
423
+
424
+ def prepare_image(image):
425
+ if isinstance(image, torch.Tensor):
426
+ # Batch single image
427
+ if image.ndim == 3:
428
+ image = image.unsqueeze(0)
429
+
430
+ image = image.to(dtype=torch.float32)
431
+ else:
432
+ # preprocess image
433
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
434
+ image = [image]
435
+
436
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
437
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
438
+ image = np.concatenate(image, axis=0)
439
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
440
+ image = np.concatenate([i[None, :] for i in image], axis=0)
441
+
442
+ image = image.transpose(0, 3, 1, 2)
443
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
444
+
445
+ return image
446
+
447
+ class StableDiffusionImage2PanoPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
448
+ r"""
449
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
450
+
451
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
452
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
453
+
454
+ In addition the pipeline inherits the following loading methods:
455
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
456
+
457
+ Args:
458
+ vae ([`AutoencoderKL`]):
459
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
460
+ text_encoder ([`CLIPTextModel`]):
461
+ Frozen text-encoder. Stable Diffusion uses the text portion of
462
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
463
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
464
+ tokenizer (`CLIPTokenizer`):
465
+ Tokenizer of class
466
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
467
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
468
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
469
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
470
+ as a list, the outputs from each ControlNet are added together to create one combined additional
471
+ conditioning.
472
+ scheduler ([`SchedulerMixin`]):
473
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
474
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
475
+ safety_checker ([`StableDiffusionSafetyChecker`]):
476
+ Classification module that estimates whether generated images could be considered offensive or harmful.
477
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
478
+ feature_extractor ([`CLIPImageProcessor`]):
479
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
480
+ """
481
+ _optional_components = ["safety_checker", "feature_extractor"]
482
+
483
+ def __init__(
484
+ self,
485
+ vae: AutoencoderKL,
486
+ text_encoder: CLIPTextModel,
487
+ tokenizer: CLIPTokenizer,
488
+ unet: UNet2DConditionModel,
489
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
490
+ scheduler: KarrasDiffusionSchedulers,
491
+ safety_checker: StableDiffusionSafetyChecker,
492
+ feature_extractor: CLIPImageProcessor,
493
+ requires_safety_checker: bool = True,
494
+ ):
495
+ super().__init__()
496
+
497
+ if safety_checker is None and requires_safety_checker:
498
+ logger.warning(
499
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
500
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
501
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
502
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
503
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
504
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
505
+ )
506
+
507
+ if safety_checker is not None and feature_extractor is None:
508
+ raise ValueError(
509
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
510
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
511
+ )
512
+
513
+ if isinstance(controlnet, (list, tuple)):
514
+ controlnet = MultiControlNetModel(controlnet)
515
+
516
+ self.register_modules(
517
+ vae=vae,
518
+ text_encoder=text_encoder,
519
+ tokenizer=tokenizer,
520
+ unet=unet,
521
+ controlnet=controlnet,
522
+ scheduler=scheduler,
523
+ safety_checker=safety_checker,
524
+ feature_extractor=feature_extractor,
525
+ )
526
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
527
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
528
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
529
+
530
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
531
+ def enable_vae_slicing(self):
532
+ r"""
533
+ Enable sliced VAE decoding.
534
+
535
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
536
+ steps. This is useful to save some memory and allow larger batch sizes.
537
+ """
538
+ self.vae.enable_slicing()
539
+
540
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
541
+ def disable_vae_slicing(self):
542
+ r"""
543
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
544
+ computing decoding in one step.
545
+ """
546
+ self.vae.disable_slicing()
547
+
548
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
549
+ def enable_vae_tiling(self):
550
+ r"""
551
+ Enable tiled VAE decoding.
552
+
553
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
554
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
555
+ """
556
+ self.vae.enable_tiling()
557
+
558
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
559
+ def disable_vae_tiling(self):
560
+ r"""
561
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
562
+ computing decoding in one step.
563
+ """
564
+ self.vae.disable_tiling()
565
+
566
+ def enable_sequential_cpu_offload(self, gpu_id=0):
567
+ r"""
568
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
569
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
570
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
571
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
572
+ `enable_model_cpu_offload`, but performance is lower.
573
+ """
574
+ if is_accelerate_available():
575
+ from accelerate import cpu_offload
576
+ else:
577
+ raise ImportError("Please install accelerate via `pip install accelerate`")
578
+
579
+ device = torch.device(f"cuda:{gpu_id}")
580
+
581
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
582
+ cpu_offload(cpu_offloaded_model, device)
583
+
584
+ if self.safety_checker is not None:
585
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
586
+
587
+ def enable_model_cpu_offload(self, gpu_id=0):
588
+ r"""
589
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
590
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
591
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
592
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
593
+ """
594
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
595
+ from accelerate import cpu_offload_with_hook
596
+ else:
597
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
598
+
599
+ device = torch.device(f"cuda:{gpu_id}")
600
+
601
+ hook = None
602
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
603
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
604
+
605
+ if self.safety_checker is not None:
606
+ # the safety checker can offload the vae again
607
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
608
+
609
+ # control net hook has be manually offloaded as it alternates with unet
610
+ cpu_offload_with_hook(self.controlnet, device)
611
+
612
+ # We'll offload the last model manually.
613
+ self.final_offload_hook = hook
614
+
615
+ @property
616
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
617
+ def _execution_device(self):
618
+ r"""
619
+ Returns the device on which the pipeline's models will be executed. After calling
620
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
621
+ hooks.
622
+ """
623
+ if not hasattr(self.unet, "_hf_hook"):
624
+ return self.device
625
+ for module in self.unet.modules():
626
+ if (
627
+ hasattr(module, "_hf_hook")
628
+ and hasattr(module._hf_hook, "execution_device")
629
+ and module._hf_hook.execution_device is not None
630
+ ):
631
+ return torch.device(module._hf_hook.execution_device)
632
+ return self.device
633
+
634
+ def _encode_prompt(
635
+ self,
636
+ prompt,
637
+ device,
638
+ num_images_per_prompt,
639
+ do_classifier_free_guidance,
640
+ negative_prompt=None,
641
+ max_embeddings_multiples=3,
642
+ prompt_embeds: Optional[torch.FloatTensor] = None,
643
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
644
+ lora_scale: Optional[float] = None,
645
+ ):
646
+ r"""
647
+ Encodes the prompt into text encoder hidden states.
648
+
649
+ Args:
650
+ prompt (`str` or `list(int)`):
651
+ prompt to be encoded
652
+ device: (`torch.device`):
653
+ torch device
654
+ num_images_per_prompt (`int`):
655
+ number of images that should be generated per prompt
656
+ do_classifier_free_guidance (`bool`):
657
+ whether to use classifier free guidance or not
658
+ negative_prompt (`str` or `List[str]`):
659
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
660
+ if `guidance_scale` is less than `1`).
661
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
662
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
663
+ """
664
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
665
+ self._lora_scale = lora_scale
666
+
667
+ if prompt is not None and isinstance(prompt, str):
668
+ batch_size = 1
669
+ elif prompt is not None and isinstance(prompt, list):
670
+ batch_size = len(prompt)
671
+ else:
672
+ batch_size = prompt_embeds.shape[0]
673
+
674
+ if negative_prompt_embeds is None:
675
+ if negative_prompt is None:
676
+ negative_prompt = [""] * batch_size
677
+ elif isinstance(negative_prompt, str):
678
+ negative_prompt = [negative_prompt] * batch_size
679
+ if batch_size != len(negative_prompt):
680
+ raise ValueError(
681
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
682
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
683
+ " the batch size of `prompt`."
684
+ )
685
+ if prompt_embeds is None or negative_prompt_embeds is None:
686
+ if isinstance(self, TextualInversionLoaderMixin):
687
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
688
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
689
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
690
+
691
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
692
+ pipe=self,
693
+ prompt=prompt,
694
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
695
+ max_embeddings_multiples=max_embeddings_multiples,
696
+ )
697
+ if prompt_embeds is None:
698
+ prompt_embeds = prompt_embeds1
699
+ if negative_prompt_embeds is None:
700
+ negative_prompt_embeds = negative_prompt_embeds1
701
+
702
+ bs_embed, seq_len, _ = prompt_embeds.shape
703
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
704
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
705
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
706
+
707
+ if do_classifier_free_guidance:
708
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
709
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
710
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
711
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
712
+
713
+ return prompt_embeds
714
+
715
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
716
+ def run_safety_checker(self, image, device, dtype):
717
+ if self.safety_checker is None:
718
+ has_nsfw_concept = None
719
+ else:
720
+ if torch.is_tensor(image):
721
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
722
+ else:
723
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
724
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
725
+ image, has_nsfw_concept = self.safety_checker(
726
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
727
+ )
728
+ return image, has_nsfw_concept
729
+
730
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
731
+ def decode_latents(self, latents):
732
+ warnings.warn(
733
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
734
+ " use VaeImageProcessor instead",
735
+ FutureWarning,
736
+ )
737
+ latents = 1 / self.vae.config.scaling_factor * latents
738
+ image = self.vae.decode(latents, return_dict=False)[0]
739
+ image = (image / 2 + 0.5).clamp(0, 1)
740
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
741
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
742
+ return image
743
+
744
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
745
+ def prepare_extra_step_kwargs(self, generator, eta):
746
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
747
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
748
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
749
+ # and should be between [0, 1]
750
+
751
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
752
+ extra_step_kwargs = {}
753
+ if accepts_eta:
754
+ extra_step_kwargs["eta"] = eta
755
+
756
+ # check if the scheduler accepts generator
757
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
758
+ if accepts_generator:
759
+ extra_step_kwargs["generator"] = generator
760
+ return extra_step_kwargs
761
+
762
+ def check_inputs(
763
+ self,
764
+ prompt,
765
+ image,
766
+ height,
767
+ width,
768
+ callback_steps,
769
+ negative_prompt=None,
770
+ prompt_embeds=None,
771
+ negative_prompt_embeds=None,
772
+ controlnet_conditioning_scale=1.0,
773
+ ):
774
+ if height % 8 != 0 or width % 8 != 0:
775
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
776
+
777
+ if (callback_steps is None) or (
778
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
779
+ ):
780
+ raise ValueError(
781
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
782
+ f" {type(callback_steps)}."
783
+ )
784
+
785
+ if prompt is not None and prompt_embeds is not None:
786
+ raise ValueError(
787
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
788
+ " only forward one of the two."
789
+ )
790
+ elif prompt is None and prompt_embeds is None:
791
+ raise ValueError(
792
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
793
+ )
794
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
795
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
796
+
797
+ if negative_prompt is not None and negative_prompt_embeds is not None:
798
+ raise ValueError(
799
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
800
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
801
+ )
802
+
803
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
804
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
805
+ raise ValueError(
806
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
807
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
808
+ f" {negative_prompt_embeds.shape}."
809
+ )
810
+
811
+ # `prompt` needs more sophisticated handling when there are multiple
812
+ # conditionings.
813
+ if isinstance(self.controlnet, MultiControlNetModel):
814
+ if isinstance(prompt, list):
815
+ logger.warning(
816
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
817
+ " prompts. The conditionings will be fixed across the prompts."
818
+ )
819
+
820
+ # Check `image`
821
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
822
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
823
+ )
824
+ if (
825
+ isinstance(self.controlnet, ControlNetModel)
826
+ or is_compiled
827
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
828
+ ):
829
+ self.check_image(image, prompt, prompt_embeds)
830
+ elif (
831
+ isinstance(self.controlnet, MultiControlNetModel)
832
+ or is_compiled
833
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
834
+ ):
835
+ if not isinstance(image, list):
836
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
837
+
838
+ # When `image` is a nested list:
839
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
840
+ elif any(isinstance(i, list) for i in image):
841
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
842
+ elif len(image) != len(self.controlnet.nets):
843
+ raise ValueError(
844
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
845
+ )
846
+
847
+ for image_ in image:
848
+ self.check_image(image_, prompt, prompt_embeds)
849
+ else:
850
+ assert False
851
+
852
+ # Check `controlnet_conditioning_scale`
853
+ if (
854
+ isinstance(self.controlnet, ControlNetModel)
855
+ or is_compiled
856
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
857
+ ):
858
+ if not isinstance(controlnet_conditioning_scale, float):
859
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
860
+ elif (
861
+ isinstance(self.controlnet, MultiControlNetModel)
862
+ or is_compiled
863
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
864
+ ):
865
+ if isinstance(controlnet_conditioning_scale, list):
866
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
867
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
868
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
869
+ self.controlnet.nets
870
+ ):
871
+ raise ValueError(
872
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
873
+ " the same length as the number of controlnets"
874
+ )
875
+ else:
876
+ assert False
877
+
878
+ def check_image(self, image, prompt, prompt_embeds):
879
+ image_is_pil = isinstance(image, PIL.Image.Image)
880
+ image_is_tensor = isinstance(image, torch.Tensor)
881
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
882
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
883
+
884
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
885
+ raise TypeError(
886
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
887
+ )
888
+
889
+ if image_is_pil:
890
+ image_batch_size = 1
891
+ elif image_is_tensor:
892
+ image_batch_size = image.shape[0]
893
+ elif image_is_pil_list:
894
+ image_batch_size = len(image)
895
+ elif image_is_tensor_list:
896
+ image_batch_size = len(image)
897
+
898
+ if prompt is not None and isinstance(prompt, str):
899
+ prompt_batch_size = 1
900
+ elif prompt is not None and isinstance(prompt, list):
901
+ prompt_batch_size = len(prompt)
902
+ elif prompt_embeds is not None:
903
+ prompt_batch_size = prompt_embeds.shape[0]
904
+
905
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
906
+ raise ValueError(
907
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
908
+ )
909
+
910
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
911
+ def prepare_control_image(
912
+ self,
913
+ image,
914
+ width,
915
+ height,
916
+ batch_size,
917
+ num_images_per_prompt,
918
+ device,
919
+ dtype,
920
+ do_classifier_free_guidance=False,
921
+ guess_mode=False,
922
+ ):
923
+ if not isinstance(image, torch.Tensor):
924
+ if isinstance(image, PIL.Image.Image):
925
+ image = [image]
926
+
927
+ if isinstance(image[0], PIL.Image.Image):
928
+ images = []
929
+
930
+ for image_ in image:
931
+ image_ = image_.convert("RGB")
932
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
933
+ image_ = np.array(image_)
934
+ image_ = image_[None, :]
935
+ images.append(image_)
936
+
937
+ image = images
938
+
939
+ image = np.concatenate(image, axis=0)
940
+ image = np.array(image).astype(np.float32) / 255.0
941
+ image = image.transpose(0, 3, 1, 2)
942
+ image = torch.from_numpy(image)
943
+ elif isinstance(image[0], torch.Tensor):
944
+ image = torch.cat(image, dim=0)
945
+
946
+ image_batch_size = image.shape[0]
947
+
948
+ if image_batch_size == 1:
949
+ repeat_by = batch_size
950
+ else:
951
+ # image batch size is the same as prompt batch size
952
+ repeat_by = num_images_per_prompt
953
+
954
+ image = image.repeat_interleave(repeat_by, dim=0)
955
+
956
+ image = image.to(device=device, dtype=dtype)
957
+
958
+ if do_classifier_free_guidance and not guess_mode:
959
+ image = torch.cat([image] * 2)
960
+
961
+ return image
962
+
963
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
964
+ def get_timesteps(self, num_inference_steps, strength, device):
965
+ # get the original timestep using init_timestep
966
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
967
+
968
+ t_start = max(num_inference_steps - init_timestep, 0)
969
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
970
+
971
+ return timesteps, num_inference_steps - t_start
972
+
973
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
974
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
975
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
976
+ raise ValueError(
977
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
978
+ )
979
+
980
+ image = image.to(device=device, dtype=dtype)
981
+
982
+ batch_size = batch_size * num_images_per_prompt
983
+ if isinstance(generator, list) and len(generator) != batch_size:
984
+ raise ValueError(
985
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
986
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
987
+ )
988
+
989
+ if isinstance(generator, list):
990
+ init_latents = [
991
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
992
+ ]
993
+ init_latents = torch.cat(init_latents, dim=0)
994
+ else:
995
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
996
+
997
+ init_latents = self.vae.config.scaling_factor * init_latents
998
+
999
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
1000
+ # expand init_latents for batch_size
1001
+ deprecation_message = (
1002
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
1003
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
1004
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
1005
+ " your script to pass as many initial images as text prompts to suppress this warning."
1006
+ )
1007
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
1008
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
1009
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
1010
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
1011
+ raise ValueError(
1012
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
1013
+ )
1014
+ else:
1015
+ init_latents = torch.cat([init_latents], dim=0)
1016
+
1017
+ shape = init_latents.shape
1018
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1019
+
1020
+ # get latents
1021
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
1022
+ latents = init_latents
1023
+
1024
+ return latents
1025
+
1026
+ def _default_height_width(self, height, width, image):
1027
+ # NOTE: It is possible that a list of images have different
1028
+ # dimensions for each image, so just checking the first image
1029
+ # is not _exactly_ correct, but it is simple.
1030
+ while isinstance(image, list):
1031
+ image = image[0]
1032
+
1033
+ if height is None:
1034
+ if isinstance(image, PIL.Image.Image):
1035
+ height = image.height
1036
+ elif isinstance(image, torch.Tensor):
1037
+ height = image.shape[2]
1038
+
1039
+ height = (height // 8) * 8 # round down to nearest multiple of 8
1040
+
1041
+ if width is None:
1042
+ if isinstance(image, PIL.Image.Image):
1043
+ width = image.width
1044
+ elif isinstance(image, torch.Tensor):
1045
+ width = image.shape[3]
1046
+
1047
+ width = (width // 8) * 8 # round down to nearest multiple of 8
1048
+
1049
+ return height, width
1050
+
1051
+ # override DiffusionPipeline
1052
+ def save_pretrained(
1053
+ self,
1054
+ save_directory: Union[str, os.PathLike],
1055
+ safe_serialization: bool = False,
1056
+ variant: Optional[str] = None,
1057
+ ):
1058
+ if isinstance(self.controlnet, ControlNetModel):
1059
+ super().save_pretrained(save_directory, safe_serialization, variant)
1060
+ else:
1061
+ raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
1062
+
1063
+ def denoise_latents(self, latents, t, prompt_embeds, control_image, controlnet_conditioning_scale, guess_mode, cross_attention_kwargs, do_classifier_free_guidance, guidance_scale, extra_step_kwargs, views_scheduler_status):
1064
+ # expand the latents if we are doing classifier free guidance
1065
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1066
+ self.scheduler.__dict__.update(views_scheduler_status[0])
1067
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1068
+
1069
+ # controlnet(s) inference
1070
+ if guess_mode and do_classifier_free_guidance:
1071
+ # Infer ControlNet only for the conditional batch.
1072
+ controlnet_latent_model_input = latents
1073
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1074
+ else:
1075
+ controlnet_latent_model_input = latent_model_input
1076
+ controlnet_prompt_embeds = prompt_embeds
1077
+
1078
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1079
+ controlnet_latent_model_input,
1080
+ t,
1081
+ encoder_hidden_states=controlnet_prompt_embeds,
1082
+ controlnet_cond=control_image,
1083
+ conditioning_scale=controlnet_conditioning_scale,
1084
+ guess_mode=guess_mode,
1085
+ return_dict=False,
1086
+ )
1087
+
1088
+ if guess_mode and do_classifier_free_guidance:
1089
+ # Infered ControlNet only for the conditional batch.
1090
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1091
+ # add 0 to the unconditional batch to keep it unchanged.
1092
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1093
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1094
+
1095
+ # predict the noise residual
1096
+ noise_pred = self.unet(
1097
+ latent_model_input,
1098
+ t,
1099
+ encoder_hidden_states=prompt_embeds,
1100
+ cross_attention_kwargs=cross_attention_kwargs,
1101
+ down_block_additional_residuals=down_block_res_samples,
1102
+ mid_block_additional_residual=mid_block_res_sample,
1103
+ return_dict=False,
1104
+ )[0]
1105
+
1106
+ # perform guidance
1107
+ if do_classifier_free_guidance:
1108
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1109
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1110
+
1111
+ # compute the previous noisy sample x_t -> x_t-1
1112
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1113
+ return latents
1114
+
1115
+ def blend_v(self, a, b, blend_extent):
1116
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
1117
+ for y in range(blend_extent):
1118
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
1119
+ return b
1120
+
1121
+ def blend_h(self, a, b, blend_extent):
1122
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1123
+ for x in range(blend_extent):
1124
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
1125
+ return b
1126
+
1127
+ def get_blocks(self, latents, control_image, tile_latent_min_size, overlap_size):
1128
+ rows_latents = []
1129
+ rows_control_images = []
1130
+ for i in range(0, latents.shape[2] - overlap_size, overlap_size):
1131
+ row_latents = []
1132
+ row_control_images = []
1133
+ for j in range(0, latents.shape[3] - overlap_size, overlap_size):
1134
+ latents_input = latents[:, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
1135
+ control_image_input = control_image[:, :,
1136
+ self.vae_scale_factor * i: self.vae_scale_factor * (i + tile_latent_min_size),
1137
+ self.vae_scale_factor * j: self.vae_scale_factor * (j + tile_latent_min_size)]
1138
+ row_latents.append(latents_input)
1139
+ row_control_images.append(control_image_input)
1140
+ rows_latents.append(row_latents)
1141
+ rows_control_images.append(row_control_images)
1142
+ return rows_latents, rows_control_images
1143
+
1144
+ @torch.no_grad()
1145
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1146
+ def __call__(
1147
+ self,
1148
+ prompt: Union[str, List[str]] = None,
1149
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
1150
+ control_image: Union[
1151
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
1152
+ ] = None,
1153
+ height: Optional[int] = None,
1154
+ width: Optional[int] = None,
1155
+ strength: float = 0.8,
1156
+ num_inference_steps: int = 50,
1157
+ guidance_scale: float = 7.5,
1158
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1159
+ num_images_per_prompt: Optional[int] = 1,
1160
+ eta: float = 0.0,
1161
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1162
+ latents: Optional[torch.FloatTensor] = None,
1163
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1164
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1165
+ output_type: Optional[str] = "pil",
1166
+ return_dict: bool = True,
1167
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1168
+ callback_steps: int = 1,
1169
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1170
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
1171
+ guess_mode: bool = False,
1172
+ mask: Optional[torch.FloatTensor] = None,
1173
+ ):
1174
+ r"""
1175
+ Function invoked when calling the pipeline for generation.
1176
+
1177
+ Args:
1178
+ prompt (`str` or `List[str]`, *optional*):
1179
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1180
+ instead.
1181
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
1182
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
1183
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
1184
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
1185
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
1186
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
1187
+ specified in init, images must be passed as a list such that each element of the list can be correctly
1188
+ batched for input to a single controlnet.
1189
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1190
+ The height in pixels of the generated image.
1191
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1192
+ The width in pixels of the generated image.
1193
+ num_inference_steps (`int`, *optional*, defaults to 50):
1194
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1195
+ expense of slower inference.
1196
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1197
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1198
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1199
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1200
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1201
+ usually at the expense of lower image quality.
1202
+ negative_prompt (`str` or `List[str]`, *optional*):
1203
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1204
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1205
+ less than `1`).
1206
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1207
+ The number of images to generate per prompt.
1208
+ eta (`float`, *optional*, defaults to 0.0):
1209
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1210
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1211
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1212
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1213
+ to make generation deterministic.
1214
+ latents (`torch.FloatTensor`, *optional*):
1215
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1216
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1217
+ tensor will ge generated by sampling using the supplied random `generator`.
1218
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1219
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1220
+ provided, text embeddings will be generated from `prompt` input argument.
1221
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1222
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1223
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1224
+ argument.
1225
+ output_type (`str`, *optional*, defaults to `"pil"`):
1226
+ The output format of the generate image. Choose between
1227
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1228
+ return_dict (`bool`, *optional*, defaults to `True`):
1229
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1230
+ plain tuple.
1231
+ callback (`Callable`, *optional*):
1232
+ A function that will be called every `callback_steps` steps during inference. The function will be
1233
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1234
+ callback_steps (`int`, *optional*, defaults to 1):
1235
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1236
+ called at every step.
1237
+ cross_attention_kwargs (`dict`, *optional*):
1238
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1239
+ `self.processor` in
1240
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1241
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1242
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1243
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1244
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
1245
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
1246
+ guess_mode (`bool`, *optional*, defaults to `False`):
1247
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
1248
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1249
+
1250
+ Examples:
1251
+
1252
+ Returns:
1253
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1254
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1255
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1256
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1257
+ (nsfw) content, according to the `safety_checker`.
1258
+ """
1259
+
1260
+ def controlnet_forward(
1261
+ self,
1262
+ sample: torch.FloatTensor,
1263
+ timestep: Union[torch.Tensor, float, int],
1264
+ encoder_hidden_states: torch.Tensor,
1265
+ controlnet_cond: torch.FloatTensor,
1266
+ conditioning_scale: float = 1.0,
1267
+ class_labels: Optional[torch.Tensor] = None,
1268
+ timestep_cond: Optional[torch.Tensor] = None,
1269
+ attention_mask: Optional[torch.Tensor] = None,
1270
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1271
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1272
+ guess_mode: bool = False,
1273
+ return_dict: bool = True,
1274
+ mask: Optional[torch.FloatTensor] = None,
1275
+ ) -> Union[ControlNetOutput, Tuple]:
1276
+ """
1277
+ The [`ControlNetModel`] forward method.
1278
+
1279
+ Args:
1280
+ sample (`torch.FloatTensor`):
1281
+ The noisy input tensor.
1282
+ timestep (`Union[torch.Tensor, float, int]`):
1283
+ The number of timesteps to denoise an input.
1284
+ encoder_hidden_states (`torch.Tensor`):
1285
+ The encoder hidden states.
1286
+ controlnet_cond (`torch.FloatTensor`):
1287
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
1288
+ conditioning_scale (`float`, defaults to `1.0`):
1289
+ The scale factor for ControlNet outputs.
1290
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1291
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1292
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
1293
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1294
+ added_cond_kwargs (`dict`):
1295
+ Additional conditions for the Stable Diffusion XL UNet.
1296
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
1297
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
1298
+ guess_mode (`bool`, defaults to `False`):
1299
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
1300
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
1301
+ return_dict (`bool`, defaults to `True`):
1302
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
1303
+
1304
+ Returns:
1305
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
1306
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
1307
+ returned where the first element is the sample tensor.
1308
+ """
1309
+ # check channel order
1310
+ channel_order = self.config.controlnet_conditioning_channel_order
1311
+
1312
+ if channel_order == "rgb":
1313
+ # in rgb order by default
1314
+ ...
1315
+ elif channel_order == "bgr":
1316
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
1317
+ else:
1318
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
1319
+
1320
+ # prepare attention_mask
1321
+ if attention_mask is not None:
1322
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1323
+ attention_mask = attention_mask.unsqueeze(1)
1324
+
1325
+ # 1. time
1326
+ timesteps = timestep
1327
+ if not torch.is_tensor(timesteps):
1328
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1329
+ # This would be a good case for the `match` statement (Python 3.10+)
1330
+ is_mps = sample.device.type == "mps"
1331
+ if isinstance(timestep, float):
1332
+ dtype = torch.float32 if is_mps else torch.float64
1333
+ else:
1334
+ dtype = torch.int32 if is_mps else torch.int64
1335
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1336
+ elif len(timesteps.shape) == 0:
1337
+ timesteps = timesteps[None].to(sample.device)
1338
+
1339
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1340
+ timesteps = timesteps.expand(sample.shape[0])
1341
+
1342
+ t_emb = self.time_proj(timesteps)
1343
+
1344
+ # timesteps does not contain any weights and will always return f32 tensors
1345
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1346
+ # there might be better ways to encapsulate this.
1347
+ t_emb = t_emb.to(dtype=sample.dtype)
1348
+
1349
+ emb = self.time_embedding(t_emb, timestep_cond)
1350
+ aug_emb = None
1351
+
1352
+ if self.class_embedding is not None:
1353
+ if class_labels is None:
1354
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1355
+
1356
+ if self.config.class_embed_type == "timestep":
1357
+ class_labels = self.time_proj(class_labels)
1358
+
1359
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
1360
+ emb = emb + class_emb
1361
+
1362
+ if self.config.addition_embed_type is not None:
1363
+ if self.config.addition_embed_type == "text":
1364
+ aug_emb = self.add_embedding(encoder_hidden_states)
1365
+
1366
+ elif self.config.addition_embed_type == "text_time":
1367
+ if "text_embeds" not in added_cond_kwargs:
1368
+ raise ValueError(
1369
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1370
+ )
1371
+ text_embeds = added_cond_kwargs.get("text_embeds")
1372
+ if "time_ids" not in added_cond_kwargs:
1373
+ raise ValueError(
1374
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1375
+ )
1376
+ time_ids = added_cond_kwargs.get("time_ids")
1377
+ time_embeds = self.add_time_proj(time_ids.flatten())
1378
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1379
+
1380
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1381
+ add_embeds = add_embeds.to(emb.dtype)
1382
+ aug_emb = self.add_embedding(add_embeds)
1383
+
1384
+ emb = emb + aug_emb if aug_emb is not None else emb
1385
+
1386
+ # 2. pre-process
1387
+ sample = self.conv_in(sample)
1388
+
1389
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
1390
+
1391
+ if mask is not None:
1392
+ sample = (1 - mask.to(sample.dtype)) * sample + mask.to(sample.dtype) * controlnet_cond
1393
+ else:
1394
+ sample = sample + controlnet_cond
1395
+
1396
+ # 3. down
1397
+ down_block_res_samples = (sample,)
1398
+ for downsample_block in self.down_blocks:
1399
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1400
+ sample, res_samples = downsample_block(
1401
+ hidden_states=sample,
1402
+ temb=emb,
1403
+ encoder_hidden_states=encoder_hidden_states,
1404
+ attention_mask=attention_mask,
1405
+ cross_attention_kwargs=cross_attention_kwargs,
1406
+ )
1407
+ else:
1408
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1409
+
1410
+ down_block_res_samples += res_samples
1411
+
1412
+ # 4. mid
1413
+ if self.mid_block is not None:
1414
+ sample = self.mid_block(
1415
+ sample,
1416
+ emb,
1417
+ encoder_hidden_states=encoder_hidden_states,
1418
+ attention_mask=attention_mask,
1419
+ cross_attention_kwargs=cross_attention_kwargs,
1420
+ )
1421
+
1422
+ # 5. Control net blocks
1423
+
1424
+ controlnet_down_block_res_samples = ()
1425
+
1426
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
1427
+ down_block_res_sample = controlnet_block(down_block_res_sample)
1428
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
1429
+
1430
+ down_block_res_samples = controlnet_down_block_res_samples
1431
+
1432
+ mid_block_res_sample = self.controlnet_mid_block(sample)
1433
+
1434
+ # 6. scaling
1435
+ if guess_mode and not self.config.global_pool_conditions:
1436
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
1437
+
1438
+ scales = scales * conditioning_scale
1439
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
1440
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
1441
+ else:
1442
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
1443
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
1444
+
1445
+ if self.config.global_pool_conditions:
1446
+ down_block_res_samples = [
1447
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
1448
+ ]
1449
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
1450
+
1451
+ if not return_dict:
1452
+ return (down_block_res_samples, mid_block_res_sample)
1453
+
1454
+ return ControlNetOutput(
1455
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
1456
+ )
1457
+ self.controlnet.forward = controlnet_forward.__get__(self.controlnet, ControlNetModel)
1458
+
1459
+ def tiled_decode(
1460
+ self,
1461
+ z: torch.FloatTensor,
1462
+ return_dict: bool = True
1463
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
1464
+ r"""Decode a batch of images using a tiled decoder.
1465
+
1466
+ Args:
1467
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
1468
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled
1469
+ decoding is: different from non-tiled decoding due to each tile using a different decoder.
1470
+ To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
1471
+ You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
1472
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
1473
+ `True`):
1474
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1475
+ """
1476
+ _tile_overlap_factor = 1 - self.tile_overlap_factor
1477
+ overlap_size = int(self.tile_latent_min_size
1478
+ * _tile_overlap_factor)
1479
+ blend_extent = int(self.tile_sample_min_size
1480
+ * self.tile_overlap_factor)
1481
+ row_limit = self.tile_sample_min_size - blend_extent
1482
+ w = z.shape[3]
1483
+ z = torch.cat([z, z[:, :, :, :w // 4]], dim=-1)
1484
+ # Split z into overlapping 64x64 tiles and decode them separately.
1485
+ # The tiles have an overlap to avoid seams between tiles.
1486
+
1487
+ rows = []
1488
+ for i in range(0, z.shape[2], overlap_size):
1489
+ row = []
1490
+ tile = z[:, :, i:i + self.tile_latent_min_size, :]
1491
+ tile = self.post_quant_conv(tile)
1492
+ decoded = self.decoder(tile)
1493
+ vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
1494
+ row.append(decoded)
1495
+ rows.append(row)
1496
+ result_rows = []
1497
+ for i, row in enumerate(rows):
1498
+ result_row = []
1499
+ for j, tile in enumerate(row):
1500
+ # blend the above tile and the left tile
1501
+ # to the current tile and add the current tile to the result row
1502
+ if i > 0:
1503
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
1504
+ if j > 0:
1505
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
1506
+ result_row.append(
1507
+ self.blend_h(
1508
+ tile[:, :, :row_limit, w * vae_scale_factor:],
1509
+ tile[:, :, :row_limit, :w * vae_scale_factor],
1510
+ tile.shape[-1] - w * vae_scale_factor))
1511
+ result_rows.append(torch.cat(result_row, dim=3))
1512
+
1513
+ dec = torch.cat(result_rows, dim=2)
1514
+ if not return_dict:
1515
+ return (dec, )
1516
+
1517
+ return DecoderOutput(sample=dec)
1518
+
1519
+ self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
1520
+
1521
+ # 0. Default height and width to unet
1522
+ height, width = self._default_height_width(height, width, image)
1523
+ self.blend_extend = width // self.vae_scale_factor // 32
1524
+
1525
+ # 1. Check inputs. Raise error if not correct
1526
+ self.check_inputs(
1527
+ prompt,
1528
+ control_image,
1529
+ height,
1530
+ width,
1531
+ callback_steps,
1532
+ negative_prompt,
1533
+ prompt_embeds,
1534
+ negative_prompt_embeds,
1535
+ controlnet_conditioning_scale,
1536
+ )
1537
+
1538
+ # 2. Define call parameters
1539
+ if prompt is not None and isinstance(prompt, str):
1540
+ batch_size = 1
1541
+ elif prompt is not None and isinstance(prompt, list):
1542
+ batch_size = len(prompt)
1543
+ else:
1544
+ batch_size = prompt_embeds.shape[0]
1545
+
1546
+ device = self._execution_device
1547
+ self.controlnet.to(device)
1548
+
1549
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1550
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1551
+ # corresponds to doing no classifier free guidance.
1552
+ do_classifier_free_guidance = guidance_scale > 1.0
1553
+
1554
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1555
+
1556
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1557
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1558
+
1559
+ global_pool_conditions = (
1560
+ controlnet.config.global_pool_conditions
1561
+ if isinstance(controlnet, ControlNetModel)
1562
+ else controlnet.nets[0].config.global_pool_conditions
1563
+ )
1564
+ guess_mode = guess_mode or global_pool_conditions
1565
+
1566
+ # 3. Encode input prompt
1567
+ prompt_embeds = self._encode_prompt(
1568
+ prompt,
1569
+ device,
1570
+ num_images_per_prompt,
1571
+ do_classifier_free_guidance,
1572
+ negative_prompt,
1573
+ prompt_embeds=prompt_embeds,
1574
+ negative_prompt_embeds=negative_prompt_embeds,
1575
+ )
1576
+ # 4. Prepare image, and controlnet_conditioning_image
1577
+ image = prepare_image(image)
1578
+
1579
+ # 5. Prepare image
1580
+ if isinstance(controlnet, ControlNetModel):
1581
+ control_image = self.prepare_control_image(
1582
+ image=control_image,
1583
+ width=width,
1584
+ height=height,
1585
+ batch_size=batch_size * num_images_per_prompt,
1586
+ num_images_per_prompt=num_images_per_prompt,
1587
+ device=device,
1588
+ dtype=controlnet.dtype,
1589
+ do_classifier_free_guidance=do_classifier_free_guidance,
1590
+ guess_mode=guess_mode,
1591
+ )
1592
+ elif isinstance(controlnet, MultiControlNetModel):
1593
+ control_images = []
1594
+
1595
+ for control_image_ in control_image:
1596
+ control_image_ = self.prepare_control_image(
1597
+ image=control_image_,
1598
+ width=width,
1599
+ height=height,
1600
+ batch_size=batch_size * num_images_per_prompt,
1601
+ num_images_per_prompt=num_images_per_prompt,
1602
+ device=device,
1603
+ dtype=controlnet.dtype,
1604
+ do_classifier_free_guidance=do_classifier_free_guidance,
1605
+ guess_mode=guess_mode,
1606
+ )
1607
+
1608
+ control_images.append(control_image_)
1609
+
1610
+ control_image = control_images
1611
+ else:
1612
+ assert False
1613
+
1614
+ # 5. Prepare timesteps
1615
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1616
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1617
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1618
+
1619
+ # 6. Prepare latent variables
1620
+ latents = self.prepare_latents(
1621
+ image,
1622
+ latent_timestep,
1623
+ batch_size,
1624
+ num_images_per_prompt,
1625
+ prompt_embeds.dtype,
1626
+ device,
1627
+ generator,
1628
+ )
1629
+ if mask is not None:
1630
+ mask = torch.cat([mask] * batch_size, dim=0)
1631
+
1632
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1633
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1634
+
1635
+ views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)]
1636
+ # value = torch.zeros_like(latents)
1637
+ latents = torch.cat([latents, latents[:, :, :, :self.blend_extend]], dim=-1)
1638
+ control_image = torch.cat([control_image, control_image[:, :, :, :self.blend_extend * self.vae_scale_factor]], dim=-1)
1639
+ if mask is not None:
1640
+ mask = torch.cat([mask] * batch_size, dim=0)
1641
+ mask = torch.cat([mask, mask[:, :, :, :self.blend_extend]], dim=-1)
1642
+
1643
+
1644
+ # 8. Denoising loop
1645
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1646
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1647
+ for i, t in enumerate(timesteps):
1648
+ # expand the latents if we are doing classifier free guidance
1649
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1650
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1651
+ if mask is not None:
1652
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
1653
+ else:
1654
+ mask_input = None
1655
+
1656
+ # controlnet(s) inference
1657
+ if guess_mode and do_classifier_free_guidance:
1658
+ # Infer ControlNet only for the conditional batch.
1659
+ controlnet_latent_model_input = latents
1660
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1661
+ else:
1662
+ controlnet_latent_model_input = latent_model_input
1663
+ controlnet_prompt_embeds = prompt_embeds
1664
+
1665
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1666
+ controlnet_latent_model_input,
1667
+ t,
1668
+ encoder_hidden_states=controlnet_prompt_embeds,
1669
+ controlnet_cond=control_image,
1670
+ conditioning_scale=controlnet_conditioning_scale,
1671
+ guess_mode=guess_mode,
1672
+ return_dict=False,
1673
+ mask=mask_input,
1674
+ )
1675
+
1676
+ if guess_mode and do_classifier_free_guidance:
1677
+ # Infered ControlNet only for the conditional batch.
1678
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1679
+ # add 0 to the unconditional batch to keep it unchanged.
1680
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1681
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1682
+
1683
+ # predict the noise residual
1684
+ noise_pred = self.unet(
1685
+ latent_model_input,
1686
+ t,
1687
+ encoder_hidden_states=prompt_embeds,
1688
+ cross_attention_kwargs=cross_attention_kwargs,
1689
+ down_block_additional_residuals=down_block_res_samples,
1690
+ mid_block_additional_residual=mid_block_res_sample,
1691
+ return_dict=False,
1692
+ )[0]
1693
+
1694
+ # perform guidance
1695
+ if do_classifier_free_guidance:
1696
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1697
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1698
+
1699
+ # compute the previous noisy sample x_t -> x_t-1
1700
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1701
+
1702
+ # call the callback, if provided
1703
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1704
+ progress_bar.update()
1705
+ if callback is not None and i % callback_steps == 0:
1706
+ callback(i, t, latents)
1707
+ # latents = value + 0.0
1708
+ latents = self.blend_h(latents, latents, self.blend_extend)
1709
+ latents = self.blend_h(latents, latents, self.blend_extend)
1710
+ latents = latents[:, :, :, :width // self.vae_scale_factor]
1711
+
1712
+ # If we do sequential model offloading, let's offload unet and controlnet
1713
+ # manually for max memory savings
1714
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1715
+ self.unet.to("cpu")
1716
+ self.controlnet.to("cpu")
1717
+ torch.cuda.empty_cache()
1718
+
1719
+ if not output_type == "latent":
1720
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1721
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1722
+ else:
1723
+ image = latents
1724
+ has_nsfw_concept = None
1725
+
1726
+ if has_nsfw_concept is None:
1727
+ do_denormalize = [True] * image.shape[0]
1728
+ else:
1729
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1730
+
1731
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1732
+
1733
+ # Offload last model to CPU
1734
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1735
+ self.final_offload_hook.offload()
1736
+
1737
+ if not return_dict:
1738
+ return (image, has_nsfw_concept)
1739
+
1740
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
img2panoimg/pipeline_sr.py ADDED
@@ -0,0 +1,1202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © Alibaba, Inc. and its affiliates.
2
+ # The implementation here is modifed based on diffusers.StableDiffusionControlNetImg2ImgPipeline,
3
+ # originally Apache 2.0 License and public available at
4
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
5
+
6
+ import copy
7
+ import re
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from diffusers import (AutoencoderKL, DiffusionPipeline,
15
+ StableDiffusionControlNetImg2ImgPipeline)
16
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
17
+ from diffusers.models import ControlNetModel
18
+ try:
19
+ from diffusers.models.autoencoders.vae import DecoderOutput
20
+ except:
21
+ from diffusers.models.vae import DecoderOutput
22
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
23
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
24
+ from diffusers.utils import logging, replace_example_docstring
25
+ from diffusers.utils.torch_utils import is_compiled_module
26
+
27
+ from transformers import CLIPTokenizer
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> import torch
35
+ >>> from PIL import Image
36
+ >>> from txt2panoimage.pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
37
+ >>> base_model_path = "models/sr-base"
38
+ >>> controlnet_path = "models/sr-control"
39
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
40
+ >>> pipe = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(base_model_path, controlnet=controlnet,
41
+ ... torch_dtype=torch.float16)
42
+ >>> pipe.vae.enable_tiling()
43
+ >>> # remove following line if xformers is not installed
44
+ >>> pipe.enable_xformers_memory_efficient_attention()
45
+ >>> pipe.enable_model_cpu_offload()
46
+ >>> input_image_path = 'data/test.png'
47
+ >>> image = Image.open(input_image_path)
48
+ >>> image = pipe(
49
+ ... "futuristic-looking woman",
50
+ ... num_inference_steps=20,
51
+ ... image=image,
52
+ ... height=768,
53
+ ... width=1536,
54
+ ... control_image=image,
55
+ ... ).images[0]
56
+
57
+ ```
58
+ """
59
+
60
+ re_attention = re.compile(
61
+ r"""
62
+ \\\(|
63
+ \\\)|
64
+ \\\[|
65
+ \\]|
66
+ \\\\|
67
+ \\|
68
+ \(|
69
+ \[|
70
+ :([+-]?[.\d]+)\)|
71
+ \)|
72
+ ]|
73
+ [^\\()\[\]:]+|
74
+ :
75
+ """,
76
+ re.X,
77
+ )
78
+
79
+
80
+ def parse_prompt_attention(text):
81
+ """
82
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
83
+ Accepted tokens are:
84
+ (abc) - increases attention to abc by a multiplier of 1.1
85
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
86
+ [abc] - decreases attention to abc by a multiplier of 1.1
87
+ """
88
+
89
+ res = []
90
+ round_brackets = []
91
+ square_brackets = []
92
+
93
+ round_bracket_multiplier = 1.1
94
+ square_bracket_multiplier = 1 / 1.1
95
+
96
+ def multiply_range(start_position, multiplier):
97
+ for p in range(start_position, len(res)):
98
+ res[p][1] *= multiplier
99
+
100
+ for m in re_attention.finditer(text):
101
+ text = m.group(0)
102
+ weight = m.group(1)
103
+
104
+ if text.startswith('\\'):
105
+ res.append([text[1:], 1.0])
106
+ elif text == '(':
107
+ round_brackets.append(len(res))
108
+ elif text == '[':
109
+ square_brackets.append(len(res))
110
+ elif weight is not None and len(round_brackets) > 0:
111
+ multiply_range(round_brackets.pop(), float(weight))
112
+ elif text == ')' and len(round_brackets) > 0:
113
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
114
+ elif text == ']' and len(square_brackets) > 0:
115
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
116
+ else:
117
+ res.append([text, 1.0])
118
+
119
+ for pos in round_brackets:
120
+ multiply_range(pos, round_bracket_multiplier)
121
+
122
+ for pos in square_brackets:
123
+ multiply_range(pos, square_bracket_multiplier)
124
+
125
+ if len(res) == 0:
126
+ res = [['', 1.0]]
127
+
128
+ # merge runs of identical weights
129
+ i = 0
130
+ while i + 1 < len(res):
131
+ if res[i][1] == res[i + 1][1]:
132
+ res[i][0] += res[i + 1][0]
133
+ res.pop(i + 1)
134
+ else:
135
+ i += 1
136
+
137
+ return res
138
+
139
+
140
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str],
141
+ max_length: int):
142
+ r"""
143
+ Tokenize a list of prompts and return its tokens with weights of each token.
144
+
145
+ No padding, starting or ending token is included.
146
+ """
147
+ tokens = []
148
+ weights = []
149
+ truncated = False
150
+ for text in prompt:
151
+ texts_and_weights = parse_prompt_attention(text)
152
+ text_token = []
153
+ text_weight = []
154
+ for word, weight in texts_and_weights:
155
+ # tokenize and discard the starting and the ending token
156
+ token = pipe.tokenizer(word).input_ids[1:-1]
157
+ text_token += token
158
+ # copy the weight by length of token
159
+ text_weight += [weight] * len(token)
160
+ # stop if the text is too long (longer than truncation limit)
161
+ if len(text_token) > max_length:
162
+ truncated = True
163
+ break
164
+ # truncate
165
+ if len(text_token) > max_length:
166
+ truncated = True
167
+ text_token = text_token[:max_length]
168
+ text_weight = text_weight[:max_length]
169
+ tokens.append(text_token)
170
+ weights.append(text_weight)
171
+ if truncated:
172
+ logger.warning(
173
+ 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples'
174
+ )
175
+ return tokens, weights
176
+
177
+
178
+ def pad_tokens_and_weights(tokens,
179
+ weights,
180
+ max_length,
181
+ bos,
182
+ eos,
183
+ pad,
184
+ no_boseos_middle=True,
185
+ chunk_length=77):
186
+ r"""
187
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
188
+ """
189
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
190
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
191
+ for i in range(len(tokens)):
192
+ tokens[i] = [
193
+ bos
194
+ ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
195
+ if no_boseos_middle:
196
+ weights[i] = [1.0] + weights[i] + [1.0] * (
197
+ max_length - 1 - len(weights[i]))
198
+ else:
199
+ w = []
200
+ if len(weights[i]) == 0:
201
+ w = [1.0] * weights_length
202
+ else:
203
+ for j in range(max_embeddings_multiples):
204
+ w.append(1.0) # weight for starting token in this chunk
205
+ w += weights[i][j * (chunk_length - 2):min(
206
+ len(weights[i]), (j + 1) * (chunk_length - 2))]
207
+ w.append(1.0) # weight for ending token in this chunk
208
+ w += [1.0] * (weights_length - len(w))
209
+ weights[i] = w[:]
210
+
211
+ return tokens, weights
212
+
213
+
214
+ def get_unweighted_text_embeddings(
215
+ pipe: DiffusionPipeline,
216
+ text_input: torch.Tensor,
217
+ chunk_length: int,
218
+ no_boseos_middle: Optional[bool] = True,
219
+ ):
220
+ """
221
+ When the length of tokens is a multiple of the capacity of the text encoder,
222
+ it should be split into chunks and sent to the text encoder individually.
223
+ """
224
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
225
+ if max_embeddings_multiples > 1:
226
+ text_embeddings = []
227
+ for i in range(max_embeddings_multiples):
228
+ # extract the i-th chunk
229
+ text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1)
230
+ * (chunk_length - 2) + 2].clone()
231
+
232
+ # cover the head and the tail by the starting and the ending tokens
233
+ text_input_chunk[:, 0] = text_input[0, 0]
234
+ text_input_chunk[:, -1] = text_input[0, -1]
235
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
236
+
237
+ if no_boseos_middle:
238
+ if i == 0:
239
+ # discard the ending token
240
+ text_embedding = text_embedding[:, :-1]
241
+ elif i == max_embeddings_multiples - 1:
242
+ # discard the starting token
243
+ text_embedding = text_embedding[:, 1:]
244
+ else:
245
+ # discard both starting and ending tokens
246
+ text_embedding = text_embedding[:, 1:-1]
247
+
248
+ text_embeddings.append(text_embedding)
249
+ text_embeddings = torch.concat(text_embeddings, axis=1)
250
+ else:
251
+ text_embeddings = pipe.text_encoder(text_input)[0]
252
+ return text_embeddings
253
+
254
+
255
+ def get_weighted_text_embeddings(
256
+ pipe: DiffusionPipeline,
257
+ prompt: Union[str, List[str]],
258
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
259
+ max_embeddings_multiples: Optional[int] = 3,
260
+ no_boseos_middle: Optional[bool] = False,
261
+ skip_parsing: Optional[bool] = False,
262
+ skip_weighting: Optional[bool] = False,
263
+ ):
264
+ r"""
265
+ Prompts can be assigned with local weights using brackets. For example,
266
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
267
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
268
+
269
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
270
+
271
+ Args:
272
+ pipe (`DiffusionPipeline`):
273
+ Pipe to provide access to the tokenizer and the text encoder.
274
+ prompt (`str` or `List[str]`):
275
+ The prompt or prompts to guide the image generation.
276
+ uncond_prompt (`str` or `List[str]`):
277
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
278
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
279
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
280
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
281
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
282
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
283
+ ending token in each of the chunk in the middle.
284
+ skip_parsing (`bool`, *optional*, defaults to `False`):
285
+ Skip the parsing of brackets.
286
+ skip_weighting (`bool`, *optional*, defaults to `False`):
287
+ Skip the weighting. When the parsing is skipped, it is forced True.
288
+ """
289
+ max_length = (pipe.tokenizer.model_max_length
290
+ - 2) * max_embeddings_multiples + 2
291
+ if isinstance(prompt, str):
292
+ prompt = [prompt]
293
+
294
+ if not skip_parsing:
295
+ prompt_tokens, prompt_weights = get_prompts_with_weights(
296
+ pipe, prompt, max_length - 2)
297
+ if uncond_prompt is not None:
298
+ if isinstance(uncond_prompt, str):
299
+ uncond_prompt = [uncond_prompt]
300
+ uncond_tokens, uncond_weights = get_prompts_with_weights(
301
+ pipe, uncond_prompt, max_length - 2)
302
+ else:
303
+ prompt_tokens = [
304
+ token[1:-1] for token in pipe.tokenizer(
305
+ prompt, max_length=max_length, truncation=True).input_ids
306
+ ]
307
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
308
+ if uncond_prompt is not None:
309
+ if isinstance(uncond_prompt, str):
310
+ uncond_prompt = [uncond_prompt]
311
+ uncond_tokens = [
312
+ token[1:-1] for token in pipe.tokenizer(
313
+ uncond_prompt, max_length=max_length,
314
+ truncation=True).input_ids
315
+ ]
316
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
317
+
318
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
319
+ max_length = max([len(token) for token in prompt_tokens])
320
+ if uncond_prompt is not None:
321
+ max_length = max(max_length,
322
+ max([len(token) for token in uncond_tokens]))
323
+
324
+ max_embeddings_multiples = min(
325
+ max_embeddings_multiples,
326
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
327
+ )
328
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
329
+ max_length = (pipe.tokenizer.model_max_length
330
+ - 2) * max_embeddings_multiples + 2
331
+
332
+ # pad the length of tokens and weights
333
+ bos = pipe.tokenizer.bos_token_id
334
+ eos = pipe.tokenizer.eos_token_id
335
+ pad = getattr(pipe.tokenizer, 'pad_token_id', eos)
336
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
337
+ prompt_tokens,
338
+ prompt_weights,
339
+ max_length,
340
+ bos,
341
+ eos,
342
+ pad,
343
+ no_boseos_middle=no_boseos_middle,
344
+ chunk_length=pipe.tokenizer.model_max_length,
345
+ )
346
+ prompt_tokens = torch.tensor(
347
+ prompt_tokens, dtype=torch.long, device=pipe.device)
348
+ if uncond_prompt is not None:
349
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
350
+ uncond_tokens,
351
+ uncond_weights,
352
+ max_length,
353
+ bos,
354
+ eos,
355
+ pad,
356
+ no_boseos_middle=no_boseos_middle,
357
+ chunk_length=pipe.tokenizer.model_max_length,
358
+ )
359
+ uncond_tokens = torch.tensor(
360
+ uncond_tokens, dtype=torch.long, device=pipe.device)
361
+
362
+ # get the embeddings
363
+ text_embeddings = get_unweighted_text_embeddings(
364
+ pipe,
365
+ prompt_tokens,
366
+ pipe.tokenizer.model_max_length,
367
+ no_boseos_middle=no_boseos_middle,
368
+ )
369
+ prompt_weights = torch.tensor(
370
+ prompt_weights,
371
+ dtype=text_embeddings.dtype,
372
+ device=text_embeddings.device)
373
+ if uncond_prompt is not None:
374
+ uncond_embeddings = get_unweighted_text_embeddings(
375
+ pipe,
376
+ uncond_tokens,
377
+ pipe.tokenizer.model_max_length,
378
+ no_boseos_middle=no_boseos_middle,
379
+ )
380
+ uncond_weights = torch.tensor(
381
+ uncond_weights,
382
+ dtype=uncond_embeddings.dtype,
383
+ device=uncond_embeddings.device)
384
+
385
+ # assign weights to the prompts and normalize in the sense of mean
386
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
387
+ if (not skip_parsing) and (not skip_weighting):
388
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
389
+ text_embeddings.dtype)
390
+ text_embeddings *= prompt_weights.unsqueeze(-1)
391
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
392
+ text_embeddings.dtype)
393
+ text_embeddings *= (previous_mean
394
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
395
+ if uncond_prompt is not None:
396
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
397
+ uncond_embeddings.dtype)
398
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
399
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
400
+ uncond_embeddings.dtype)
401
+ uncond_embeddings *= (previous_mean
402
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
403
+
404
+ if uncond_prompt is not None:
405
+ return text_embeddings, uncond_embeddings
406
+ return text_embeddings, None
407
+
408
+
409
+ def prepare_image(image):
410
+ if isinstance(image, torch.Tensor):
411
+ # Batch single image
412
+ if image.ndim == 3:
413
+ image = image.unsqueeze(0)
414
+
415
+ image = image.to(dtype=torch.float32)
416
+ else:
417
+ # preprocess image
418
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
419
+ image = [image]
420
+
421
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
422
+ image = [np.array(i.convert('RGB'))[None, :] for i in image]
423
+ image = np.concatenate(image, axis=0)
424
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
425
+ image = np.concatenate([i[None, :] for i in image], axis=0)
426
+
427
+ image = image.transpose(0, 3, 1, 2)
428
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
429
+
430
+ return image
431
+
432
+
433
+ class StableDiffusionControlNetImg2ImgPanoPipeline(
434
+ StableDiffusionControlNetImg2ImgPipeline):
435
+ r"""
436
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
437
+
438
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
439
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
440
+
441
+ In addition the pipeline inherits the following loading methods:
442
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
443
+
444
+ Args:
445
+ vae ([`AutoencoderKL`]):
446
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
447
+ text_encoder ([`CLIPTextModel`]):
448
+ Frozen text-encoder. Stable Diffusion uses the text portion of
449
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
450
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
451
+ tokenizer (`CLIPTokenizer`):
452
+ Tokenizer of class
453
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/
454
+ model_doc/clip#transformers.CLIPTokenizer).
455
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
456
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
457
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
458
+ as a list, the outputs from each ControlNet are added together to create one combined additional
459
+ conditioning.
460
+ scheduler ([`SchedulerMixin`]):
461
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
462
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
463
+ safety_checker ([`StableDiffusionSafetyChecker`]):
464
+ Classification module that estimates whether generated images could be considered offensive or harmful.
465
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
466
+ feature_extractor ([`CLIPImageProcessor`]):
467
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
468
+ """
469
+ _optional_components = ['safety_checker', 'feature_extractor']
470
+
471
+ def check_inputs(
472
+ self,
473
+ prompt,
474
+ image,
475
+ height,
476
+ width,
477
+ callback_steps,
478
+ negative_prompt=None,
479
+ prompt_embeds=None,
480
+ negative_prompt_embeds=None,
481
+ controlnet_conditioning_scale=1.0,
482
+ ):
483
+ if height % 8 != 0 or width % 8 != 0:
484
+ raise ValueError(
485
+ f'`height` and `width` have to be divisible by 8 but are {height} and {width}.'
486
+ )
487
+ condition_1 = callback_steps is not None
488
+ condition_2 = not isinstance(callback_steps,
489
+ int) or callback_steps <= 0
490
+ if (callback_steps is None) or (condition_1 and condition_2):
491
+ raise ValueError(
492
+ f'`callback_steps` has to be a positive integer but is {callback_steps} of type'
493
+ f' {type(callback_steps)}.')
494
+ if prompt is not None and prompt_embeds is not None:
495
+ raise ValueError(
496
+ f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to'
497
+ ' only forward one of the two.')
498
+ elif prompt is None and prompt_embeds is None:
499
+ raise ValueError(
500
+ 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.'
501
+ )
502
+ elif prompt is not None and (not isinstance(prompt, str)
503
+ and not isinstance(prompt, list)):
504
+ raise ValueError(
505
+ f'`prompt` has to be of type `str` or `list` but is {type(prompt)}'
506
+ )
507
+ if negative_prompt is not None and negative_prompt_embeds is not None:
508
+ raise ValueError(
509
+ f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:'
510
+ f' {negative_prompt_embeds}. Please make sure to only forward one of the two.'
511
+ )
512
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
513
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
514
+ raise ValueError(
515
+ '`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but'
516
+ f' got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`'
517
+ f' {negative_prompt_embeds.shape}.')
518
+ # `prompt` needs more sophisticated handling when there are multiple
519
+ # conditionings.
520
+ if isinstance(self.controlnet, MultiControlNetModel):
521
+ if isinstance(prompt, list):
522
+ logger.warning(
523
+ f'You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}'
524
+ ' prompts. The conditionings will be fixed across the prompts.'
525
+ )
526
+ # Check `image`
527
+ is_compiled = hasattr(
528
+ F, 'scaled_dot_product_attention') and isinstance(
529
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule)
530
+ if (isinstance(self.controlnet, ControlNetModel) or is_compiled
531
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)):
532
+ self.check_image(image, prompt, prompt_embeds)
533
+ elif (isinstance(self.controlnet, MultiControlNetModel) or is_compiled
534
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)):
535
+ if not isinstance(image, list):
536
+ raise TypeError(
537
+ 'For multiple controlnets: `image` must be type `list`')
538
+ # When `image` is a nested list:
539
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
540
+ elif any(isinstance(i, list) for i in image):
541
+ raise ValueError(
542
+ 'A single batch of multiple conditionings are supported at the moment.'
543
+ )
544
+ elif len(image) != len(self.controlnet.nets):
545
+ raise ValueError(
546
+ 'For multiple controlnets: `image` must have the same length as the number of controlnets.'
547
+ )
548
+ for image_ in image:
549
+ self.check_image(image_, prompt, prompt_embeds)
550
+ else:
551
+ assert False
552
+ # Check `controlnet_conditioning_scale`
553
+ if (isinstance(self.controlnet, ControlNetModel) or is_compiled
554
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)):
555
+ if not isinstance(controlnet_conditioning_scale, float):
556
+ raise TypeError(
557
+ 'For single controlnet: `controlnet_conditioning_scale` must be type `float`.'
558
+ )
559
+ elif (isinstance(self.controlnet, MultiControlNetModel) or is_compiled
560
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)):
561
+ if isinstance(controlnet_conditioning_scale, list):
562
+ if any(
563
+ isinstance(i, list)
564
+ for i in controlnet_conditioning_scale):
565
+ raise ValueError(
566
+ 'A single batch of multiple conditionings are supported at the moment.'
567
+ )
568
+ elif isinstance(
569
+ controlnet_conditioning_scale,
570
+ list) and len(controlnet_conditioning_scale) != len(
571
+ self.controlnet.nets):
572
+ raise ValueError(
573
+ 'For multiple controlnets: When `controlnet_conditioning_scale` '
574
+ 'is specified as `list`, it must have'
575
+ ' the same length as the number of controlnets')
576
+ else:
577
+ assert False
578
+
579
+ def _default_height_width(self, height, width, image):
580
+ # NOTE: It is possible that a list of images have different
581
+ # dimensions for each image, so just checking the first image
582
+ # is not _exactly_ correct, but it is simple.
583
+ while isinstance(image, list):
584
+ image = image[0]
585
+ if height is None:
586
+ if isinstance(image, PIL.Image.Image):
587
+ height = image.height
588
+ elif isinstance(image, torch.Tensor):
589
+ height = image.shape[2]
590
+ height = (height // 8) * 8 # round down to nearest multiple of 8
591
+ if width is None:
592
+ if isinstance(image, PIL.Image.Image):
593
+ width = image.width
594
+ elif isinstance(image, torch.Tensor):
595
+ width = image.shape[3]
596
+ width = (width // 8) * 8 # round down to nearest multiple of 8
597
+ return height, width
598
+
599
+ def _encode_prompt(
600
+ self,
601
+ prompt,
602
+ device,
603
+ num_images_per_prompt,
604
+ do_classifier_free_guidance,
605
+ negative_prompt=None,
606
+ max_embeddings_multiples=3,
607
+ prompt_embeds: Optional[torch.FloatTensor] = None,
608
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ lora_scale: Optional[float] = None,
610
+ ):
611
+ r"""
612
+ Encodes the prompt into text encoder hidden states.
613
+
614
+ Args:
615
+ prompt (`str` or `list(int)`):
616
+ prompt to be encoded
617
+ device: (`torch.device`):
618
+ torch device
619
+ num_images_per_prompt (`int`):
620
+ number of images that should be generated per prompt
621
+ do_classifier_free_guidance (`bool`):
622
+ whether to use classifier free guidance or not
623
+ negative_prompt (`str` or `List[str]`):
624
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
625
+ if `guidance_scale` is less than `1`).
626
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
627
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
628
+ """
629
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
630
+ self._lora_scale = lora_scale
631
+
632
+ if prompt is not None and isinstance(prompt, str):
633
+ batch_size = 1
634
+ elif prompt is not None and isinstance(prompt, list):
635
+ batch_size = len(prompt)
636
+ else:
637
+ batch_size = prompt_embeds.shape[0]
638
+
639
+ if negative_prompt_embeds is None:
640
+ if negative_prompt is None:
641
+ negative_prompt = [''] * batch_size
642
+ elif isinstance(negative_prompt, str):
643
+ negative_prompt = [negative_prompt] * batch_size
644
+ if batch_size != len(negative_prompt):
645
+ raise ValueError(
646
+ f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
647
+ f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
648
+ ' the batch size of `prompt`.')
649
+ if prompt_embeds is None or negative_prompt_embeds is None:
650
+ if isinstance(self, TextualInversionLoaderMixin):
651
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
652
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
653
+ negative_prompt = self.maybe_convert_prompt(
654
+ negative_prompt, self.tokenizer)
655
+
656
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
657
+ pipe=self,
658
+ prompt=prompt,
659
+ uncond_prompt=negative_prompt
660
+ if do_classifier_free_guidance else None,
661
+ max_embeddings_multiples=max_embeddings_multiples,
662
+ )
663
+ if prompt_embeds is None:
664
+ prompt_embeds = prompt_embeds1
665
+ if negative_prompt_embeds is None:
666
+ negative_prompt_embeds = negative_prompt_embeds1
667
+
668
+ bs_embed, seq_len, _ = prompt_embeds.shape
669
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
670
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
671
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
672
+ seq_len, -1)
673
+
674
+ if do_classifier_free_guidance:
675
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
676
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
677
+ 1, num_images_per_prompt, 1)
678
+ negative_prompt_embeds = negative_prompt_embeds.view(
679
+ bs_embed * num_images_per_prompt, seq_len, -1)
680
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
681
+
682
+ return prompt_embeds
683
+
684
+ def denoise_latents(self, latents, t, prompt_embeds, control_image,
685
+ controlnet_conditioning_scale, guess_mode,
686
+ cross_attention_kwargs, do_classifier_free_guidance,
687
+ guidance_scale, extra_step_kwargs,
688
+ views_scheduler_status):
689
+ # expand the latents if we are doing classifier free guidance
690
+ latent_model_input = torch.cat(
691
+ [latents] * 2) if do_classifier_free_guidance else latents
692
+ self.scheduler.__dict__.update(views_scheduler_status[0])
693
+ latent_model_input = self.scheduler.scale_model_input(
694
+ latent_model_input, t)
695
+ # controlnet(s) inference
696
+ if guess_mode and do_classifier_free_guidance:
697
+ # Infer ControlNet only for the conditional batch.
698
+ controlnet_latent_model_input = latents
699
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
700
+ else:
701
+ controlnet_latent_model_input = latent_model_input
702
+ controlnet_prompt_embeds = prompt_embeds
703
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
704
+ controlnet_latent_model_input,
705
+ t,
706
+ encoder_hidden_states=controlnet_prompt_embeds,
707
+ controlnet_cond=control_image,
708
+ conditioning_scale=controlnet_conditioning_scale,
709
+ guess_mode=guess_mode,
710
+ return_dict=False,
711
+ )
712
+ if guess_mode and do_classifier_free_guidance:
713
+ # Infered ControlNet only for the conditional batch.
714
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
715
+ # add 0 to the unconditional batch to keep it unchanged.
716
+ down_block_res_samples = [
717
+ torch.cat([torch.zeros_like(d), d])
718
+ for d in down_block_res_samples
719
+ ]
720
+ mid_block_res_sample = torch.cat(
721
+ [torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
722
+ # predict the noise residual
723
+ noise_pred = self.unet(
724
+ latent_model_input,
725
+ t,
726
+ encoder_hidden_states=prompt_embeds,
727
+ cross_attention_kwargs=cross_attention_kwargs,
728
+ down_block_additional_residuals=down_block_res_samples,
729
+ mid_block_additional_residual=mid_block_res_sample,
730
+ return_dict=False,
731
+ )[0]
732
+ # perform guidance
733
+ if do_classifier_free_guidance:
734
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
735
+ noise_pred = noise_pred_uncond + guidance_scale * (
736
+ noise_pred_text - noise_pred_uncond)
737
+ # compute the previous noisy sample x_t -> x_t-1
738
+ latents = self.scheduler.step(
739
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
740
+ return latents
741
+
742
+ def blend_v(self, a, b, blend_extent):
743
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
744
+ for y in range(blend_extent):
745
+ b[:, :,
746
+ y, :] = a[:, :, -blend_extent
747
+ + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
748
+ y / blend_extent)
749
+ return b
750
+
751
+ def blend_h(self, a, b, blend_extent):
752
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
753
+ for x in range(blend_extent):
754
+ b[:, :, :, x] = a[:, :, :, -blend_extent
755
+ + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
756
+ x / blend_extent)
757
+ return b
758
+
759
+ def get_blocks(self, latents, control_image, tile_latent_min_size,
760
+ overlap_size):
761
+ rows_latents = []
762
+ rows_control_images = []
763
+ for i in range(0, latents.shape[2] - overlap_size, overlap_size):
764
+ row_latents = []
765
+ row_control_images = []
766
+ for j in range(0, latents.shape[3] - overlap_size, overlap_size):
767
+ latents_input = latents[:, :, i:i + tile_latent_min_size,
768
+ j:j + tile_latent_min_size]
769
+ c_start_i = self.vae_scale_factor * i
770
+ c_end_i = self.vae_scale_factor * (i + tile_latent_min_size)
771
+ c_start_j = self.vae_scale_factor * j
772
+ c_end_j = self.vae_scale_factor * (j + tile_latent_min_size)
773
+ control_image_input = control_image[:, :, c_start_i:c_end_i,
774
+ c_start_j:c_end_j]
775
+ row_latents.append(latents_input)
776
+ row_control_images.append(control_image_input)
777
+ rows_latents.append(row_latents)
778
+ rows_control_images.append(row_control_images)
779
+ return rows_latents, rows_control_images
780
+
781
+ @torch.no_grad()
782
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
783
+ def __call__(
784
+ self,
785
+ prompt: Union[str, List[str]] = None,
786
+ image: Union[torch.FloatTensor, PIL.Image.Image,
787
+ List[torch.FloatTensor], List[PIL.Image.Image]] = None,
788
+ control_image: Union[torch.FloatTensor, PIL.Image.Image,
789
+ List[torch.FloatTensor],
790
+ List[PIL.Image.Image]] = None,
791
+ height: Optional[int] = None,
792
+ width: Optional[int] = None,
793
+ strength: float = 0.8,
794
+ num_inference_steps: int = 50,
795
+ guidance_scale: float = 7.5,
796
+ negative_prompt: Optional[Union[str, List[str]]] = None,
797
+ num_images_per_prompt: Optional[int] = 1,
798
+ eta: float = 0.0,
799
+ generator: Optional[Union[torch.Generator,
800
+ List[torch.Generator]]] = None,
801
+ latents: Optional[torch.FloatTensor] = None,
802
+ prompt_embeds: Optional[torch.FloatTensor] = None,
803
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
804
+ output_type: Optional[str] = 'pil',
805
+ return_dict: bool = True,
806
+ callback: Optional[Callable[[int, int, torch.FloatTensor],
807
+ None]] = None,
808
+ callback_steps: int = 1,
809
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
810
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
811
+ guess_mode: bool = False,
812
+ context_size: int = 768,
813
+ ):
814
+ r"""
815
+ Function invoked when calling the pipeline for generation.
816
+
817
+ Args:
818
+ prompt (`str` or `List[str]`, *optional*):
819
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
820
+ instead.
821
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
822
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
823
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
824
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
825
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
826
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
827
+ specified in init, images must be passed as a list such that each element of the list can be correctly
828
+ batched for input to a single controlnet.
829
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
830
+ The height in pixels of the generated image.
831
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
832
+ The width in pixels of the generated image.
833
+ num_inference_steps (`int`, *optional*, defaults to 50):
834
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
835
+ expense of slower inference.
836
+ guidance_scale (`float`, *optional*, defaults to 7.5):
837
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
838
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
839
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
840
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
841
+ usually at the expense of lower image quality.
842
+ negative_prompt (`str` or `List[str]`, *optional*):
843
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
844
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
845
+ less than `1`).
846
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
847
+ The number of images to generate per prompt.
848
+ eta (`float`, *optional*, defaults to 0.0):
849
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
850
+ [`schedulers.DDIMScheduler`], will be ignored for others.
851
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
852
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
853
+ to make generation deterministic.
854
+ latents (`torch.FloatTensor`, *optional*):
855
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
856
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
857
+ tensor will ge generated by sampling using the supplied random `generator`.
858
+ prompt_embeds (`torch.FloatTensor`, *optional*):
859
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
860
+ provided, text embeddings will be generated from `prompt` input argument.
861
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
862
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
863
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
864
+ argument.
865
+ output_type (`str`, *optional*, defaults to `"pil"`):
866
+ The output format of the generate image. Choose between
867
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
868
+ return_dict (`bool`, *optional*, defaults to `True`):
869
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
870
+ plain tuple.
871
+ callback (`Callable`, *optional*):
872
+ A function that will be called every `callback_steps` steps during inference. The function will be
873
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
874
+ callback_steps (`int`, *optional*, defaults to 1):
875
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
876
+ called at every step.
877
+ cross_attention_kwargs (`dict`, *optional*):
878
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
879
+ `self.processor` in
880
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/
881
+ src/diffusers/models/cross_attention.py).
882
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
883
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
884
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
885
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
886
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
887
+ guess_mode (`bool`, *optional*, defaults to `False`):
888
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
889
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
890
+ context_size ('int', *optional*, defaults to '768'):
891
+ tiled size when denoise the latents.
892
+
893
+ Examples:
894
+
895
+ Returns:
896
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
897
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
898
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
899
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
900
+ (nsfw) content, according to the `safety_checker`.
901
+ """
902
+
903
+ def tiled_decode(
904
+ self,
905
+ z: torch.FloatTensor,
906
+ return_dict: bool = True
907
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
908
+ r"""Decode a batch of images using a tiled decoder.
909
+
910
+ Args:
911
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
912
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled
913
+ decoding is: different from non-tiled decoding due to each tile using a different decoder.
914
+ To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
915
+ You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
916
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
917
+ `True`):
918
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
919
+ """
920
+ _tile_overlap_factor = 1 - self.tile_overlap_factor
921
+ overlap_size = int(self.tile_latent_min_size
922
+ * _tile_overlap_factor)
923
+ blend_extent = int(self.tile_sample_min_size
924
+ * self.tile_overlap_factor)
925
+ row_limit = self.tile_sample_min_size - blend_extent
926
+ w = z.shape[3]
927
+ z = torch.cat([z, z[:, :, :, :w // 4]], dim=-1)
928
+ # Split z into overlapping 64x64 tiles and decode them separately.
929
+ # The tiles have an overlap to avoid seams between tiles.
930
+
931
+ rows = []
932
+ for i in range(0, z.shape[2], overlap_size):
933
+ row = []
934
+ tile = z[:, :, i:i + self.tile_latent_min_size, :]
935
+ tile = self.post_quant_conv(tile)
936
+ decoded = self.decoder(tile)
937
+ vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
938
+ row.append(decoded)
939
+ rows.append(row)
940
+ result_rows = []
941
+ for i, row in enumerate(rows):
942
+ result_row = []
943
+ for j, tile in enumerate(row):
944
+ # blend the above tile and the left tile
945
+ # to the current tile and add the current tile to the result row
946
+ if i > 0:
947
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
948
+ if j > 0:
949
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
950
+ result_row.append(
951
+ self.blend_h(
952
+ tile[:, :, :row_limit, w * vae_scale_factor:],
953
+ tile[:, :, :row_limit, :w * vae_scale_factor],
954
+ tile.shape[-1] - w * vae_scale_factor))
955
+ result_rows.append(torch.cat(result_row, dim=3))
956
+
957
+ dec = torch.cat(result_rows, dim=2)
958
+ if not return_dict:
959
+ return (dec, )
960
+
961
+ return DecoderOutput(sample=dec)
962
+
963
+ self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
964
+
965
+ # 0. Default height and width to unet
966
+ height, width = self._default_height_width(height, width, image)
967
+
968
+ # 1. Check inputs. Raise error if not correct
969
+ self.check_inputs(
970
+ prompt,
971
+ control_image,
972
+ height,
973
+ width,
974
+ callback_steps,
975
+ negative_prompt,
976
+ prompt_embeds,
977
+ negative_prompt_embeds,
978
+ controlnet_conditioning_scale,
979
+ )
980
+
981
+ # 2. Define call parameters
982
+ if prompt is not None and isinstance(prompt, str):
983
+ batch_size = 1
984
+ elif prompt is not None and isinstance(prompt, list):
985
+ batch_size = len(prompt)
986
+ else:
987
+ batch_size = prompt_embeds.shape[0]
988
+
989
+ device = self._execution_device
990
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
991
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
992
+ # corresponds to doing no classifier free guidance.
993
+ do_classifier_free_guidance = guidance_scale > 1.0
994
+
995
+ controlnet = self.controlnet._orig_mod if is_compiled_module(
996
+ self.controlnet) else self.controlnet
997
+
998
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
999
+ controlnet_conditioning_scale, float):
1000
+ controlnet_conditioning_scale = [controlnet_conditioning_scale
1001
+ ] * len(controlnet.nets)
1002
+
1003
+ global_pool_conditions = (
1004
+ controlnet.config.global_pool_conditions if isinstance(
1005
+ controlnet, ControlNetModel) else
1006
+ controlnet.nets[0].config.global_pool_conditions)
1007
+ guess_mode = guess_mode or global_pool_conditions
1008
+
1009
+ # 3. Encode input prompt
1010
+ prompt_embeds = self._encode_prompt(
1011
+ prompt,
1012
+ device,
1013
+ num_images_per_prompt,
1014
+ do_classifier_free_guidance,
1015
+ negative_prompt,
1016
+ prompt_embeds=prompt_embeds,
1017
+ negative_prompt_embeds=negative_prompt_embeds,
1018
+ )
1019
+ # 4. Prepare image, and controlnet_conditioning_image
1020
+ image = prepare_image(image)
1021
+
1022
+ # 5. Prepare image
1023
+ if isinstance(controlnet, ControlNetModel):
1024
+ control_image = self.prepare_control_image(
1025
+ image=control_image,
1026
+ width=width,
1027
+ height=height,
1028
+ batch_size=batch_size * num_images_per_prompt,
1029
+ num_images_per_prompt=num_images_per_prompt,
1030
+ device=device,
1031
+ dtype=controlnet.dtype,
1032
+ do_classifier_free_guidance=do_classifier_free_guidance,
1033
+ guess_mode=guess_mode,
1034
+ )
1035
+ elif isinstance(controlnet, MultiControlNetModel):
1036
+ control_images = []
1037
+
1038
+ for control_image_ in control_image:
1039
+ control_image_ = self.prepare_control_image(
1040
+ image=control_image_,
1041
+ width=width,
1042
+ height=height,
1043
+ batch_size=batch_size * num_images_per_prompt,
1044
+ num_images_per_prompt=num_images_per_prompt,
1045
+ device=device,
1046
+ dtype=controlnet.dtype,
1047
+ do_classifier_free_guidance=do_classifier_free_guidance,
1048
+ guess_mode=guess_mode,
1049
+ )
1050
+
1051
+ control_images.append(control_image_)
1052
+
1053
+ control_image = control_images
1054
+ else:
1055
+ assert False
1056
+
1057
+ # 5. Prepare timesteps
1058
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1059
+ timesteps, num_inference_steps = self.get_timesteps(
1060
+ num_inference_steps, strength, device)
1061
+ latent_timestep = timesteps[:1].repeat(batch_size
1062
+ * num_images_per_prompt)
1063
+
1064
+ # 6. Prepare latent variables
1065
+ latents = self.prepare_latents(
1066
+ image,
1067
+ latent_timestep,
1068
+ batch_size,
1069
+ num_images_per_prompt,
1070
+ prompt_embeds.dtype,
1071
+ device,
1072
+ generator,
1073
+ )
1074
+
1075
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1076
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1077
+
1078
+ views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)]
1079
+ # value = torch.zeros_like(latents)
1080
+ _, _, height, width = control_image.size()
1081
+ tile_latent_min_size = context_size // self.vae_scale_factor
1082
+ tile_overlap_factor = 0.5
1083
+ overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor))
1084
+ blend_extent = int(tile_latent_min_size * tile_overlap_factor)
1085
+ row_limit = tile_latent_min_size - blend_extent
1086
+ w = latents.shape[3]
1087
+ latents = torch.cat([latents, latents[:, :, :, :overlap_size]], dim=-1)
1088
+ control_image_extend = control_image[:, :, :, :overlap_size
1089
+ * self.vae_scale_factor]
1090
+ control_image = torch.cat([control_image, control_image_extend],
1091
+ dim=-1)
1092
+
1093
+ # 8. Denoising loop
1094
+ num_warmup_steps = len(
1095
+ timesteps) - num_inference_steps * self.scheduler.order
1096
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1097
+ for i, t in enumerate(timesteps):
1098
+ latents_input, control_image_input = self.get_blocks(
1099
+ latents, control_image, tile_latent_min_size, overlap_size)
1100
+ rows = []
1101
+ for latents_input_, control_image_input_ in zip(
1102
+ latents_input, control_image_input):
1103
+ num_block = len(latents_input_)
1104
+ # get batched latents_input
1105
+ latents_input_ = torch.cat(
1106
+ latents_input_[:num_block], dim=0)
1107
+ # get batched prompt_embeds
1108
+ prompt_embeds_ = torch.cat(
1109
+ [prompt_embeds.chunk(2)[0]] * num_block
1110
+ + [prompt_embeds.chunk(2)[1]] * num_block,
1111
+ dim=0)
1112
+ # get batched control_image_input
1113
+ control_image_input_ = torch.cat(
1114
+ [
1115
+ x[0, :, :, ][None, :, :, :]
1116
+ for x in control_image_input_[:num_block]
1117
+ ] + [
1118
+ x[1, :, :, ][None, :, :, :]
1119
+ for x in control_image_input_[:num_block]
1120
+ ],
1121
+ dim=0)
1122
+ latents_output = self.denoise_latents(
1123
+ latents_input_, t, prompt_embeds_,
1124
+ control_image_input_, controlnet_conditioning_scale,
1125
+ guess_mode, cross_attention_kwargs,
1126
+ do_classifier_free_guidance, guidance_scale,
1127
+ extra_step_kwargs, views_scheduler_status)
1128
+ rows.append(list(latents_output.chunk(num_block)))
1129
+ result_rows = []
1130
+ for i, row in enumerate(rows):
1131
+ result_row = []
1132
+ for j, tile in enumerate(row):
1133
+ # blend the above tile and the left tile
1134
+ # to the current tile and add the current tile to the result row
1135
+ if i > 0:
1136
+ tile = self.blend_v(rows[i - 1][j], tile,
1137
+ blend_extent)
1138
+ if j > 0:
1139
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
1140
+ if j == 0:
1141
+ tile = self.blend_h(row[-1], tile, blend_extent)
1142
+ if i != len(rows) - 1:
1143
+ if j == len(row) - 1:
1144
+ result_row.append(tile[:, :, :row_limit, :])
1145
+ else:
1146
+ result_row.append(
1147
+ tile[:, :, :row_limit, :row_limit])
1148
+ else:
1149
+ if j == len(row) - 1:
1150
+ result_row.append(tile[:, :, :, :])
1151
+ else:
1152
+ result_row.append(tile[:, :, :, :row_limit])
1153
+ result_rows.append(torch.cat(result_row, dim=3))
1154
+ latents = torch.cat(result_rows, dim=2)
1155
+
1156
+ # call the callback, if provided
1157
+ condition_i = i == len(timesteps) - 1
1158
+ condition_warm = (i + 1) > num_warmup_steps and (
1159
+ i + 1) % self.scheduler.order == 0
1160
+ if condition_i or condition_warm:
1161
+ progress_bar.update()
1162
+ if callback is not None and i % callback_steps == 0:
1163
+ callback(i, t, latents)
1164
+ latents = latents[:, :, :, :w]
1165
+
1166
+ # If we do sequential model offloading, let's offload unet and controlnet
1167
+ # manually for max memory savings
1168
+ if hasattr(
1169
+ self,
1170
+ 'final_offload_hook') and self.final_offload_hook is not None:
1171
+ self.unet.to('cpu')
1172
+ self.controlnet.to('cpu')
1173
+ torch.cuda.empty_cache()
1174
+
1175
+ if not output_type == 'latent':
1176
+ image = self.vae.decode(
1177
+ latents / self.vae.config.scaling_factor, return_dict=False)[0]
1178
+ image, has_nsfw_concept = self.run_safety_checker(
1179
+ image, device, prompt_embeds.dtype)
1180
+ else:
1181
+ image = latents
1182
+ has_nsfw_concept = None
1183
+
1184
+ if has_nsfw_concept is None:
1185
+ do_denormalize = [True] * image.shape[0]
1186
+ else:
1187
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1188
+
1189
+ image = self.image_processor.postprocess(
1190
+ image, output_type=output_type, do_denormalize=do_denormalize)
1191
+
1192
+ # Offload last model to CPU
1193
+ if hasattr(
1194
+ self,
1195
+ 'final_offload_hook') and self.final_offload_hook is not None:
1196
+ self.final_offload_hook.offload()
1197
+
1198
+ if not return_dict:
1199
+ return (image, has_nsfw_concept)
1200
+
1201
+ return StableDiffusionPipelineOutput(
1202
+ images=image, nsfw_content_detected=has_nsfw_concept)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.26.0
2
+ accelerate
3
+ xformers
4
+ triton
5
+ transformers
6
+ git+https://github.com/doevent/Real-ESRGAN.git
7
+ py360convert
8
+ numpy==1.23.5
9
+ basicsr
10
+ streamlit
11
+ streamlit_pannellum
txt2panoimg/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline_base import StableDiffusionBlendExtendPipeline
2
+ from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
3
+ from .text_to_360panorama_image_pipeline import Text2360PanoramaImagePipeline
txt2panoimg/pipeline_base.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © Alibaba, Inc. and its affiliates.
2
+ # The implementation here is modifed based on diffusers.StableDiffusionPipeline,
3
+ # originally Apache 2.0 License and public available at
4
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
5
+
6
+ import re
7
+ from typing import Any, Callable, Dict, List, Optional, Union
8
+
9
+ import torch
10
+ from diffusers import (AutoencoderKL, DiffusionPipeline,
11
+ StableDiffusionPipeline)
12
+
13
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
14
+ try:
15
+ from diffusers.models.autoencoders.vae import DecoderOutput
16
+ except:
17
+ from diffusers.models.vae import DecoderOutput
18
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
19
+ from diffusers.utils import logging, replace_example_docstring
20
+ from transformers import CLIPTokenizer
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```py
27
+ >>> import torch
28
+ >>> from diffusers import EulerAncestralDiscreteScheduler
29
+ >>> from txt2panoimage.pipeline_base import StableDiffusionBlendExtendPipeline
30
+ >>> model_id = "models/sd-base"
31
+ >>> pipe = StableDiffusionBlendExtendPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
32
+ >>> pipe = pipe.to("cuda")
33
+ >>> pipe.vae.enable_tiling()
34
+ >>> pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
35
+ >>> # remove following line if xformers is not installed
36
+ >>> pipe.enable_xformers_memory_efficient_attention()
37
+ >>> pipe.enable_model_cpu_offload()
38
+ >>> prompt = "a living room"
39
+ >>> image = pipe(prompt).images[0]
40
+ ```
41
+ """
42
+
43
+ re_attention = re.compile(
44
+ r"""
45
+ \\\(|
46
+ \\\)|
47
+ \\\[|
48
+ \\]|
49
+ \\\\|
50
+ \\|
51
+ \(|
52
+ \[|
53
+ :([+-]?[.\d]+)\)|
54
+ \)|
55
+ ]|
56
+ [^\\()\[\]:]+|
57
+ :
58
+ """,
59
+ re.X,
60
+ )
61
+
62
+
63
+ def parse_prompt_attention(text):
64
+ """
65
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
66
+ Accepted tokens are:
67
+ (abc) - increases attention to abc by a multiplier of 1.1
68
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
69
+ [abc] - decreases attention to abc by a multiplier of 1.1
70
+ """
71
+
72
+ res = []
73
+ round_brackets = []
74
+ square_brackets = []
75
+
76
+ round_bracket_multiplier = 1.1
77
+ square_bracket_multiplier = 1 / 1.1
78
+
79
+ def multiply_range(start_position, multiplier):
80
+ for p in range(start_position, len(res)):
81
+ res[p][1] *= multiplier
82
+
83
+ for m in re_attention.finditer(text):
84
+ text = m.group(0)
85
+ weight = m.group(1)
86
+
87
+ if text.startswith('\\'):
88
+ res.append([text[1:], 1.0])
89
+ elif text == '(':
90
+ round_brackets.append(len(res))
91
+ elif text == '[':
92
+ square_brackets.append(len(res))
93
+ elif weight is not None and len(round_brackets) > 0:
94
+ multiply_range(round_brackets.pop(), float(weight))
95
+ elif text == ')' and len(round_brackets) > 0:
96
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
97
+ elif text == ']' and len(square_brackets) > 0:
98
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
99
+ else:
100
+ res.append([text, 1.0])
101
+
102
+ for pos in round_brackets:
103
+ multiply_range(pos, round_bracket_multiplier)
104
+
105
+ for pos in square_brackets:
106
+ multiply_range(pos, square_bracket_multiplier)
107
+
108
+ if len(res) == 0:
109
+ res = [['', 1.0]]
110
+
111
+ # merge runs of identical weights
112
+ i = 0
113
+ while i + 1 < len(res):
114
+ if res[i][1] == res[i + 1][1]:
115
+ res[i][0] += res[i + 1][0]
116
+ res.pop(i + 1)
117
+ else:
118
+ i += 1
119
+
120
+ return res
121
+
122
+
123
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str],
124
+ max_length: int):
125
+ r"""
126
+ Tokenize a list of prompts and return its tokens with weights of each token.
127
+
128
+ No padding, starting or ending token is included.
129
+ """
130
+ tokens = []
131
+ weights = []
132
+ truncated = False
133
+ for text in prompt:
134
+ texts_and_weights = parse_prompt_attention(text)
135
+ text_token = []
136
+ text_weight = []
137
+ for word, weight in texts_and_weights:
138
+ # tokenize and discard the starting and the ending token
139
+ token = pipe.tokenizer(word).input_ids[1:-1]
140
+ text_token += token
141
+ # copy the weight by length of token
142
+ text_weight += [weight] * len(token)
143
+ # stop if the text is too long (longer than truncation limit)
144
+ if len(text_token) > max_length:
145
+ truncated = True
146
+ break
147
+ # truncate
148
+ if len(text_token) > max_length:
149
+ truncated = True
150
+ text_token = text_token[:max_length]
151
+ text_weight = text_weight[:max_length]
152
+ tokens.append(text_token)
153
+ weights.append(text_weight)
154
+ if truncated:
155
+ logger.warning(
156
+ 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples'
157
+ )
158
+ return tokens, weights
159
+
160
+
161
+ def pad_tokens_and_weights(tokens,
162
+ weights,
163
+ max_length,
164
+ bos,
165
+ eos,
166
+ pad,
167
+ no_boseos_middle=True,
168
+ chunk_length=77):
169
+ r"""
170
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
171
+ """
172
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
173
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
174
+ for i in range(len(tokens)):
175
+ tokens[i] = [
176
+ bos
177
+ ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
178
+ if no_boseos_middle:
179
+ weights[i] = [1.0] + weights[i] + [1.0] * (
180
+ max_length - 1 - len(weights[i]))
181
+ else:
182
+ w = []
183
+ if len(weights[i]) == 0:
184
+ w = [1.0] * weights_length
185
+ else:
186
+ for j in range(max_embeddings_multiples):
187
+ w.append(1.0) # weight for starting token in this chunk
188
+ w += weights[i][j * (chunk_length - 2):min(
189
+ len(weights[i]), (j + 1) * (chunk_length - 2))]
190
+ w.append(1.0) # weight for ending token in this chunk
191
+ w += [1.0] * (weights_length - len(w))
192
+ weights[i] = w[:]
193
+
194
+ return tokens, weights
195
+
196
+
197
+ def get_unweighted_text_embeddings(
198
+ pipe: DiffusionPipeline,
199
+ text_input: torch.Tensor,
200
+ chunk_length: int,
201
+ no_boseos_middle: Optional[bool] = True,
202
+ ):
203
+ """
204
+ When the length of tokens is a multiple of the capacity of the text encoder,
205
+ it should be split into chunks and sent to the text encoder individually.
206
+ """
207
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
208
+ if max_embeddings_multiples > 1:
209
+ text_embeddings = []
210
+ for i in range(max_embeddings_multiples):
211
+ # extract the i-th chunk
212
+ text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1)
213
+ * (chunk_length - 2) + 2].clone()
214
+
215
+ # cover the head and the tail by the starting and the ending tokens
216
+ text_input_chunk[:, 0] = text_input[0, 0]
217
+ text_input_chunk[:, -1] = text_input[0, -1]
218
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
219
+
220
+ if no_boseos_middle:
221
+ if i == 0:
222
+ # discard the ending token
223
+ text_embedding = text_embedding[:, :-1]
224
+ elif i == max_embeddings_multiples - 1:
225
+ # discard the starting token
226
+ text_embedding = text_embedding[:, 1:]
227
+ else:
228
+ # discard both starting and ending tokens
229
+ text_embedding = text_embedding[:, 1:-1]
230
+
231
+ text_embeddings.append(text_embedding)
232
+ text_embeddings = torch.concat(text_embeddings, axis=1)
233
+ else:
234
+ text_embeddings = pipe.text_encoder(text_input)[0]
235
+ return text_embeddings
236
+
237
+
238
+ def get_weighted_text_embeddings(
239
+ pipe: DiffusionPipeline,
240
+ prompt: Union[str, List[str]],
241
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
242
+ max_embeddings_multiples: Optional[int] = 3,
243
+ no_boseos_middle: Optional[bool] = False,
244
+ skip_parsing: Optional[bool] = False,
245
+ skip_weighting: Optional[bool] = False,
246
+ ):
247
+ r"""
248
+ Prompts can be assigned with local weights using brackets. For example,
249
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
250
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
251
+
252
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
253
+
254
+ Args:
255
+ pipe (`DiffusionPipeline`):
256
+ Pipe to provide access to the tokenizer and the text encoder.
257
+ prompt (`str` or `List[str]`):
258
+ The prompt or prompts to guide the image generation.
259
+ uncond_prompt (`str` or `List[str]`):
260
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
261
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
262
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
263
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
264
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
265
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
266
+ ending token in each of the chunk in the middle.
267
+ skip_parsing (`bool`, *optional*, defaults to `False`):
268
+ Skip the parsing of brackets.
269
+ skip_weighting (`bool`, *optional*, defaults to `False`):
270
+ Skip the weighting. When the parsing is skipped, it is forced True.
271
+ """
272
+ max_length = (pipe.tokenizer.model_max_length
273
+ - 2) * max_embeddings_multiples + 2
274
+ if isinstance(prompt, str):
275
+ prompt = [prompt]
276
+
277
+ if not skip_parsing:
278
+ prompt_tokens, prompt_weights = get_prompts_with_weights(
279
+ pipe, prompt, max_length - 2)
280
+ if uncond_prompt is not None:
281
+ if isinstance(uncond_prompt, str):
282
+ uncond_prompt = [uncond_prompt]
283
+ uncond_tokens, uncond_weights = get_prompts_with_weights(
284
+ pipe, uncond_prompt, max_length - 2)
285
+ else:
286
+ prompt_tokens = [
287
+ token[1:-1] for token in pipe.tokenizer(
288
+ prompt, max_length=max_length, truncation=True).input_ids
289
+ ]
290
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
291
+ if uncond_prompt is not None:
292
+ if isinstance(uncond_prompt, str):
293
+ uncond_prompt = [uncond_prompt]
294
+ uncond_tokens = [
295
+ token[1:-1] for token in pipe.tokenizer(
296
+ uncond_prompt, max_length=max_length,
297
+ truncation=True).input_ids
298
+ ]
299
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
300
+
301
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
302
+ max_length = max([len(token) for token in prompt_tokens])
303
+ if uncond_prompt is not None:
304
+ max_length = max(max_length,
305
+ max([len(token) for token in uncond_tokens]))
306
+
307
+ max_embeddings_multiples = min(
308
+ max_embeddings_multiples,
309
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
310
+ )
311
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
312
+ max_length = (pipe.tokenizer.model_max_length
313
+ - 2) * max_embeddings_multiples + 2
314
+
315
+ # pad the length of tokens and weights
316
+ bos = pipe.tokenizer.bos_token_id
317
+ eos = pipe.tokenizer.eos_token_id
318
+ pad = getattr(pipe.tokenizer, 'pad_token_id', eos)
319
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
320
+ prompt_tokens,
321
+ prompt_weights,
322
+ max_length,
323
+ bos,
324
+ eos,
325
+ pad,
326
+ no_boseos_middle=no_boseos_middle,
327
+ chunk_length=pipe.tokenizer.model_max_length,
328
+ )
329
+ prompt_tokens = torch.tensor(
330
+ prompt_tokens, dtype=torch.long, device=pipe.device)
331
+ if uncond_prompt is not None:
332
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
333
+ uncond_tokens,
334
+ uncond_weights,
335
+ max_length,
336
+ bos,
337
+ eos,
338
+ pad,
339
+ no_boseos_middle=no_boseos_middle,
340
+ chunk_length=pipe.tokenizer.model_max_length,
341
+ )
342
+ uncond_tokens = torch.tensor(
343
+ uncond_tokens, dtype=torch.long, device=pipe.device)
344
+
345
+ # get the embeddings
346
+ text_embeddings = get_unweighted_text_embeddings(
347
+ pipe,
348
+ prompt_tokens,
349
+ pipe.tokenizer.model_max_length,
350
+ no_boseos_middle=no_boseos_middle,
351
+ )
352
+ prompt_weights = torch.tensor(
353
+ prompt_weights,
354
+ dtype=text_embeddings.dtype,
355
+ device=text_embeddings.device)
356
+ if uncond_prompt is not None:
357
+ uncond_embeddings = get_unweighted_text_embeddings(
358
+ pipe,
359
+ uncond_tokens,
360
+ pipe.tokenizer.model_max_length,
361
+ no_boseos_middle=no_boseos_middle,
362
+ )
363
+ uncond_weights = torch.tensor(
364
+ uncond_weights,
365
+ dtype=uncond_embeddings.dtype,
366
+ device=uncond_embeddings.device)
367
+
368
+ # assign weights to the prompts and normalize in the sense of mean
369
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
370
+ if (not skip_parsing) and (not skip_weighting):
371
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
372
+ text_embeddings.dtype)
373
+ text_embeddings *= prompt_weights.unsqueeze(-1)
374
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
375
+ text_embeddings.dtype)
376
+ text_embeddings *= (previous_mean
377
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
378
+ if uncond_prompt is not None:
379
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
380
+ uncond_embeddings.dtype)
381
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
382
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
383
+ uncond_embeddings.dtype)
384
+ uncond_embeddings *= (previous_mean
385
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
386
+
387
+ if uncond_prompt is not None:
388
+ return text_embeddings, uncond_embeddings
389
+ return text_embeddings, None
390
+
391
+
392
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
393
+ """
394
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
395
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
396
+ """
397
+ std_text = noise_pred_text.std(
398
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
399
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
400
+ # rescale the results from guidance (fixes overexposure)
401
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
402
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
403
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (
404
+ 1 - guidance_rescale) * noise_cfg
405
+ return noise_cfg
406
+
407
+
408
+ class StableDiffusionBlendExtendPipeline(StableDiffusionPipeline):
409
+ r"""
410
+ Pipeline for text-to-image generation using Stable Diffusion.
411
+
412
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
413
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
414
+
415
+ In addition the pipeline inherits the following loading methods:
416
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
417
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
418
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
419
+
420
+ as well as the following saving methods:
421
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
422
+
423
+ Args:
424
+ vae ([`AutoencoderKL`]):
425
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
426
+ text_encoder ([`CLIPTextModel`]):
427
+ Frozen text-encoder. Stable Diffusion uses the text portion of
428
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
429
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
430
+ tokenizer (`CLIPTokenizer`):
431
+ Tokenizer of class
432
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/
433
+ en/model_doc/clip#transformers.CLIPTokenizer).
434
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
435
+ scheduler ([`SchedulerMixin`]):
436
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
437
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
438
+ safety_checker ([`StableDiffusionSafetyChecker`]):
439
+ Classification module that estimates whether generated images could be considered offensive or harmful.
440
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
441
+ feature_extractor ([`CLIPImageProcessor`]):
442
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
443
+ """
444
+ _optional_components = ['safety_checker', 'feature_extractor']
445
+
446
+ def _encode_prompt(
447
+ self,
448
+ prompt,
449
+ device,
450
+ num_images_per_prompt,
451
+ do_classifier_free_guidance,
452
+ negative_prompt=None,
453
+ max_embeddings_multiples=3,
454
+ prompt_embeds: Optional[torch.FloatTensor] = None,
455
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
456
+ lora_scale: Optional[float] = None,
457
+ ):
458
+ r"""
459
+ Encodes the prompt into text encoder hidden states.
460
+
461
+ Args:
462
+ prompt (`str` or `list(int)`):
463
+ prompt to be encoded
464
+ device: (`torch.device`):
465
+ torch device
466
+ num_images_per_prompt (`int`):
467
+ number of images that should be generated per prompt
468
+ do_classifier_free_guidance (`bool`):
469
+ whether to use classifier free guidance or not
470
+ negative_prompt (`str` or `List[str]`):
471
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
472
+ if `guidance_scale` is less than `1`).
473
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
474
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
475
+ """
476
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
477
+ self._lora_scale = lora_scale
478
+
479
+ if prompt is not None and isinstance(prompt, str):
480
+ batch_size = 1
481
+ elif prompt is not None and isinstance(prompt, list):
482
+ batch_size = len(prompt)
483
+ else:
484
+ batch_size = prompt_embeds.shape[0]
485
+
486
+ if negative_prompt_embeds is None:
487
+ if negative_prompt is None:
488
+ negative_prompt = [''] * batch_size
489
+ elif isinstance(negative_prompt, str):
490
+ negative_prompt = [negative_prompt] * batch_size
491
+ if batch_size != len(negative_prompt):
492
+ raise ValueError(
493
+ f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
494
+ f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
495
+ ' the batch size of `prompt`.')
496
+ if prompt_embeds is None or negative_prompt_embeds is None:
497
+ if isinstance(self, TextualInversionLoaderMixin):
498
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
499
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
500
+ negative_prompt = self.maybe_convert_prompt(
501
+ negative_prompt, self.tokenizer)
502
+
503
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
504
+ pipe=self,
505
+ prompt=prompt,
506
+ uncond_prompt=negative_prompt
507
+ if do_classifier_free_guidance else None,
508
+ max_embeddings_multiples=max_embeddings_multiples,
509
+ )
510
+ if prompt_embeds is None:
511
+ prompt_embeds = prompt_embeds1
512
+ if negative_prompt_embeds is None:
513
+ negative_prompt_embeds = negative_prompt_embeds1
514
+
515
+ bs_embed, seq_len, _ = prompt_embeds.shape
516
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
517
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
518
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
519
+ seq_len, -1)
520
+
521
+ if do_classifier_free_guidance:
522
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
523
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
524
+ 1, num_images_per_prompt, 1)
525
+ negative_prompt_embeds = negative_prompt_embeds.view(
526
+ bs_embed * num_images_per_prompt, seq_len, -1)
527
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
528
+
529
+ return prompt_embeds
530
+
531
+ def blend_v(self, a, b, blend_extent):
532
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
533
+ for y in range(blend_extent):
534
+ b[:, :,
535
+ y, :] = a[:, :, -blend_extent
536
+ + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
537
+ y / blend_extent)
538
+ return b
539
+
540
+ def blend_h(self, a, b, blend_extent):
541
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
542
+ for x in range(blend_extent):
543
+ b[:, :, :, x] = a[:, :, :, -blend_extent
544
+ + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
545
+ x / blend_extent)
546
+ return b
547
+
548
+ @torch.no_grad()
549
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
550
+ def __call__(
551
+ self,
552
+ prompt: Union[str, List[str]] = None,
553
+ height: Optional[int] = None,
554
+ width: Optional[int] = None,
555
+ num_inference_steps: int = 50,
556
+ guidance_scale: float = 7.5,
557
+ negative_prompt: Optional[Union[str, List[str]]] = None,
558
+ num_images_per_prompt: Optional[int] = 1,
559
+ eta: float = 0.0,
560
+ generator: Optional[Union[torch.Generator,
561
+ List[torch.Generator]]] = None,
562
+ latents: Optional[torch.FloatTensor] = None,
563
+ prompt_embeds: Optional[torch.FloatTensor] = None,
564
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
565
+ output_type: Optional[str] = 'pil',
566
+ return_dict: bool = True,
567
+ callback: Optional[Callable[[int, int, torch.FloatTensor],
568
+ None]] = None,
569
+ callback_steps: int = 1,
570
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
571
+ guidance_rescale: float = 0.0,
572
+ ):
573
+ r"""
574
+ Function invoked when calling the pipeline for generation.
575
+
576
+ Args:
577
+ prompt (`str` or `List[str]`, *optional*):
578
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
579
+ instead.
580
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
581
+ The height in pixels of the generated image.
582
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
583
+ The width in pixels of the generated image.
584
+ num_inference_steps (`int`, *optional*, defaults to 50):
585
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
586
+ expense of slower inference.
587
+ guidance_scale (`float`, *optional*, defaults to 7.5):
588
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
589
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
590
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
591
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
592
+ usually at the expense of lower image quality.
593
+ negative_prompt (`str` or `List[str]`, *optional*):
594
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
595
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
596
+ less than `1`).
597
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
598
+ The number of images to generate per prompt.
599
+ eta (`float`, *optional*, defaults to 0.0):
600
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
601
+ [`schedulers.DDIMScheduler`], will be ignored for others.
602
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
603
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
604
+ to make generation deterministic.
605
+ latents (`torch.FloatTensor`, *optional*):
606
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
607
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
608
+ tensor will ge generated by sampling using the supplied random `generator`.
609
+ prompt_embeds (`torch.FloatTensor`, *optional*):
610
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
611
+ provided, text embeddings will be generated from `prompt` input argument.
612
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
613
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
614
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
615
+ argument.
616
+ output_type (`str`, *optional*, defaults to `"pil"`):
617
+ The output format of the generate image. Choose between
618
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
619
+ return_dict (`bool`, *optional*, defaults to `True`):
620
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
621
+ plain tuple.
622
+ callback (`Callable`, *optional*):
623
+ A function that will be called every `callback_steps` steps during inference. The function will be
624
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
625
+ callback_steps (`int`, *optional*, defaults to 1):
626
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
627
+ called at every step.
628
+ cross_attention_kwargs (`dict`, *optional*):
629
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
630
+ `self.processor` in
631
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
632
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
633
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
634
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
635
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
636
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
637
+
638
+ Examples:
639
+
640
+ Returns:
641
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
642
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
643
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
644
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
645
+ (nsfw) content, according to the `safety_checker`.
646
+ """
647
+
648
+ def tiled_decode(
649
+ self,
650
+ z: torch.FloatTensor,
651
+ return_dict: bool = True
652
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
653
+ r"""Decode a batch of images using a tiled decoder.
654
+
655
+ Args:
656
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
657
+ steps. This is useful to keep memory use constant regardless of image size.
658
+ The end result of tiled decoding is: different from non-tiled decoding due to each tile using a different
659
+ decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
660
+ You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
661
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
662
+ `True`):
663
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
664
+ """
665
+ _tile_overlap_factor = 1 - self.tile_overlap_factor
666
+ overlap_size = int(self.tile_latent_min_size
667
+ * _tile_overlap_factor)
668
+ blend_extent = int(self.tile_sample_min_size
669
+ * self.tile_overlap_factor)
670
+ row_limit = self.tile_sample_min_size - blend_extent
671
+ w = z.shape[3]
672
+ z = torch.cat([z, z[:, :, :, :w // 4]], dim=-1)
673
+ # Split z into overlapping 64x64 tiles and decode them separately.
674
+ # The tiles have an overlap to avoid seams between tiles.
675
+
676
+ rows = []
677
+ for i in range(0, z.shape[2], overlap_size):
678
+ row = []
679
+ tile = z[:, :, i:i + self.tile_latent_min_size, :]
680
+ tile = self.post_quant_conv(tile)
681
+ decoded = self.decoder(tile)
682
+ vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
683
+ row.append(decoded)
684
+ rows.append(row)
685
+ result_rows = []
686
+ for i, row in enumerate(rows):
687
+ result_row = []
688
+ for j, tile in enumerate(row):
689
+ # blend the above tile and the left tile
690
+ # to the current tile and add the current tile to the result row
691
+ if i > 0:
692
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
693
+ if j > 0:
694
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
695
+ result_row.append(
696
+ self.blend_h(
697
+ tile[:, :, :row_limit, w * vae_scale_factor:],
698
+ tile[:, :, :row_limit, :w * vae_scale_factor],
699
+ tile.shape[-1] - w * vae_scale_factor))
700
+ result_rows.append(torch.cat(result_row, dim=3))
701
+
702
+ dec = torch.cat(result_rows, dim=2)
703
+ if not return_dict:
704
+ return (dec, )
705
+
706
+ return DecoderOutput(sample=dec)
707
+
708
+ self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
709
+
710
+ # 0. Default height and width to unet
711
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
712
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
713
+
714
+ # 1. Check inputs. Raise error if not correct
715
+ self.check_inputs(prompt, height, width, callback_steps,
716
+ negative_prompt, prompt_embeds,
717
+ negative_prompt_embeds)
718
+ self.blend_extend = width // self.vae_scale_factor // 32
719
+
720
+ # 2. Define call parameters
721
+ if prompt is not None and isinstance(prompt, str):
722
+ batch_size = 1
723
+ elif prompt is not None and isinstance(prompt, list):
724
+ batch_size = len(prompt)
725
+ else:
726
+ batch_size = prompt_embeds.shape[0]
727
+
728
+ device = self._execution_device
729
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
730
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
731
+ # corresponds to doing no classifier free guidance.
732
+ do_classifier_free_guidance = guidance_scale > 1.0
733
+
734
+ # 3. Encode input prompt
735
+ text_encoder_lora_scale = (
736
+ cross_attention_kwargs.get('scale', None)
737
+ if cross_attention_kwargs is not None else None)
738
+ prompt_embeds = self._encode_prompt(
739
+ prompt,
740
+ device,
741
+ num_images_per_prompt,
742
+ do_classifier_free_guidance,
743
+ negative_prompt,
744
+ prompt_embeds=prompt_embeds,
745
+ negative_prompt_embeds=negative_prompt_embeds,
746
+ lora_scale=text_encoder_lora_scale,
747
+ )
748
+
749
+ # 4. Prepare timesteps
750
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
751
+ timesteps = self.scheduler.timesteps
752
+
753
+ # 5. Prepare latent variables
754
+ num_channels_latents = self.unet.config.in_channels
755
+ latents = self.prepare_latents(
756
+ batch_size * num_images_per_prompt,
757
+ num_channels_latents,
758
+ height,
759
+ width,
760
+ prompt_embeds.dtype,
761
+ device,
762
+ generator,
763
+ latents,
764
+ )
765
+
766
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
767
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
768
+
769
+ # 7. Denoising loop
770
+ num_warmup_steps = len(
771
+ timesteps) - num_inference_steps * self.scheduler.order
772
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
773
+ for i, t in enumerate(timesteps):
774
+ # expand the latents if we are doing classifier free guidance
775
+ latent_model_input = torch.cat(
776
+ [latents] * 2) if do_classifier_free_guidance else latents
777
+ latent_model_input = self.scheduler.scale_model_input(
778
+ latent_model_input, t)
779
+
780
+ # predict the noise residual
781
+ noise_pred = self.unet(
782
+ latent_model_input,
783
+ t,
784
+ encoder_hidden_states=prompt_embeds,
785
+ cross_attention_kwargs=cross_attention_kwargs,
786
+ return_dict=False,
787
+ )[0]
788
+
789
+ # perform guidance
790
+ if do_classifier_free_guidance:
791
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
792
+ noise_pred = noise_pred_uncond + guidance_scale * (
793
+ noise_pred_text - noise_pred_uncond)
794
+
795
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
796
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
797
+ noise_pred = rescale_noise_cfg(
798
+ noise_pred,
799
+ noise_pred_text,
800
+ guidance_rescale=guidance_rescale)
801
+
802
+ # compute the previous noisy sample x_t -> x_t-1
803
+ latents = self.scheduler.step(
804
+ noise_pred,
805
+ t,
806
+ latents,
807
+ **extra_step_kwargs,
808
+ return_dict=False)[0]
809
+
810
+ # call the callback, if provided
811
+ condition_i = i == len(timesteps) - 1
812
+ condition_warm = (i + 1) > num_warmup_steps and (
813
+ i + 1) % self.scheduler.order == 0
814
+ if condition_i or condition_warm:
815
+ progress_bar.update()
816
+ if callback is not None and i % callback_steps == 0:
817
+ callback(i, t, latents)
818
+ latents = self.blend_h(latents, latents, self.blend_extend)
819
+ latents = self.blend_h(latents, latents, self.blend_extend)
820
+ latents = latents[:, :, :, :width // self.vae_scale_factor]
821
+
822
+ if not output_type == 'latent':
823
+ image = self.vae.decode(
824
+ latents / self.vae.config.scaling_factor, return_dict=False)[0]
825
+ image, has_nsfw_concept = self.run_safety_checker(
826
+ image, device, prompt_embeds.dtype)
827
+ else:
828
+ image = latents
829
+ has_nsfw_concept = None
830
+
831
+ if has_nsfw_concept is None:
832
+ do_denormalize = [True] * image.shape[0]
833
+ else:
834
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
835
+
836
+ image = self.image_processor.postprocess(
837
+ image, output_type=output_type, do_denormalize=do_denormalize)
838
+
839
+ # Offload last model to CPU
840
+ if hasattr(
841
+ self,
842
+ 'final_offload_hook') and self.final_offload_hook is not None:
843
+ self.final_offload_hook.offload()
844
+
845
+ if not return_dict:
846
+ return (image, has_nsfw_concept)
847
+
848
+ return StableDiffusionPipelineOutput(
849
+ images=image, nsfw_content_detected=has_nsfw_concept)
txt2panoimg/pipeline_sr.py ADDED
@@ -0,0 +1,1202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © Alibaba, Inc. and its affiliates.
2
+ # The implementation here is modifed based on diffusers.StableDiffusionControlNetImg2ImgPipeline,
3
+ # originally Apache 2.0 License and public available at
4
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
5
+
6
+ import copy
7
+ import re
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from diffusers import (AutoencoderKL, DiffusionPipeline,
15
+ StableDiffusionControlNetImg2ImgPipeline)
16
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
17
+ from diffusers.models import ControlNetModel
18
+ try:
19
+ from diffusers.models.autoencoders.vae import DecoderOutput
20
+ except:
21
+ from diffusers.models.vae import DecoderOutput
22
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
23
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
24
+ from diffusers.utils import logging, replace_example_docstring
25
+ from diffusers.utils.torch_utils import is_compiled_module
26
+
27
+ from transformers import CLIPTokenizer
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> import torch
35
+ >>> from PIL import Image
36
+ >>> from txt2panoimage.pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
37
+ >>> base_model_path = "models/sr-base"
38
+ >>> controlnet_path = "models/sr-control"
39
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
40
+ >>> pipe = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(base_model_path, controlnet=controlnet,
41
+ ... torch_dtype=torch.float16)
42
+ >>> pipe.vae.enable_tiling()
43
+ >>> # remove following line if xformers is not installed
44
+ >>> pipe.enable_xformers_memory_efficient_attention()
45
+ >>> pipe.enable_model_cpu_offload()
46
+ >>> input_image_path = 'data/test.png'
47
+ >>> image = Image.open(input_image_path)
48
+ >>> image = pipe(
49
+ ... "futuristic-looking woman",
50
+ ... num_inference_steps=20,
51
+ ... image=image,
52
+ ... height=768,
53
+ ... width=1536,
54
+ ... control_image=image,
55
+ ... ).images[0]
56
+
57
+ ```
58
+ """
59
+
60
+ re_attention = re.compile(
61
+ r"""
62
+ \\\(|
63
+ \\\)|
64
+ \\\[|
65
+ \\]|
66
+ \\\\|
67
+ \\|
68
+ \(|
69
+ \[|
70
+ :([+-]?[.\d]+)\)|
71
+ \)|
72
+ ]|
73
+ [^\\()\[\]:]+|
74
+ :
75
+ """,
76
+ re.X,
77
+ )
78
+
79
+
80
+ def parse_prompt_attention(text):
81
+ """
82
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
83
+ Accepted tokens are:
84
+ (abc) - increases attention to abc by a multiplier of 1.1
85
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
86
+ [abc] - decreases attention to abc by a multiplier of 1.1
87
+ """
88
+
89
+ res = []
90
+ round_brackets = []
91
+ square_brackets = []
92
+
93
+ round_bracket_multiplier = 1.1
94
+ square_bracket_multiplier = 1 / 1.1
95
+
96
+ def multiply_range(start_position, multiplier):
97
+ for p in range(start_position, len(res)):
98
+ res[p][1] *= multiplier
99
+
100
+ for m in re_attention.finditer(text):
101
+ text = m.group(0)
102
+ weight = m.group(1)
103
+
104
+ if text.startswith('\\'):
105
+ res.append([text[1:], 1.0])
106
+ elif text == '(':
107
+ round_brackets.append(len(res))
108
+ elif text == '[':
109
+ square_brackets.append(len(res))
110
+ elif weight is not None and len(round_brackets) > 0:
111
+ multiply_range(round_brackets.pop(), float(weight))
112
+ elif text == ')' and len(round_brackets) > 0:
113
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
114
+ elif text == ']' and len(square_brackets) > 0:
115
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
116
+ else:
117
+ res.append([text, 1.0])
118
+
119
+ for pos in round_brackets:
120
+ multiply_range(pos, round_bracket_multiplier)
121
+
122
+ for pos in square_brackets:
123
+ multiply_range(pos, square_bracket_multiplier)
124
+
125
+ if len(res) == 0:
126
+ res = [['', 1.0]]
127
+
128
+ # merge runs of identical weights
129
+ i = 0
130
+ while i + 1 < len(res):
131
+ if res[i][1] == res[i + 1][1]:
132
+ res[i][0] += res[i + 1][0]
133
+ res.pop(i + 1)
134
+ else:
135
+ i += 1
136
+
137
+ return res
138
+
139
+
140
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str],
141
+ max_length: int):
142
+ r"""
143
+ Tokenize a list of prompts and return its tokens with weights of each token.
144
+
145
+ No padding, starting or ending token is included.
146
+ """
147
+ tokens = []
148
+ weights = []
149
+ truncated = False
150
+ for text in prompt:
151
+ texts_and_weights = parse_prompt_attention(text)
152
+ text_token = []
153
+ text_weight = []
154
+ for word, weight in texts_and_weights:
155
+ # tokenize and discard the starting and the ending token
156
+ token = pipe.tokenizer(word).input_ids[1:-1]
157
+ text_token += token
158
+ # copy the weight by length of token
159
+ text_weight += [weight] * len(token)
160
+ # stop if the text is too long (longer than truncation limit)
161
+ if len(text_token) > max_length:
162
+ truncated = True
163
+ break
164
+ # truncate
165
+ if len(text_token) > max_length:
166
+ truncated = True
167
+ text_token = text_token[:max_length]
168
+ text_weight = text_weight[:max_length]
169
+ tokens.append(text_token)
170
+ weights.append(text_weight)
171
+ if truncated:
172
+ logger.warning(
173
+ 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples'
174
+ )
175
+ return tokens, weights
176
+
177
+
178
+ def pad_tokens_and_weights(tokens,
179
+ weights,
180
+ max_length,
181
+ bos,
182
+ eos,
183
+ pad,
184
+ no_boseos_middle=True,
185
+ chunk_length=77):
186
+ r"""
187
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
188
+ """
189
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
190
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
191
+ for i in range(len(tokens)):
192
+ tokens[i] = [
193
+ bos
194
+ ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
195
+ if no_boseos_middle:
196
+ weights[i] = [1.0] + weights[i] + [1.0] * (
197
+ max_length - 1 - len(weights[i]))
198
+ else:
199
+ w = []
200
+ if len(weights[i]) == 0:
201
+ w = [1.0] * weights_length
202
+ else:
203
+ for j in range(max_embeddings_multiples):
204
+ w.append(1.0) # weight for starting token in this chunk
205
+ w += weights[i][j * (chunk_length - 2):min(
206
+ len(weights[i]), (j + 1) * (chunk_length - 2))]
207
+ w.append(1.0) # weight for ending token in this chunk
208
+ w += [1.0] * (weights_length - len(w))
209
+ weights[i] = w[:]
210
+
211
+ return tokens, weights
212
+
213
+
214
+ def get_unweighted_text_embeddings(
215
+ pipe: DiffusionPipeline,
216
+ text_input: torch.Tensor,
217
+ chunk_length: int,
218
+ no_boseos_middle: Optional[bool] = True,
219
+ ):
220
+ """
221
+ When the length of tokens is a multiple of the capacity of the text encoder,
222
+ it should be split into chunks and sent to the text encoder individually.
223
+ """
224
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
225
+ if max_embeddings_multiples > 1:
226
+ text_embeddings = []
227
+ for i in range(max_embeddings_multiples):
228
+ # extract the i-th chunk
229
+ text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1)
230
+ * (chunk_length - 2) + 2].clone()
231
+
232
+ # cover the head and the tail by the starting and the ending tokens
233
+ text_input_chunk[:, 0] = text_input[0, 0]
234
+ text_input_chunk[:, -1] = text_input[0, -1]
235
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
236
+
237
+ if no_boseos_middle:
238
+ if i == 0:
239
+ # discard the ending token
240
+ text_embedding = text_embedding[:, :-1]
241
+ elif i == max_embeddings_multiples - 1:
242
+ # discard the starting token
243
+ text_embedding = text_embedding[:, 1:]
244
+ else:
245
+ # discard both starting and ending tokens
246
+ text_embedding = text_embedding[:, 1:-1]
247
+
248
+ text_embeddings.append(text_embedding)
249
+ text_embeddings = torch.concat(text_embeddings, axis=1)
250
+ else:
251
+ text_embeddings = pipe.text_encoder(text_input)[0]
252
+ return text_embeddings
253
+
254
+
255
+ def get_weighted_text_embeddings(
256
+ pipe: DiffusionPipeline,
257
+ prompt: Union[str, List[str]],
258
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
259
+ max_embeddings_multiples: Optional[int] = 3,
260
+ no_boseos_middle: Optional[bool] = False,
261
+ skip_parsing: Optional[bool] = False,
262
+ skip_weighting: Optional[bool] = False,
263
+ ):
264
+ r"""
265
+ Prompts can be assigned with local weights using brackets. For example,
266
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
267
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
268
+
269
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
270
+
271
+ Args:
272
+ pipe (`DiffusionPipeline`):
273
+ Pipe to provide access to the tokenizer and the text encoder.
274
+ prompt (`str` or `List[str]`):
275
+ The prompt or prompts to guide the image generation.
276
+ uncond_prompt (`str` or `List[str]`):
277
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
278
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
279
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
280
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
281
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
282
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
283
+ ending token in each of the chunk in the middle.
284
+ skip_parsing (`bool`, *optional*, defaults to `False`):
285
+ Skip the parsing of brackets.
286
+ skip_weighting (`bool`, *optional*, defaults to `False`):
287
+ Skip the weighting. When the parsing is skipped, it is forced True.
288
+ """
289
+ max_length = (pipe.tokenizer.model_max_length
290
+ - 2) * max_embeddings_multiples + 2
291
+ if isinstance(prompt, str):
292
+ prompt = [prompt]
293
+
294
+ if not skip_parsing:
295
+ prompt_tokens, prompt_weights = get_prompts_with_weights(
296
+ pipe, prompt, max_length - 2)
297
+ if uncond_prompt is not None:
298
+ if isinstance(uncond_prompt, str):
299
+ uncond_prompt = [uncond_prompt]
300
+ uncond_tokens, uncond_weights = get_prompts_with_weights(
301
+ pipe, uncond_prompt, max_length - 2)
302
+ else:
303
+ prompt_tokens = [
304
+ token[1:-1] for token in pipe.tokenizer(
305
+ prompt, max_length=max_length, truncation=True).input_ids
306
+ ]
307
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
308
+ if uncond_prompt is not None:
309
+ if isinstance(uncond_prompt, str):
310
+ uncond_prompt = [uncond_prompt]
311
+ uncond_tokens = [
312
+ token[1:-1] for token in pipe.tokenizer(
313
+ uncond_prompt, max_length=max_length,
314
+ truncation=True).input_ids
315
+ ]
316
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
317
+
318
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
319
+ max_length = max([len(token) for token in prompt_tokens])
320
+ if uncond_prompt is not None:
321
+ max_length = max(max_length,
322
+ max([len(token) for token in uncond_tokens]))
323
+
324
+ max_embeddings_multiples = min(
325
+ max_embeddings_multiples,
326
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
327
+ )
328
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
329
+ max_length = (pipe.tokenizer.model_max_length
330
+ - 2) * max_embeddings_multiples + 2
331
+
332
+ # pad the length of tokens and weights
333
+ bos = pipe.tokenizer.bos_token_id
334
+ eos = pipe.tokenizer.eos_token_id
335
+ pad = getattr(pipe.tokenizer, 'pad_token_id', eos)
336
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
337
+ prompt_tokens,
338
+ prompt_weights,
339
+ max_length,
340
+ bos,
341
+ eos,
342
+ pad,
343
+ no_boseos_middle=no_boseos_middle,
344
+ chunk_length=pipe.tokenizer.model_max_length,
345
+ )
346
+ prompt_tokens = torch.tensor(
347
+ prompt_tokens, dtype=torch.long, device=pipe.device)
348
+ if uncond_prompt is not None:
349
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
350
+ uncond_tokens,
351
+ uncond_weights,
352
+ max_length,
353
+ bos,
354
+ eos,
355
+ pad,
356
+ no_boseos_middle=no_boseos_middle,
357
+ chunk_length=pipe.tokenizer.model_max_length,
358
+ )
359
+ uncond_tokens = torch.tensor(
360
+ uncond_tokens, dtype=torch.long, device=pipe.device)
361
+
362
+ # get the embeddings
363
+ text_embeddings = get_unweighted_text_embeddings(
364
+ pipe,
365
+ prompt_tokens,
366
+ pipe.tokenizer.model_max_length,
367
+ no_boseos_middle=no_boseos_middle,
368
+ )
369
+ prompt_weights = torch.tensor(
370
+ prompt_weights,
371
+ dtype=text_embeddings.dtype,
372
+ device=text_embeddings.device)
373
+ if uncond_prompt is not None:
374
+ uncond_embeddings = get_unweighted_text_embeddings(
375
+ pipe,
376
+ uncond_tokens,
377
+ pipe.tokenizer.model_max_length,
378
+ no_boseos_middle=no_boseos_middle,
379
+ )
380
+ uncond_weights = torch.tensor(
381
+ uncond_weights,
382
+ dtype=uncond_embeddings.dtype,
383
+ device=uncond_embeddings.device)
384
+
385
+ # assign weights to the prompts and normalize in the sense of mean
386
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
387
+ if (not skip_parsing) and (not skip_weighting):
388
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
389
+ text_embeddings.dtype)
390
+ text_embeddings *= prompt_weights.unsqueeze(-1)
391
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
392
+ text_embeddings.dtype)
393
+ text_embeddings *= (previous_mean
394
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
395
+ if uncond_prompt is not None:
396
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
397
+ uncond_embeddings.dtype)
398
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
399
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
400
+ uncond_embeddings.dtype)
401
+ uncond_embeddings *= (previous_mean
402
+ / current_mean).unsqueeze(-1).unsqueeze(-1)
403
+
404
+ if uncond_prompt is not None:
405
+ return text_embeddings, uncond_embeddings
406
+ return text_embeddings, None
407
+
408
+
409
+ def prepare_image(image):
410
+ if isinstance(image, torch.Tensor):
411
+ # Batch single image
412
+ if image.ndim == 3:
413
+ image = image.unsqueeze(0)
414
+
415
+ image = image.to(dtype=torch.float32)
416
+ else:
417
+ # preprocess image
418
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
419
+ image = [image]
420
+
421
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
422
+ image = [np.array(i.convert('RGB'))[None, :] for i in image]
423
+ image = np.concatenate(image, axis=0)
424
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
425
+ image = np.concatenate([i[None, :] for i in image], axis=0)
426
+
427
+ image = image.transpose(0, 3, 1, 2)
428
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
429
+
430
+ return image
431
+
432
+
433
+ class StableDiffusionControlNetImg2ImgPanoPipeline(
434
+ StableDiffusionControlNetImg2ImgPipeline):
435
+ r"""
436
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
437
+
438
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
439
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
440
+
441
+ In addition the pipeline inherits the following loading methods:
442
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
443
+
444
+ Args:
445
+ vae ([`AutoencoderKL`]):
446
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
447
+ text_encoder ([`CLIPTextModel`]):
448
+ Frozen text-encoder. Stable Diffusion uses the text portion of
449
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
450
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
451
+ tokenizer (`CLIPTokenizer`):
452
+ Tokenizer of class
453
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/
454
+ model_doc/clip#transformers.CLIPTokenizer).
455
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
456
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
457
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
458
+ as a list, the outputs from each ControlNet are added together to create one combined additional
459
+ conditioning.
460
+ scheduler ([`SchedulerMixin`]):
461
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
462
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
463
+ safety_checker ([`StableDiffusionSafetyChecker`]):
464
+ Classification module that estimates whether generated images could be considered offensive or harmful.
465
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
466
+ feature_extractor ([`CLIPImageProcessor`]):
467
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
468
+ """
469
+ _optional_components = ['safety_checker', 'feature_extractor']
470
+
471
+ def check_inputs(
472
+ self,
473
+ prompt,
474
+ image,
475
+ height,
476
+ width,
477
+ callback_steps,
478
+ negative_prompt=None,
479
+ prompt_embeds=None,
480
+ negative_prompt_embeds=None,
481
+ controlnet_conditioning_scale=1.0,
482
+ ):
483
+ if height % 8 != 0 or width % 8 != 0:
484
+ raise ValueError(
485
+ f'`height` and `width` have to be divisible by 8 but are {height} and {width}.'
486
+ )
487
+ condition_1 = callback_steps is not None
488
+ condition_2 = not isinstance(callback_steps,
489
+ int) or callback_steps <= 0
490
+ if (callback_steps is None) or (condition_1 and condition_2):
491
+ raise ValueError(
492
+ f'`callback_steps` has to be a positive integer but is {callback_steps} of type'
493
+ f' {type(callback_steps)}.')
494
+ if prompt is not None and prompt_embeds is not None:
495
+ raise ValueError(
496
+ f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to'
497
+ ' only forward one of the two.')
498
+ elif prompt is None and prompt_embeds is None:
499
+ raise ValueError(
500
+ 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.'
501
+ )
502
+ elif prompt is not None and (not isinstance(prompt, str)
503
+ and not isinstance(prompt, list)):
504
+ raise ValueError(
505
+ f'`prompt` has to be of type `str` or `list` but is {type(prompt)}'
506
+ )
507
+ if negative_prompt is not None and negative_prompt_embeds is not None:
508
+ raise ValueError(
509
+ f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:'
510
+ f' {negative_prompt_embeds}. Please make sure to only forward one of the two.'
511
+ )
512
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
513
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
514
+ raise ValueError(
515
+ '`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but'
516
+ f' got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`'
517
+ f' {negative_prompt_embeds.shape}.')
518
+ # `prompt` needs more sophisticated handling when there are multiple
519
+ # conditionings.
520
+ if isinstance(self.controlnet, MultiControlNetModel):
521
+ if isinstance(prompt, list):
522
+ logger.warning(
523
+ f'You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}'
524
+ ' prompts. The conditionings will be fixed across the prompts.'
525
+ )
526
+ # Check `image`
527
+ is_compiled = hasattr(
528
+ F, 'scaled_dot_product_attention') and isinstance(
529
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule)
530
+ if (isinstance(self.controlnet, ControlNetModel) or is_compiled
531
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)):
532
+ self.check_image(image, prompt, prompt_embeds)
533
+ elif (isinstance(self.controlnet, MultiControlNetModel) or is_compiled
534
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)):
535
+ if not isinstance(image, list):
536
+ raise TypeError(
537
+ 'For multiple controlnets: `image` must be type `list`')
538
+ # When `image` is a nested list:
539
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
540
+ elif any(isinstance(i, list) for i in image):
541
+ raise ValueError(
542
+ 'A single batch of multiple conditionings are supported at the moment.'
543
+ )
544
+ elif len(image) != len(self.controlnet.nets):
545
+ raise ValueError(
546
+ 'For multiple controlnets: `image` must have the same length as the number of controlnets.'
547
+ )
548
+ for image_ in image:
549
+ self.check_image(image_, prompt, prompt_embeds)
550
+ else:
551
+ assert False
552
+ # Check `controlnet_conditioning_scale`
553
+ if (isinstance(self.controlnet, ControlNetModel) or is_compiled
554
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)):
555
+ if not isinstance(controlnet_conditioning_scale, float):
556
+ raise TypeError(
557
+ 'For single controlnet: `controlnet_conditioning_scale` must be type `float`.'
558
+ )
559
+ elif (isinstance(self.controlnet, MultiControlNetModel) or is_compiled
560
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)):
561
+ if isinstance(controlnet_conditioning_scale, list):
562
+ if any(
563
+ isinstance(i, list)
564
+ for i in controlnet_conditioning_scale):
565
+ raise ValueError(
566
+ 'A single batch of multiple conditionings are supported at the moment.'
567
+ )
568
+ elif isinstance(
569
+ controlnet_conditioning_scale,
570
+ list) and len(controlnet_conditioning_scale) != len(
571
+ self.controlnet.nets):
572
+ raise ValueError(
573
+ 'For multiple controlnets: When `controlnet_conditioning_scale` '
574
+ 'is specified as `list`, it must have'
575
+ ' the same length as the number of controlnets')
576
+ else:
577
+ assert False
578
+
579
+ def _default_height_width(self, height, width, image):
580
+ # NOTE: It is possible that a list of images have different
581
+ # dimensions for each image, so just checking the first image
582
+ # is not _exactly_ correct, but it is simple.
583
+ while isinstance(image, list):
584
+ image = image[0]
585
+ if height is None:
586
+ if isinstance(image, PIL.Image.Image):
587
+ height = image.height
588
+ elif isinstance(image, torch.Tensor):
589
+ height = image.shape[2]
590
+ height = (height // 8) * 8 # round down to nearest multiple of 8
591
+ if width is None:
592
+ if isinstance(image, PIL.Image.Image):
593
+ width = image.width
594
+ elif isinstance(image, torch.Tensor):
595
+ width = image.shape[3]
596
+ width = (width // 8) * 8 # round down to nearest multiple of 8
597
+ return height, width
598
+
599
+ def _encode_prompt(
600
+ self,
601
+ prompt,
602
+ device,
603
+ num_images_per_prompt,
604
+ do_classifier_free_guidance,
605
+ negative_prompt=None,
606
+ max_embeddings_multiples=3,
607
+ prompt_embeds: Optional[torch.FloatTensor] = None,
608
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ lora_scale: Optional[float] = None,
610
+ ):
611
+ r"""
612
+ Encodes the prompt into text encoder hidden states.
613
+
614
+ Args:
615
+ prompt (`str` or `list(int)`):
616
+ prompt to be encoded
617
+ device: (`torch.device`):
618
+ torch device
619
+ num_images_per_prompt (`int`):
620
+ number of images that should be generated per prompt
621
+ do_classifier_free_guidance (`bool`):
622
+ whether to use classifier free guidance or not
623
+ negative_prompt (`str` or `List[str]`):
624
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
625
+ if `guidance_scale` is less than `1`).
626
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
627
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
628
+ """
629
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
630
+ self._lora_scale = lora_scale
631
+
632
+ if prompt is not None and isinstance(prompt, str):
633
+ batch_size = 1
634
+ elif prompt is not None and isinstance(prompt, list):
635
+ batch_size = len(prompt)
636
+ else:
637
+ batch_size = prompt_embeds.shape[0]
638
+
639
+ if negative_prompt_embeds is None:
640
+ if negative_prompt is None:
641
+ negative_prompt = [''] * batch_size
642
+ elif isinstance(negative_prompt, str):
643
+ negative_prompt = [negative_prompt] * batch_size
644
+ if batch_size != len(negative_prompt):
645
+ raise ValueError(
646
+ f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
647
+ f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
648
+ ' the batch size of `prompt`.')
649
+ if prompt_embeds is None or negative_prompt_embeds is None:
650
+ if isinstance(self, TextualInversionLoaderMixin):
651
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
652
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
653
+ negative_prompt = self.maybe_convert_prompt(
654
+ negative_prompt, self.tokenizer)
655
+
656
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
657
+ pipe=self,
658
+ prompt=prompt,
659
+ uncond_prompt=negative_prompt
660
+ if do_classifier_free_guidance else None,
661
+ max_embeddings_multiples=max_embeddings_multiples,
662
+ )
663
+ if prompt_embeds is None:
664
+ prompt_embeds = prompt_embeds1
665
+ if negative_prompt_embeds is None:
666
+ negative_prompt_embeds = negative_prompt_embeds1
667
+
668
+ bs_embed, seq_len, _ = prompt_embeds.shape
669
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
670
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
671
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
672
+ seq_len, -1)
673
+
674
+ if do_classifier_free_guidance:
675
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
676
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
677
+ 1, num_images_per_prompt, 1)
678
+ negative_prompt_embeds = negative_prompt_embeds.view(
679
+ bs_embed * num_images_per_prompt, seq_len, -1)
680
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
681
+
682
+ return prompt_embeds
683
+
684
+ def denoise_latents(self, latents, t, prompt_embeds, control_image,
685
+ controlnet_conditioning_scale, guess_mode,
686
+ cross_attention_kwargs, do_classifier_free_guidance,
687
+ guidance_scale, extra_step_kwargs,
688
+ views_scheduler_status):
689
+ # expand the latents if we are doing classifier free guidance
690
+ latent_model_input = torch.cat(
691
+ [latents] * 2) if do_classifier_free_guidance else latents
692
+ self.scheduler.__dict__.update(views_scheduler_status[0])
693
+ latent_model_input = self.scheduler.scale_model_input(
694
+ latent_model_input, t)
695
+ # controlnet(s) inference
696
+ if guess_mode and do_classifier_free_guidance:
697
+ # Infer ControlNet only for the conditional batch.
698
+ controlnet_latent_model_input = latents
699
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
700
+ else:
701
+ controlnet_latent_model_input = latent_model_input
702
+ controlnet_prompt_embeds = prompt_embeds
703
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
704
+ controlnet_latent_model_input,
705
+ t,
706
+ encoder_hidden_states=controlnet_prompt_embeds,
707
+ controlnet_cond=control_image,
708
+ conditioning_scale=controlnet_conditioning_scale,
709
+ guess_mode=guess_mode,
710
+ return_dict=False,
711
+ )
712
+ if guess_mode and do_classifier_free_guidance:
713
+ # Infered ControlNet only for the conditional batch.
714
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
715
+ # add 0 to the unconditional batch to keep it unchanged.
716
+ down_block_res_samples = [
717
+ torch.cat([torch.zeros_like(d), d])
718
+ for d in down_block_res_samples
719
+ ]
720
+ mid_block_res_sample = torch.cat(
721
+ [torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
722
+ # predict the noise residual
723
+ noise_pred = self.unet(
724
+ latent_model_input,
725
+ t,
726
+ encoder_hidden_states=prompt_embeds,
727
+ cross_attention_kwargs=cross_attention_kwargs,
728
+ down_block_additional_residuals=down_block_res_samples,
729
+ mid_block_additional_residual=mid_block_res_sample,
730
+ return_dict=False,
731
+ )[0]
732
+ # perform guidance
733
+ if do_classifier_free_guidance:
734
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
735
+ noise_pred = noise_pred_uncond + guidance_scale * (
736
+ noise_pred_text - noise_pred_uncond)
737
+ # compute the previous noisy sample x_t -> x_t-1
738
+ latents = self.scheduler.step(
739
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
740
+ return latents
741
+
742
+ def blend_v(self, a, b, blend_extent):
743
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
744
+ for y in range(blend_extent):
745
+ b[:, :,
746
+ y, :] = a[:, :, -blend_extent
747
+ + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
748
+ y / blend_extent)
749
+ return b
750
+
751
+ def blend_h(self, a, b, blend_extent):
752
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
753
+ for x in range(blend_extent):
754
+ b[:, :, :, x] = a[:, :, :, -blend_extent
755
+ + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
756
+ x / blend_extent)
757
+ return b
758
+
759
+ def get_blocks(self, latents, control_image, tile_latent_min_size,
760
+ overlap_size):
761
+ rows_latents = []
762
+ rows_control_images = []
763
+ for i in range(0, latents.shape[2] - overlap_size, overlap_size):
764
+ row_latents = []
765
+ row_control_images = []
766
+ for j in range(0, latents.shape[3] - overlap_size, overlap_size):
767
+ latents_input = latents[:, :, i:i + tile_latent_min_size,
768
+ j:j + tile_latent_min_size]
769
+ c_start_i = self.vae_scale_factor * i
770
+ c_end_i = self.vae_scale_factor * (i + tile_latent_min_size)
771
+ c_start_j = self.vae_scale_factor * j
772
+ c_end_j = self.vae_scale_factor * (j + tile_latent_min_size)
773
+ control_image_input = control_image[:, :, c_start_i:c_end_i,
774
+ c_start_j:c_end_j]
775
+ row_latents.append(latents_input)
776
+ row_control_images.append(control_image_input)
777
+ rows_latents.append(row_latents)
778
+ rows_control_images.append(row_control_images)
779
+ return rows_latents, rows_control_images
780
+
781
+ @torch.no_grad()
782
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
783
+ def __call__(
784
+ self,
785
+ prompt: Union[str, List[str]] = None,
786
+ image: Union[torch.FloatTensor, PIL.Image.Image,
787
+ List[torch.FloatTensor], List[PIL.Image.Image]] = None,
788
+ control_image: Union[torch.FloatTensor, PIL.Image.Image,
789
+ List[torch.FloatTensor],
790
+ List[PIL.Image.Image]] = None,
791
+ height: Optional[int] = None,
792
+ width: Optional[int] = None,
793
+ strength: float = 0.8,
794
+ num_inference_steps: int = 50,
795
+ guidance_scale: float = 7.5,
796
+ negative_prompt: Optional[Union[str, List[str]]] = None,
797
+ num_images_per_prompt: Optional[int] = 1,
798
+ eta: float = 0.0,
799
+ generator: Optional[Union[torch.Generator,
800
+ List[torch.Generator]]] = None,
801
+ latents: Optional[torch.FloatTensor] = None,
802
+ prompt_embeds: Optional[torch.FloatTensor] = None,
803
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
804
+ output_type: Optional[str] = 'pil',
805
+ return_dict: bool = True,
806
+ callback: Optional[Callable[[int, int, torch.FloatTensor],
807
+ None]] = None,
808
+ callback_steps: int = 1,
809
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
810
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
811
+ guess_mode: bool = False,
812
+ context_size: int = 768,
813
+ ):
814
+ r"""
815
+ Function invoked when calling the pipeline for generation.
816
+
817
+ Args:
818
+ prompt (`str` or `List[str]`, *optional*):
819
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
820
+ instead.
821
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
822
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
823
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
824
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
825
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
826
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
827
+ specified in init, images must be passed as a list such that each element of the list can be correctly
828
+ batched for input to a single controlnet.
829
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
830
+ The height in pixels of the generated image.
831
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
832
+ The width in pixels of the generated image.
833
+ num_inference_steps (`int`, *optional*, defaults to 50):
834
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
835
+ expense of slower inference.
836
+ guidance_scale (`float`, *optional*, defaults to 7.5):
837
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
838
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
839
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
840
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
841
+ usually at the expense of lower image quality.
842
+ negative_prompt (`str` or `List[str]`, *optional*):
843
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
844
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
845
+ less than `1`).
846
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
847
+ The number of images to generate per prompt.
848
+ eta (`float`, *optional*, defaults to 0.0):
849
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
850
+ [`schedulers.DDIMScheduler`], will be ignored for others.
851
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
852
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
853
+ to make generation deterministic.
854
+ latents (`torch.FloatTensor`, *optional*):
855
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
856
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
857
+ tensor will ge generated by sampling using the supplied random `generator`.
858
+ prompt_embeds (`torch.FloatTensor`, *optional*):
859
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
860
+ provided, text embeddings will be generated from `prompt` input argument.
861
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
862
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
863
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
864
+ argument.
865
+ output_type (`str`, *optional*, defaults to `"pil"`):
866
+ The output format of the generate image. Choose between
867
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
868
+ return_dict (`bool`, *optional*, defaults to `True`):
869
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
870
+ plain tuple.
871
+ callback (`Callable`, *optional*):
872
+ A function that will be called every `callback_steps` steps during inference. The function will be
873
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
874
+ callback_steps (`int`, *optional*, defaults to 1):
875
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
876
+ called at every step.
877
+ cross_attention_kwargs (`dict`, *optional*):
878
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
879
+ `self.processor` in
880
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/
881
+ src/diffusers/models/cross_attention.py).
882
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
883
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
884
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
885
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
886
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
887
+ guess_mode (`bool`, *optional*, defaults to `False`):
888
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
889
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
890
+ context_size ('int', *optional*, defaults to '768'):
891
+ tiled size when denoise the latents.
892
+
893
+ Examples:
894
+
895
+ Returns:
896
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
897
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
898
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
899
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
900
+ (nsfw) content, according to the `safety_checker`.
901
+ """
902
+
903
+ def tiled_decode(
904
+ self,
905
+ z: torch.FloatTensor,
906
+ return_dict: bool = True
907
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
908
+ r"""Decode a batch of images using a tiled decoder.
909
+
910
+ Args:
911
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
912
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled
913
+ decoding is: different from non-tiled decoding due to each tile using a different decoder.
914
+ To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
915
+ You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
916
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
917
+ `True`):
918
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
919
+ """
920
+ _tile_overlap_factor = 1 - self.tile_overlap_factor
921
+ overlap_size = int(self.tile_latent_min_size
922
+ * _tile_overlap_factor)
923
+ blend_extent = int(self.tile_sample_min_size
924
+ * self.tile_overlap_factor)
925
+ row_limit = self.tile_sample_min_size - blend_extent
926
+ w = z.shape[3]
927
+ z = torch.cat([z, z[:, :, :, :w // 4]], dim=-1)
928
+ # Split z into overlapping 64x64 tiles and decode them separately.
929
+ # The tiles have an overlap to avoid seams between tiles.
930
+
931
+ rows = []
932
+ for i in range(0, z.shape[2], overlap_size):
933
+ row = []
934
+ tile = z[:, :, i:i + self.tile_latent_min_size, :]
935
+ tile = self.post_quant_conv(tile)
936
+ decoded = self.decoder(tile)
937
+ vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
938
+ row.append(decoded)
939
+ rows.append(row)
940
+ result_rows = []
941
+ for i, row in enumerate(rows):
942
+ result_row = []
943
+ for j, tile in enumerate(row):
944
+ # blend the above tile and the left tile
945
+ # to the current tile and add the current tile to the result row
946
+ if i > 0:
947
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
948
+ if j > 0:
949
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
950
+ result_row.append(
951
+ self.blend_h(
952
+ tile[:, :, :row_limit, w * vae_scale_factor:],
953
+ tile[:, :, :row_limit, :w * vae_scale_factor],
954
+ tile.shape[-1] - w * vae_scale_factor))
955
+ result_rows.append(torch.cat(result_row, dim=3))
956
+
957
+ dec = torch.cat(result_rows, dim=2)
958
+ if not return_dict:
959
+ return (dec, )
960
+
961
+ return DecoderOutput(sample=dec)
962
+
963
+ self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
964
+
965
+ # 0. Default height and width to unet
966
+ height, width = self._default_height_width(height, width, image)
967
+
968
+ # 1. Check inputs. Raise error if not correct
969
+ self.check_inputs(
970
+ prompt,
971
+ control_image,
972
+ height,
973
+ width,
974
+ callback_steps,
975
+ negative_prompt,
976
+ prompt_embeds,
977
+ negative_prompt_embeds,
978
+ controlnet_conditioning_scale,
979
+ )
980
+
981
+ # 2. Define call parameters
982
+ if prompt is not None and isinstance(prompt, str):
983
+ batch_size = 1
984
+ elif prompt is not None and isinstance(prompt, list):
985
+ batch_size = len(prompt)
986
+ else:
987
+ batch_size = prompt_embeds.shape[0]
988
+
989
+ device = self._execution_device
990
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
991
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
992
+ # corresponds to doing no classifier free guidance.
993
+ do_classifier_free_guidance = guidance_scale > 1.0
994
+
995
+ controlnet = self.controlnet._orig_mod if is_compiled_module(
996
+ self.controlnet) else self.controlnet
997
+
998
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
999
+ controlnet_conditioning_scale, float):
1000
+ controlnet_conditioning_scale = [controlnet_conditioning_scale
1001
+ ] * len(controlnet.nets)
1002
+
1003
+ global_pool_conditions = (
1004
+ controlnet.config.global_pool_conditions if isinstance(
1005
+ controlnet, ControlNetModel) else
1006
+ controlnet.nets[0].config.global_pool_conditions)
1007
+ guess_mode = guess_mode or global_pool_conditions
1008
+
1009
+ # 3. Encode input prompt
1010
+ prompt_embeds = self._encode_prompt(
1011
+ prompt,
1012
+ device,
1013
+ num_images_per_prompt,
1014
+ do_classifier_free_guidance,
1015
+ negative_prompt,
1016
+ prompt_embeds=prompt_embeds,
1017
+ negative_prompt_embeds=negative_prompt_embeds,
1018
+ )
1019
+ # 4. Prepare image, and controlnet_conditioning_image
1020
+ image = prepare_image(image)
1021
+
1022
+ # 5. Prepare image
1023
+ if isinstance(controlnet, ControlNetModel):
1024
+ control_image = self.prepare_control_image(
1025
+ image=control_image,
1026
+ width=width,
1027
+ height=height,
1028
+ batch_size=batch_size * num_images_per_prompt,
1029
+ num_images_per_prompt=num_images_per_prompt,
1030
+ device=device,
1031
+ dtype=controlnet.dtype,
1032
+ do_classifier_free_guidance=do_classifier_free_guidance,
1033
+ guess_mode=guess_mode,
1034
+ )
1035
+ elif isinstance(controlnet, MultiControlNetModel):
1036
+ control_images = []
1037
+
1038
+ for control_image_ in control_image:
1039
+ control_image_ = self.prepare_control_image(
1040
+ image=control_image_,
1041
+ width=width,
1042
+ height=height,
1043
+ batch_size=batch_size * num_images_per_prompt,
1044
+ num_images_per_prompt=num_images_per_prompt,
1045
+ device=device,
1046
+ dtype=controlnet.dtype,
1047
+ do_classifier_free_guidance=do_classifier_free_guidance,
1048
+ guess_mode=guess_mode,
1049
+ )
1050
+
1051
+ control_images.append(control_image_)
1052
+
1053
+ control_image = control_images
1054
+ else:
1055
+ assert False
1056
+
1057
+ # 5. Prepare timesteps
1058
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1059
+ timesteps, num_inference_steps = self.get_timesteps(
1060
+ num_inference_steps, strength, device)
1061
+ latent_timestep = timesteps[:1].repeat(batch_size
1062
+ * num_images_per_prompt)
1063
+
1064
+ # 6. Prepare latent variables
1065
+ latents = self.prepare_latents(
1066
+ image,
1067
+ latent_timestep,
1068
+ batch_size,
1069
+ num_images_per_prompt,
1070
+ prompt_embeds.dtype,
1071
+ device,
1072
+ generator,
1073
+ )
1074
+
1075
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1076
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1077
+
1078
+ views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)]
1079
+ # value = torch.zeros_like(latents)
1080
+ _, _, height, width = control_image.size()
1081
+ tile_latent_min_size = context_size // self.vae_scale_factor
1082
+ tile_overlap_factor = 0.5
1083
+ overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor))
1084
+ blend_extent = int(tile_latent_min_size * tile_overlap_factor)
1085
+ row_limit = tile_latent_min_size - blend_extent
1086
+ w = latents.shape[3]
1087
+ latents = torch.cat([latents, latents[:, :, :, :overlap_size]], dim=-1)
1088
+ control_image_extend = control_image[:, :, :, :overlap_size
1089
+ * self.vae_scale_factor]
1090
+ control_image = torch.cat([control_image, control_image_extend],
1091
+ dim=-1)
1092
+
1093
+ # 8. Denoising loop
1094
+ num_warmup_steps = len(
1095
+ timesteps) - num_inference_steps * self.scheduler.order
1096
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1097
+ for i, t in enumerate(timesteps):
1098
+ latents_input, control_image_input = self.get_blocks(
1099
+ latents, control_image, tile_latent_min_size, overlap_size)
1100
+ rows = []
1101
+ for latents_input_, control_image_input_ in zip(
1102
+ latents_input, control_image_input):
1103
+ num_block = len(latents_input_)
1104
+ # get batched latents_input
1105
+ latents_input_ = torch.cat(
1106
+ latents_input_[:num_block], dim=0)
1107
+ # get batched prompt_embeds
1108
+ prompt_embeds_ = torch.cat(
1109
+ [prompt_embeds.chunk(2)[0]] * num_block
1110
+ + [prompt_embeds.chunk(2)[1]] * num_block,
1111
+ dim=0)
1112
+ # get batched control_image_input
1113
+ control_image_input_ = torch.cat(
1114
+ [
1115
+ x[0, :, :, ][None, :, :, :]
1116
+ for x in control_image_input_[:num_block]
1117
+ ] + [
1118
+ x[1, :, :, ][None, :, :, :]
1119
+ for x in control_image_input_[:num_block]
1120
+ ],
1121
+ dim=0)
1122
+ latents_output = self.denoise_latents(
1123
+ latents_input_, t, prompt_embeds_,
1124
+ control_image_input_, controlnet_conditioning_scale,
1125
+ guess_mode, cross_attention_kwargs,
1126
+ do_classifier_free_guidance, guidance_scale,
1127
+ extra_step_kwargs, views_scheduler_status)
1128
+ rows.append(list(latents_output.chunk(num_block)))
1129
+ result_rows = []
1130
+ for i, row in enumerate(rows):
1131
+ result_row = []
1132
+ for j, tile in enumerate(row):
1133
+ # blend the above tile and the left tile
1134
+ # to the current tile and add the current tile to the result row
1135
+ if i > 0:
1136
+ tile = self.blend_v(rows[i - 1][j], tile,
1137
+ blend_extent)
1138
+ if j > 0:
1139
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
1140
+ if j == 0:
1141
+ tile = self.blend_h(row[-1], tile, blend_extent)
1142
+ if i != len(rows) - 1:
1143
+ if j == len(row) - 1:
1144
+ result_row.append(tile[:, :, :row_limit, :])
1145
+ else:
1146
+ result_row.append(
1147
+ tile[:, :, :row_limit, :row_limit])
1148
+ else:
1149
+ if j == len(row) - 1:
1150
+ result_row.append(tile[:, :, :, :])
1151
+ else:
1152
+ result_row.append(tile[:, :, :, :row_limit])
1153
+ result_rows.append(torch.cat(result_row, dim=3))
1154
+ latents = torch.cat(result_rows, dim=2)
1155
+
1156
+ # call the callback, if provided
1157
+ condition_i = i == len(timesteps) - 1
1158
+ condition_warm = (i + 1) > num_warmup_steps and (
1159
+ i + 1) % self.scheduler.order == 0
1160
+ if condition_i or condition_warm:
1161
+ progress_bar.update()
1162
+ if callback is not None and i % callback_steps == 0:
1163
+ callback(i, t, latents)
1164
+ latents = latents[:, :, :, :w]
1165
+
1166
+ # If we do sequential model offloading, let's offload unet and controlnet
1167
+ # manually for max memory savings
1168
+ if hasattr(
1169
+ self,
1170
+ 'final_offload_hook') and self.final_offload_hook is not None:
1171
+ self.unet.to('cpu')
1172
+ self.controlnet.to('cpu')
1173
+ torch.cuda.empty_cache()
1174
+
1175
+ if not output_type == 'latent':
1176
+ image = self.vae.decode(
1177
+ latents / self.vae.config.scaling_factor, return_dict=False)[0]
1178
+ image, has_nsfw_concept = self.run_safety_checker(
1179
+ image, device, prompt_embeds.dtype)
1180
+ else:
1181
+ image = latents
1182
+ has_nsfw_concept = None
1183
+
1184
+ if has_nsfw_concept is None:
1185
+ do_denormalize = [True] * image.shape[0]
1186
+ else:
1187
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1188
+
1189
+ image = self.image_processor.postprocess(
1190
+ image, output_type=output_type, do_denormalize=do_denormalize)
1191
+
1192
+ # Offload last model to CPU
1193
+ if hasattr(
1194
+ self,
1195
+ 'final_offload_hook') and self.final_offload_hook is not None:
1196
+ self.final_offload_hook.offload()
1197
+
1198
+ if not return_dict:
1199
+ return (image, has_nsfw_concept)
1200
+
1201
+ return StableDiffusionPipelineOutput(
1202
+ images=image, nsfw_content_detected=has_nsfw_concept)
txt2panoimg/text_to_360panorama_image_pipeline.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © Alibaba, Inc. and its affiliates.
2
+ import random
3
+ from typing import Any, Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import (ControlNetModel, DiffusionPipeline,
8
+ EulerAncestralDiscreteScheduler,
9
+ UniPCMultistepScheduler)
10
+ from PIL import Image
11
+ from RealESRGAN import RealESRGAN
12
+
13
+ from .pipeline_base import StableDiffusionBlendExtendPipeline
14
+ from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
15
+
16
+ class LazyRealESRGAN:
17
+ def __init__(self, device, scale):
18
+ self.device = device
19
+ self.scale = scale
20
+ self.model = None
21
+ self.model_path = None
22
+
23
+ def load_model(self):
24
+ if self.model is None:
25
+ self.model = RealESRGAN(self.device, scale=self.scale)
26
+ self.model.load_weights(self.model_path, download=False)
27
+
28
+ def predict(self, img):
29
+ self.load_model()
30
+ return self.model.predict(img)
31
+
32
+ class Text2360PanoramaImagePipeline(DiffusionPipeline):
33
+ """ Stable Diffusion for 360 Panorama Image Generation Pipeline.
34
+ Example:
35
+ >>> import torch
36
+ >>> from txt2panoimg import Text2360PanoramaImagePipeline
37
+ >>> prompt = 'The mountains'
38
+ >>> input = {'prompt': prompt, 'upscale': True}
39
+ >>> model_id = 'models/'
40
+ >>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16)
41
+ >>> output = txt2panoimg(input)
42
+ >>> output.save('result.png')
43
+ """
44
+
45
+ def __init__(self, model: str, device: str = 'cuda', **kwargs):
46
+ """
47
+ Use `model` to create a stable diffusion pipeline for 360 panorama image generation.
48
+ Args:
49
+ model: model id on modelscope hub.
50
+ device: str = 'cuda'
51
+ """
52
+ super().__init__()
53
+
54
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
55
+ ) if device is None else device
56
+ if device == 'gpu':
57
+ device = torch.device('cuda')
58
+
59
+ torch_dtype = kwargs.get('torch_dtype', torch.float16)
60
+ enable_xformers_memory_efficient_attention = kwargs.get(
61
+ 'enable_xformers_memory_efficient_attention', True)
62
+
63
+ model_id = model + '/sd-base/'
64
+
65
+ # init base model
66
+ self.pipe = StableDiffusionBlendExtendPipeline.from_pretrained(
67
+ model_id, torch_dtype=torch_dtype).to(device)
68
+ self.pipe.vae.enable_tiling()
69
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
70
+ self.pipe.scheduler.config)
71
+ # remove following line if xformers is not installed
72
+ try:
73
+ if enable_xformers_memory_efficient_attention:
74
+ self.pipe.enable_xformers_memory_efficient_attention()
75
+ except Exception as e:
76
+ print(e)
77
+ self.pipe.enable_model_cpu_offload()
78
+
79
+ # init controlnet-sr model
80
+ base_model_path = model + '/sr-base'
81
+ controlnet_path = model + '/sr-control'
82
+ controlnet = ControlNetModel.from_pretrained(
83
+ controlnet_path, torch_dtype=torch_dtype)
84
+ self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(
85
+ base_model_path, controlnet=controlnet,
86
+ torch_dtype=torch_dtype).to(device)
87
+ self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config(
88
+ self.pipe.scheduler.config)
89
+ self.pipe_sr.vae.enable_tiling()
90
+ # remove following line if xformers is not installed
91
+ try:
92
+ if enable_xformers_memory_efficient_attention:
93
+ self.pipe_sr.enable_xformers_memory_efficient_attention()
94
+ except Exception as e:
95
+ print(e)
96
+ self.pipe_sr.enable_model_cpu_offload()
97
+ device = torch.device("cuda")
98
+ model_path = model + '/RealESRGAN_x2plus.pth'
99
+ self.upsampler = LazyRealESRGAN(device=device, scale=2)
100
+ self.upsampler.model_path = model_path
101
+
102
+ @staticmethod
103
+ def blend_h(a, b, blend_extent):
104
+ a = np.array(a)
105
+ b = np.array(b)
106
+ blend_extent = min(a.shape[1], b.shape[1], blend_extent)
107
+ for x in range(blend_extent):
108
+ b[:, x, :] = a[:, -blend_extent
109
+ + x, :] * (1 - x / blend_extent) + b[:, x, :] * (
110
+ x / blend_extent)
111
+ return b
112
+
113
+ def __call__(self, inputs: Dict[str, Any],
114
+ **forward_params) -> Dict[str, Any]:
115
+ if not isinstance(inputs, dict):
116
+ raise ValueError(
117
+ f'Expected the input to be a dictionary, but got {type(input)}'
118
+ )
119
+ num_inference_steps = inputs.get('num_inference_steps', 20)
120
+ guidance_scale = inputs.get('guidance_scale', 7.5)
121
+ preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))'
122
+ add_prompt = inputs.get('add_prompt', preset_a_prompt)
123
+ preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\
124
+ 'low quality, zombie, logo, text, watermark, username, monochrome, '\
125
+ 'complex lighting'
126
+ negative_prompt = inputs.get('negative_prompt', preset_n_prompt)
127
+ seed = inputs.get('seed', -1)
128
+ upscale = inputs.get('upscale', True)
129
+ refinement = inputs.get('refinement', True)
130
+
131
+ guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15)
132
+ guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17)
133
+
134
+ if 'prompt' in inputs.keys():
135
+ prompt = inputs['prompt']
136
+ else:
137
+ # for demo_service
138
+ prompt = forward_params.get('prompt', 'the living room')
139
+
140
+ print(f'Test with prompt: {prompt}')
141
+
142
+ if seed == -1:
143
+ seed = random.randint(0, 65535)
144
+ print(f'global seed: {seed}')
145
+
146
+ generator = torch.manual_seed(seed)
147
+
148
+ prompt = '<360panorama>, ' + prompt + ', ' + add_prompt
149
+ output_img = self.pipe(
150
+ prompt,
151
+ negative_prompt=negative_prompt,
152
+ num_inference_steps=num_inference_steps,
153
+ height=512,
154
+ width=1024,
155
+ guidance_scale=guidance_scale,
156
+ generator=generator).images[0]
157
+
158
+ if not upscale:
159
+ print('finished')
160
+ else:
161
+ print('inputs: upscale=True, running upscaler.')
162
+ print('running upscaler step1. Initial super-resolution')
163
+ sr_scale = 2.0
164
+ output_img = self.pipe_sr(
165
+ prompt.replace('<360panorama>, ', ''),
166
+ negative_prompt=negative_prompt,
167
+ image=output_img.resize(
168
+ (int(1536 * sr_scale), int(768 * sr_scale))),
169
+ num_inference_steps=7,
170
+ generator=generator,
171
+ control_image=output_img.resize(
172
+ (int(1536 * sr_scale), int(768 * sr_scale))),
173
+ strength=0.8,
174
+ controlnet_conditioning_scale=1.0,
175
+ guidance_scale=guidance_scale_sr_step1,
176
+ ).images[0]
177
+
178
+ print('running upscaler step2. Super-resolution with Real-ESRGAN')
179
+ output_img = output_img.resize((1536 * 2, 768 * 2))
180
+ w = output_img.size[0]
181
+ blend_extend = 10
182
+ outscale = 2
183
+ output_img = np.array(output_img)
184
+ output_img = np.concatenate(
185
+ [output_img, output_img[:, :blend_extend, :]], axis=1)
186
+ output_img = self.upsampler.predict(
187
+ output_img)
188
+ output_img = self.blend_h(output_img, output_img,
189
+ blend_extend * outscale)
190
+ output_img = Image.fromarray(output_img[:, :w * outscale, :])
191
+
192
+ if refinement:
193
+ print(
194
+ 'inputs: refinement=True, running refinement. This is a bit time-consuming.'
195
+ )
196
+ sr_scale = 4
197
+ output_img = self.pipe_sr(
198
+ prompt.replace('<360panorama>, ', ''),
199
+ negative_prompt=negative_prompt,
200
+ image=output_img.resize(
201
+ (int(1536 * sr_scale), int(768 * sr_scale))),
202
+ num_inference_steps=7,
203
+ generator=generator,
204
+ control_image=output_img.resize(
205
+ (int(1536 * sr_scale), int(768 * sr_scale))),
206
+ strength=0.8,
207
+ controlnet_conditioning_scale=1.0,
208
+ guidance_scale=guidance_scale_sr_step2,
209
+ ).images[0]
210
+ print('finished')
211
+
212
+ return output_img