Commit
·
daf9c75
1
Parent(s):
07dc8e6
Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes
Browse files- modules/events/flux_events.py +9 -8
- modules/helpers/flux_helpers.py +2 -54
- modules/pipelines/common_pipelines.py +19 -0
- modules/pipelines/flux_pipelines.py +51 -12
- tabs/image_tab.py +2 -2
modules/events/flux_events.py
CHANGED
@@ -5,14 +5,15 @@ import spaces
|
|
5 |
import gradio as gr
|
6 |
from huggingface_hub import ModelCard
|
7 |
|
8 |
-
from modules.helpers.
|
9 |
-
from
|
|
|
10 |
|
11 |
loras = flux_loras
|
12 |
|
13 |
|
14 |
# Event functions
|
15 |
-
def update_fast_generation(
|
16 |
if fast_generation:
|
17 |
return (
|
18 |
gr.update(
|
@@ -125,7 +126,7 @@ def update_selected_lora(custom_lora):
|
|
125 |
)
|
126 |
|
127 |
|
128 |
-
def add_to_enabled_loras(
|
129 |
lora_data = loras
|
130 |
try:
|
131 |
selected_lora = int(selected_lora)
|
@@ -233,7 +234,7 @@ def generate_image(
|
|
233 |
"vae": vae,
|
234 |
"controlnet_config": None,
|
235 |
}
|
236 |
-
base_args =
|
237 |
|
238 |
if len(enabled_loras) > 0:
|
239 |
base_args.loras = []
|
@@ -252,7 +253,7 @@ def generate_image(
|
|
252 |
image = img2img_image
|
253 |
strength = float(img2img_strength)
|
254 |
|
255 |
-
base_args =
|
256 |
**base_args.__dict__,
|
257 |
image=image,
|
258 |
strength=strength
|
@@ -263,7 +264,7 @@ def generate_image(
|
|
263 |
strength = float(inpaint_strength)
|
264 |
|
265 |
if image and mask_image:
|
266 |
-
base_args =
|
267 |
**base_args.__dict__,
|
268 |
image=image,
|
269 |
mask_image=mask_image,
|
@@ -289,7 +290,7 @@ def generate_image(
|
|
289 |
base_args.controlnet_config.control_images.append(depth_image)
|
290 |
base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
|
291 |
else:
|
292 |
-
base_args =
|
293 |
|
294 |
return gr.update(
|
295 |
value=gen_img(base_args),
|
|
|
5 |
import gradio as gr
|
6 |
from huggingface_hub import ModelCard
|
7 |
|
8 |
+
from modules.helpers.common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq
|
9 |
+
from modules.helpers.flux_helpers import gen_img
|
10 |
+
from config import flux_loras
|
11 |
|
12 |
loras = flux_loras
|
13 |
|
14 |
|
15 |
# Event functions
|
16 |
+
def update_fast_generation(fast_generation):
|
17 |
if fast_generation:
|
18 |
return (
|
19 |
gr.update(
|
|
|
126 |
)
|
127 |
|
128 |
|
129 |
+
def add_to_enabled_loras(selected_lora, enabled_loras):
|
130 |
lora_data = loras
|
131 |
try:
|
132 |
selected_lora = int(selected_lora)
|
|
|
234 |
"vae": vae,
|
235 |
"controlnet_config": None,
|
236 |
}
|
237 |
+
base_args = BaseReq(**base_args)
|
238 |
|
239 |
if len(enabled_loras) > 0:
|
240 |
base_args.loras = []
|
|
|
253 |
image = img2img_image
|
254 |
strength = float(img2img_strength)
|
255 |
|
256 |
+
base_args = BaseImg2ImgReq(
|
257 |
**base_args.__dict__,
|
258 |
image=image,
|
259 |
strength=strength
|
|
|
264 |
strength = float(inpaint_strength)
|
265 |
|
266 |
if image and mask_image:
|
267 |
+
base_args = BaseInpaintReq(
|
268 |
**base_args.__dict__,
|
269 |
image=image,
|
270 |
mask_image=mask_image,
|
|
|
290 |
base_args.controlnet_config.control_images.append(depth_image)
|
291 |
base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
|
292 |
else:
|
293 |
+
base_args = BaseReq(**base_args.__dict__)
|
294 |
|
295 |
return gr.update(
|
296 |
value=gen_img(base_args),
|
modules/helpers/flux_helpers.py
CHANGED
@@ -6,10 +6,6 @@ from diffusers import (
|
|
6 |
AutoPipelineForText2Image,
|
7 |
AutoPipelineForImage2Image,
|
8 |
AutoPipelineForInpainting,
|
9 |
-
DiffusionPipeline,
|
10 |
-
AutoencoderKL,
|
11 |
-
FluxControlNetModel,
|
12 |
-
FluxMultiControlNetModel,
|
13 |
)
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
from diffusers.schedulers import *
|
@@ -17,56 +13,8 @@ from huggingface_hub import hf_hub_download
|
|
17 |
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
|
18 |
|
19 |
from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
|
20 |
-
|
21 |
-
|
22 |
-
def load_sd():
|
23 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
-
|
26 |
-
# Models
|
27 |
-
models = [
|
28 |
-
{
|
29 |
-
"repo_id": "black-forest-labs/FLUX.1-dev",
|
30 |
-
"loader": "flux",
|
31 |
-
"compute_type": torch.bfloat16,
|
32 |
-
}
|
33 |
-
]
|
34 |
-
|
35 |
-
for model in models:
|
36 |
-
try:
|
37 |
-
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
38 |
-
model['repo_id'],
|
39 |
-
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
40 |
-
torch_dtype=model['compute_type'],
|
41 |
-
safety_checker=None,
|
42 |
-
variant="fp16"
|
43 |
-
).to(device)
|
44 |
-
except:
|
45 |
-
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
46 |
-
model['repo_id'],
|
47 |
-
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
48 |
-
torch_dtype=model['compute_type'],
|
49 |
-
safety_checker=None
|
50 |
-
).to(device)
|
51 |
-
|
52 |
-
model["pipeline"].enable_model_cpu_offload()
|
53 |
-
|
54 |
-
# VAE n Refiner
|
55 |
-
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
|
56 |
-
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
57 |
-
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
58 |
-
refiner.enable_model_cpu_offload()
|
59 |
-
|
60 |
-
# ControlNet
|
61 |
-
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
|
62 |
-
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
63 |
-
torch_dtype=torch.bfloat16
|
64 |
-
).to(device)])
|
65 |
-
|
66 |
-
return device, models, flux_vae, sdxl_vae, refiner, controlnet
|
67 |
-
|
68 |
-
|
69 |
-
device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
|
70 |
|
71 |
|
72 |
def get_control_mode(controlnet_config: ControlNetReq):
|
|
|
6 |
AutoPipelineForText2Image,
|
7 |
AutoPipelineForImage2Image,
|
8 |
AutoPipelineForInpainting,
|
|
|
|
|
|
|
|
|
9 |
)
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from diffusers.schedulers import *
|
|
|
13 |
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
|
14 |
|
15 |
from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
|
16 |
+
from modules.pipelines.flux_pipelines import device, models, flux_vae, controlnet
|
17 |
+
from modules.pipelines.common_pipelines import refiner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
def get_control_mode(controlnet_config: ControlNetReq):
|
modules/pipelines/common_pipelines.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import (
|
3 |
+
DiffusionPipeline,
|
4 |
+
AutoencoderKL,
|
5 |
+
)
|
6 |
+
from diffusers.schedulers import *
|
7 |
+
|
8 |
+
|
9 |
+
def load_common():
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
# VAE n Refiner
|
13 |
+
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
14 |
+
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
15 |
+
refiner.enable_model_cpu_offload()
|
16 |
+
|
17 |
+
return refiner, sdxl_vae
|
18 |
+
|
19 |
+
refiner, sdxl_vae = load_common()
|
modules/pipelines/flux_pipelines.py
CHANGED
@@ -1,19 +1,58 @@
|
|
1 |
-
# modules/pipelines/flux_pipelines.py
|
2 |
|
3 |
import torch
|
4 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def load_flux():
|
7 |
-
#
|
8 |
-
|
9 |
-
return device, models, flux_vae, controlnet
|
10 |
|
11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
import torch
|
14 |
-
from diffusers import AutoPipelineForText2Image, AutoencoderKL
|
15 |
|
16 |
-
|
17 |
-
# Load SDXL models and pipelines
|
18 |
-
# ...
|
19 |
-
return device, models, sdxl_vae, controlnet
|
|
|
|
|
1 |
|
2 |
import torch
|
3 |
+
from diffusers import (
|
4 |
+
AutoPipelineForText2Image,
|
5 |
+
DiffusionPipeline,
|
6 |
+
AutoencoderKL,
|
7 |
+
FluxControlNetModel,
|
8 |
+
FluxMultiControlNetModel,
|
9 |
+
)
|
10 |
+
from diffusers.schedulers import *
|
11 |
+
|
12 |
+
|
13 |
|
14 |
def load_flux():
|
15 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
17 |
|
18 |
+
# Models
|
19 |
+
models = [
|
20 |
+
{
|
21 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
22 |
+
"loader": "flux",
|
23 |
+
"compute_type": torch.bfloat16,
|
24 |
+
}
|
25 |
+
]
|
26 |
+
|
27 |
+
for model in models:
|
28 |
+
try:
|
29 |
+
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
30 |
+
model['repo_id'],
|
31 |
+
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
32 |
+
torch_dtype=model['compute_type'],
|
33 |
+
safety_checker=None,
|
34 |
+
variant="fp16"
|
35 |
+
).to(device)
|
36 |
+
except:
|
37 |
+
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
38 |
+
model['repo_id'],
|
39 |
+
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
40 |
+
torch_dtype=model['compute_type'],
|
41 |
+
safety_checker=None
|
42 |
+
).to(device)
|
43 |
+
|
44 |
+
model["pipeline"].enable_model_cpu_offload()
|
45 |
+
|
46 |
+
# VAE n Refiner
|
47 |
+
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
|
48 |
+
|
49 |
+
# ControlNet
|
50 |
+
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
|
51 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
52 |
+
torch_dtype=torch.bfloat16
|
53 |
+
).to(device)])
|
54 |
+
|
55 |
+
return device, models, flux_vae, controlnet
|
56 |
|
|
|
|
|
57 |
|
58 |
+
device, models, flux_vae, controlnet = load_flux()
|
|
|
|
|
|
tabs/image_tab.py
CHANGED
@@ -144,13 +144,13 @@ def flux_tab():
|
|
144 |
|
145 |
# Events
|
146 |
# Base Options
|
147 |
-
fast_generation.change(update_fast_generation, [
|
148 |
|
149 |
|
150 |
# Lora Gallery
|
151 |
lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
|
152 |
custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
|
153 |
-
add_lora.click(add_to_enabled_loras, [
|
154 |
enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
|
155 |
|
156 |
for i in range(6):
|
|
|
144 |
|
145 |
# Events
|
146 |
# Base Options
|
147 |
+
fast_generation.change(update_fast_generation, [fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
|
148 |
|
149 |
|
150 |
# Lora Gallery
|
151 |
lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
|
152 |
custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
|
153 |
+
add_lora.click(add_to_enabled_loras, [selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
|
154 |
enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
|
155 |
|
156 |
for i in range(6):
|