Spaces:
Paused
Paused
shariqfarooq
commited on
Commit
•
13f1a87
1
Parent(s):
5a85f92
add demo
Browse files- app.py +125 -0
- cross_frame_attention.py +120 -0
- loosecontrol.py +135 -0
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import PIL
|
4 |
+
import PIL.Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from gradio_editor3d import Editor3D as g3deditor
|
9 |
+
import copy
|
10 |
+
from loosecontrol import LooseControlNet
|
11 |
+
|
12 |
+
|
13 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
14 |
+
cn = LooseControlNet()
|
15 |
+
cn.pipe = cn.pipe.to(torch_device=device, torch_dtype=torch.float16)
|
16 |
+
|
17 |
+
# Need to figure out a better way how to do this per user, making 'cf attention' act like a state per user.
|
18 |
+
# For now, we just copy the model.
|
19 |
+
cn_with_cf = copy.deepcopy(cn)
|
20 |
+
cn_with_cf.set_cf_attention()
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class FixedInputs:
|
25 |
+
prompt: str
|
26 |
+
seed: int
|
27 |
+
depth: PIL.Image.Image
|
28 |
+
|
29 |
+
|
30 |
+
negative_prompt = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
|
31 |
+
def depth2image(prompt, seed, depth):
|
32 |
+
seed = int(seed)
|
33 |
+
gen = cn(prompt, control_image=depth, controlnet_conditioning_scale=1.0, generator=torch.Generator().manual_seed(seed), num_inference_steps=20, negative_prompt=negative_prompt)
|
34 |
+
return gen
|
35 |
+
|
36 |
+
def edit_previous(prompt, seed, depth, fixed_inputs):
|
37 |
+
seed = int(seed)
|
38 |
+
control_image = [fixed_inputs.depth, depth]
|
39 |
+
prompt = [fixed_inputs.prompt, prompt]
|
40 |
+
neg_prompt = [negative_prompt, negative_prompt]
|
41 |
+
generator = [torch.Generator().manual_seed(fixed_inputs.seed), torch.Generator().manual_seed(seed)]
|
42 |
+
gen = cn_with_cf(prompt, control_image=control_image, controlnet_conditioning_scale=1.0, generator=generator, num_inference_steps=20, negative_prompt=neg_prompt)[-1]
|
43 |
+
return gen
|
44 |
+
|
45 |
+
def run(prompt, seed, depth, should_edit, fixed_inputs):
|
46 |
+
depth = depth.convert("RGB")
|
47 |
+
# all values below [3,3,3] in depth should actually be set to [255,255,255]
|
48 |
+
# This is to due the nature of training data and is experimental right now.
|
49 |
+
# Not in use for now.
|
50 |
+
# depth = np.array(depth)
|
51 |
+
# depth[depth < 3] = 255
|
52 |
+
# depth = PIL.Image.fromarray(depth)
|
53 |
+
|
54 |
+
fixed_inputs = fixed_inputs[0]
|
55 |
+
if should_edit and fixed_inputs is not None:
|
56 |
+
return edit_previous(prompt, seed, depth, fixed_inputs)
|
57 |
+
else:
|
58 |
+
return depth2image(prompt, seed, depth)
|
59 |
+
|
60 |
+
def handle_edit_change(edit, prompt, seed, image_input, fixed_inputs):
|
61 |
+
if edit:
|
62 |
+
fixed_inputs[0] = FixedInputs(prompt, int(seed), image_input)
|
63 |
+
else:
|
64 |
+
fixed_inputs[0] = None
|
65 |
+
return fixed_inputs
|
66 |
+
|
67 |
+
|
68 |
+
css = """
|
69 |
+
|
70 |
+
#image_output {
|
71 |
+
width: 512px;
|
72 |
+
height: 512px;
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
main_description = """
|
77 |
+
# LooseControl
|
78 |
+
|
79 |
+
This is the official demo for the paper [LooseControl: Lifting ControlNet for Generalized Depth Conditioning](https://shariqfarooq123.github.io/loose-control/).
|
80 |
+
Our 3D Box Editing allows users to interactively edit the 3D boxes representing objects in the scene. Users can change the position, size, and orientation of 3D boxes, allowing to quickly create and edit the scenes to their liking in a 3D-aware manner.
|
81 |
+
Best viewed on desktop.
|
82 |
+
"""
|
83 |
+
|
84 |
+
instructions_editor3d = """
|
85 |
+
## Instructions for Editor3D UI
|
86 |
+
- Use 'WASD' keys to move the camera.
|
87 |
+
- Click on an object to select it.
|
88 |
+
- Use the sliders to change the position, size, and orientation of the selected object. Sliders support click and drag for faster editing.
|
89 |
+
- Use the 'Add Box', 'Delete', and 'Duplicate' buttons to add, delete, and duplicate objects.
|
90 |
+
- Delete and Duplicate buttons work on the selected object. Duplicate creates a copy and selects it.
|
91 |
+
- Use the 'Toggle Mode' to switch between "normal" and "depth" mode. Final image sent to the model should be in "depth" mode.
|
92 |
+
- Use the 'Render' button to render the scene and send it to the model for generation.
|
93 |
+
|
94 |
+
### Lock style checkbox - Fixes the style of the latest generated image.
|
95 |
+
This allows users to edit the 3D boxes without changing the style of the generated image. This is useful when the user is satisfied with the style/content of the generated image and wants to edit the 3D boxes without changing the overall essence of the scene.
|
96 |
+
It can be used to create stop motion videos like those shown [here](https://shariqfarooq123.github.io/loose-control/).
|
97 |
+
|
98 |
+
"""
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
with gr.Blocks(css=css) as demo:
|
103 |
+
gr.Markdown(main_description)
|
104 |
+
|
105 |
+
fixed_inputs = gr.State([None])
|
106 |
+
with gr.Row():
|
107 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Write your prompt", elem_id="input")
|
108 |
+
seed = gr.Textbox(value=42, label="Seed", elem_id="seed")
|
109 |
+
should_edit = gr.Checkbox(label="Lock style", elem_id="edit")
|
110 |
+
|
111 |
+
with gr.Row():
|
112 |
+
image_input = g3deditor(elem_id="image_input")
|
113 |
+
|
114 |
+
with gr.Row():
|
115 |
+
image_output = gr.Image(elem_id="image_output", type='pil')
|
116 |
+
|
117 |
+
should_edit.change(fn=handle_edit_change, inputs=[should_edit, prompt, seed, image_input, fixed_inputs], outputs=[fixed_inputs])
|
118 |
+
image_input.change(fn=run, inputs=[prompt, seed, image_input, should_edit, fixed_inputs], outputs=[image_output])
|
119 |
+
with gr.Accordion("Instructions"):
|
120 |
+
gr.Markdown(instructions_editor3d)
|
121 |
+
|
122 |
+
demo.queue().launch()
|
123 |
+
|
124 |
+
|
125 |
+
|
cross_frame_attention.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
class CrossFrameAttnProcessor:
|
5 |
+
def __init__(self, unet_chunk_size=2):
|
6 |
+
self.unet_chunk_size = unet_chunk_size
|
7 |
+
|
8 |
+
def __call__(
|
9 |
+
self,
|
10 |
+
attn,
|
11 |
+
hidden_states,
|
12 |
+
encoder_hidden_states=None,
|
13 |
+
attention_mask=None, **kwargs):
|
14 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
15 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
16 |
+
query = attn.to_q(hidden_states)
|
17 |
+
|
18 |
+
is_cross_attention = encoder_hidden_states is not None
|
19 |
+
if encoder_hidden_states is None:
|
20 |
+
encoder_hidden_states = hidden_states
|
21 |
+
elif attn.norm_cross:
|
22 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
23 |
+
key = attn.to_k(encoder_hidden_states)
|
24 |
+
value = attn.to_v(encoder_hidden_states)
|
25 |
+
# Sparse Attention
|
26 |
+
if not is_cross_attention:
|
27 |
+
video_length = key.size()[0] // self.unet_chunk_size
|
28 |
+
# print("Video length is", video_length)
|
29 |
+
# former_frame_index = torch.arange(video_length) - 1
|
30 |
+
# former_frame_index[0] = 0
|
31 |
+
former_frame_index = [0] * video_length
|
32 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
|
33 |
+
key = key[:, former_frame_index]
|
34 |
+
key = rearrange(key, "b f d c -> (b f) d c")
|
35 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
|
36 |
+
value = value[:, former_frame_index]
|
37 |
+
value = rearrange(value, "b f d c -> (b f) d c")
|
38 |
+
|
39 |
+
query = attn.head_to_batch_dim(query)
|
40 |
+
key = attn.head_to_batch_dim(key)
|
41 |
+
value = attn.head_to_batch_dim(value)
|
42 |
+
|
43 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
44 |
+
hidden_states = torch.bmm(attention_probs, value)
|
45 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
46 |
+
|
47 |
+
# linear proj
|
48 |
+
hidden_states = attn.to_out[0](hidden_states)
|
49 |
+
# dropout
|
50 |
+
hidden_states = attn.to_out[1](hidden_states)
|
51 |
+
|
52 |
+
return hidden_states
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class AttnProcessorX:
|
57 |
+
r"""
|
58 |
+
Default processor for performing attention-related computations.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __call__(
|
62 |
+
self,
|
63 |
+
attn,
|
64 |
+
hidden_states,
|
65 |
+
encoder_hidden_states=None,
|
66 |
+
attention_mask=None,
|
67 |
+
temb=None,
|
68 |
+
scale=1.0,
|
69 |
+
):
|
70 |
+
residual = hidden_states
|
71 |
+
|
72 |
+
if attn.spatial_norm is not None:
|
73 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
74 |
+
|
75 |
+
input_ndim = hidden_states.ndim
|
76 |
+
|
77 |
+
if input_ndim == 4:
|
78 |
+
batch_size, channel, height, width = hidden_states.shape
|
79 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
80 |
+
|
81 |
+
batch_size, sequence_length, _ = (
|
82 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
83 |
+
)
|
84 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
85 |
+
|
86 |
+
if attn.group_norm is not None:
|
87 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
88 |
+
|
89 |
+
query = attn.to_q(hidden_states, scale=scale)
|
90 |
+
|
91 |
+
if encoder_hidden_states is None:
|
92 |
+
encoder_hidden_states = hidden_states
|
93 |
+
elif attn.norm_cross:
|
94 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
95 |
+
|
96 |
+
key = attn.to_k(encoder_hidden_states, scale=scale)
|
97 |
+
value = attn.to_v(encoder_hidden_states, scale=scale)
|
98 |
+
|
99 |
+
query = attn.head_to_batch_dim(query)
|
100 |
+
key = attn.head_to_batch_dim(key)
|
101 |
+
value = attn.head_to_batch_dim(value)
|
102 |
+
|
103 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
104 |
+
hidden_states = torch.bmm(attention_probs, value)
|
105 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
106 |
+
|
107 |
+
# linear proj
|
108 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
109 |
+
# dropout
|
110 |
+
hidden_states = attn.to_out[1](hidden_states)
|
111 |
+
|
112 |
+
if input_ndim == 4:
|
113 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
114 |
+
|
115 |
+
if attn.residual_connection:
|
116 |
+
hidden_states = hidden_states + residual
|
117 |
+
|
118 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
119 |
+
|
120 |
+
return hidden_states
|
loosecontrol.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import (
|
2 |
+
ControlNetModel,
|
3 |
+
StableDiffusionControlNetPipeline,
|
4 |
+
UniPCMultistepScheduler,
|
5 |
+
)
|
6 |
+
import torch
|
7 |
+
import PIL
|
8 |
+
import PIL.Image
|
9 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
10 |
+
from typing import Dict
|
11 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
12 |
+
import functools
|
13 |
+
from cross_frame_attention import CrossFrameAttnProcessor
|
14 |
+
|
15 |
+
TEXT_ENCODER_NAME = "text_encoder"
|
16 |
+
UNET_NAME = "unet"
|
17 |
+
NEGATIVE_PROMPT = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
|
18 |
+
|
19 |
+
def attach_loaders_mixin(model):
|
20 |
+
# hacky way to make ControlNet work with LoRA. This may not be required in future versions of diffusers.
|
21 |
+
model.text_encoder_name = TEXT_ENCODER_NAME
|
22 |
+
model.unet_name = UNET_NAME
|
23 |
+
r"""
|
24 |
+
Attach the [`UNet2DConditionLoadersMixin`] to a model. This will add the
|
25 |
+
all the methods from the mixin 'UNet2DConditionLoadersMixin' to the model.
|
26 |
+
"""
|
27 |
+
# mixin_instance = UNet2DConditionLoadersMixin()
|
28 |
+
for attr_name, attr_value in vars(UNet2DConditionLoadersMixin).items():
|
29 |
+
# print(attr_name)
|
30 |
+
if callable(attr_value):
|
31 |
+
# setattr(model, attr_name, functools.partialmethod(attr_value, model).__get__(model, model.__class__))
|
32 |
+
setattr(model, attr_name, functools.partial(attr_value, model))
|
33 |
+
return model
|
34 |
+
|
35 |
+
def set_attn_processor(module, processor, _remove_lora=False):
|
36 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
37 |
+
if hasattr(module, "set_processor"):
|
38 |
+
if not isinstance(processor, dict):
|
39 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
40 |
+
else:
|
41 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
42 |
+
|
43 |
+
for sub_name, child in module.named_children():
|
44 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
45 |
+
|
46 |
+
for name, module in module.named_children():
|
47 |
+
fn_recursive_attn_processor(name, module, processor)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
class ControlNetX(ControlNetModel, UNet2DConditionLoadersMixin):
|
52 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
53 |
+
# This may not be required in future versions of diffusers.
|
54 |
+
@property
|
55 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
56 |
+
r"""
|
57 |
+
Returns:
|
58 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
59 |
+
indexed by its weight name.
|
60 |
+
"""
|
61 |
+
# set recursively
|
62 |
+
processors = {}
|
63 |
+
|
64 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
65 |
+
if hasattr(module, "get_processor"):
|
66 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
67 |
+
|
68 |
+
for sub_name, child in module.named_children():
|
69 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
70 |
+
|
71 |
+
return processors
|
72 |
+
|
73 |
+
for name, module in self.named_children():
|
74 |
+
fn_recursive_add_processors(name, module, processors)
|
75 |
+
|
76 |
+
return processors
|
77 |
+
|
78 |
+
class ControlNetPipeline:
|
79 |
+
def __init__(self, checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
|
80 |
+
controlnet = ControlNetX.from_pretrained(checkpoint)
|
81 |
+
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
82 |
+
sd_checkpoint, controlnet=controlnet, requires_safety_checker=False, safety_checker=None,
|
83 |
+
torch_dtype=torch.float16)
|
84 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def __call__(self,
|
88 |
+
prompt: str="",
|
89 |
+
height=512,
|
90 |
+
width=512,
|
91 |
+
control_image=None,
|
92 |
+
controlnet_conditioning_scale=1.0,
|
93 |
+
num_inference_steps: int=20,
|
94 |
+
**kwargs) -> PIL.Image.Image:
|
95 |
+
|
96 |
+
out = self.pipe(prompt, control_image,
|
97 |
+
height=height, width=width,
|
98 |
+
num_inference_steps=num_inference_steps,
|
99 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
100 |
+
**kwargs).images
|
101 |
+
|
102 |
+
return out[0] if len(out) == 1 else out
|
103 |
+
|
104 |
+
def to(self, *args, **kwargs):
|
105 |
+
self.pipe.to(*args, **kwargs)
|
106 |
+
return self
|
107 |
+
|
108 |
+
|
109 |
+
class LooseControlNet(ControlNetPipeline):
|
110 |
+
def __init__(self, loose_control_weights="shariqfarooq/loose-control-3dbox", cn_checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
|
111 |
+
super().__init__(cn_checkpoint, sd_checkpoint)
|
112 |
+
self.pipe.controlnet = attach_loaders_mixin(self.pipe.controlnet)
|
113 |
+
self.pipe.controlnet.load_attn_procs(loose_control_weights)
|
114 |
+
|
115 |
+
def set_normal_attention(self):
|
116 |
+
self.pipe.unet.set_attn_processor(AttnProcessor())
|
117 |
+
|
118 |
+
def set_cf_attention(self, _remove_lora=False):
|
119 |
+
for upblocks in self.pipe.unet.up_blocks[-2:]:
|
120 |
+
set_attn_processor(upblocks, CrossFrameAttnProcessor(), _remove_lora=_remove_lora)
|
121 |
+
|
122 |
+
def edit(self, depth, depth_edit, prompt, prompt_edit=None, seed=42, seed_edit=None, negative_prompt=NEGATIVE_PROMPT, controlnet_conditioning_scale=1.0, num_inference_steps=20, **kwargs):
|
123 |
+
if prompt_edit is None:
|
124 |
+
prompt_edit = prompt
|
125 |
+
|
126 |
+
if seed_edit is None:
|
127 |
+
seed_edit = seed
|
128 |
+
|
129 |
+
seed = int(seed)
|
130 |
+
seed_edit = int(seed_edit)
|
131 |
+
control_image = [depth, depth_edit]
|
132 |
+
prompt = [prompt, prompt_edit]
|
133 |
+
generator = [torch.Generator().manual_seed(seed), torch.Generator().manual_seed(seed_edit)]
|
134 |
+
gen = self.pipe(prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=negative_prompt, **kwargs)[-1]
|
135 |
+
return gen
|