Deadmon commited on
Commit
e794576
1 Parent(s): 4fd60a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -123
app.py CHANGED
@@ -1,132 +1,30 @@
1
- import os
2
  import torch
3
- import gc
4
- import gradio as gr
5
- import numpy as np
6
- from PIL import Image
7
- from einops import rearrange
8
- import io
9
- import requests
10
- import spaces
11
- from huggingface_hub import login
12
- from gradio_imageslider import ImageSlider
13
  from diffusers.utils import load_image
14
  from diffusers import FluxControlNetPipeline, FluxControlNetModel
15
 
16
- # Device settings: CPU for loading, GPU for inference
17
- device_cpu = torch.device("cpu")
18
- device_gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
-
20
- # Model identifiers
21
  base_model = 'black-forest-labs/FLUX.1-dev'
22
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
23
 
24
- # Load the ControlNet model and pipeline on CPU
25
- controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16).to(device_cpu)
26
- pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device_cpu)
27
 
 
28
  controlnet_conditioning_scale = 0.5
29
-
30
- control_modes = {
31
- "canny": 0,
32
- "tile": 1,
33
- "depth": 2,
34
- "blur": 3,
35
- "pose": 4,
36
- "gray": 5,
37
- "lq": 6,
38
- }
39
-
40
- def load_and_convert_image(image):
41
- """Load and convert images to a format that PIL can handle."""
42
- if isinstance(image, str):
43
- image = Image.open(image)
44
- elif isinstance(image, bytes):
45
- image = Image.open(io.BytesIO(image))
46
- # Convert AVIF to PNG if necessary
47
- if image.format == 'AVIF':
48
- image = image.convert("RGB") # Convert to a format PIL can handle
49
- return image
50
-
51
- def preprocess_image(image, target_width, target_height, crop=True):
52
- """Preprocess image to match the target dimensions."""
53
- image = load_and_convert_image(image)
54
- if crop:
55
- original_width, original_height = image.size
56
-
57
- # Resize to match the target size without stretching
58
- scale = max(target_width / original_width, target_height / original_height)
59
- resized_width = int(scale * original_width)
60
- resized_height = int(scale * original_height)
61
-
62
- image = image.resize((resized_width, resized_height), Image.LANCZOS)
63
-
64
- # Center crop to match the target dimensions
65
- left = (resized_width - target_width) // 2
66
- top = (resized_height - target_height) // 2
67
- image = image.crop((left, top, left + target_width, top + target_height))
68
- else:
69
- image = image.resize((target_width, target_height), Image.LANCZOS)
70
-
71
- return image
72
-
73
- def clear_cuda_memory():
74
- """Clear CUDA memory."""
75
- gc.collect()
76
- torch.cuda.empty_cache()
77
- torch.cuda.ipc_collect()
78
-
79
- @spaces.GPU(duration=120)
80
- def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4, width=512, height=512, seed=42, random_seed=False):
81
- """Generate image using the FLUX.1 ControlNet model."""
82
- clear_cuda_memory()
83
-
84
- if random_seed:
85
- seed = np.random.randint(0, 10000)
86
-
87
- if not os.path.isdir("./controlnet_results/"):
88
- os.makedirs("./controlnet_results/")
89
-
90
- # Move model to GPU for inference
91
- pipe.to(device_gpu)
92
-
93
- control_image = preprocess_image(control_image, width, height)
94
-
95
- torch.manual_seed(seed)
96
- with torch.no_grad():
97
- image = pipe(
98
- prompt,
99
- control_image=control_image,
100
- control_mode=control_modes[control_mode],
101
- width=width,
102
- height=height,
103
- controlnet_conditioning_scale=controlnet_conditioning_scale,
104
- num_inference_steps=num_steps,
105
- guidance_scale=guidance,
106
- ).images[0]
107
-
108
- # Move model back to CPU after inference
109
- pipe.to(device_cpu)
110
-
111
- return [control_image, image] # Return both images for slider
112
-
113
- interface = gr.Interface(
114
- fn=generate_image,
115
- inputs=[
116
- gr.Textbox(label="Prompt"),
117
- gr.Image(type="pil", label="Control Image"),
118
- gr.Dropdown(choices=list(control_modes.keys()), label="Control Mode", value="canny"),
119
- gr.Slider(step=1, minimum=1, maximum=64, value=28, label="Num Steps"),
120
- gr.Slider(minimum=0.1, maximum=10, value=4, label="Guidance"),
121
- gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Width"),
122
- gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Height"),
123
- gr.Number(value=42, label="Seed"),
124
- gr.Checkbox(label="Random Seed")
125
- ],
126
- outputs=ImageSlider(label="Before / After"), # Use ImageSlider as the output
127
- title="FLUX.1 Controlnet Canny",
128
- description="Generate images using ControlNet and a text prompt.\n[[non-commercial license, Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]"
129
- )
130
-
131
- if __name__ == "__main__":
132
- interface.launch(share=True)
 
 
1
  import torch
 
 
 
 
 
 
 
 
 
 
2
  from diffusers.utils import load_image
3
  from diffusers import FluxControlNetPipeline, FluxControlNetModel
4
 
 
 
 
 
 
5
  base_model = 'black-forest-labs/FLUX.1-dev'
6
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
7
 
8
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
9
+ pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
10
+ pipe.to("cuda")
11
 
12
+ control_image_canny = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg")
13
  controlnet_conditioning_scale = 0.5
14
+ control_mode = 0
15
+
16
+ width, height = control_image.size
17
+
18
+ prompt = 'A bohemian-style female travel blogger with sun-kissed skin and messy beach waves.'
19
+
20
+ image = pipe(
21
+ prompt,
22
+ control_image=control_image,
23
+ control_mode=control_mode,
24
+ width=width,
25
+ height=height,
26
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
27
+ num_inference_steps=24,
28
+ guidance_scale=3.5,
29
+ ).images[0]
30
+ image.save("image.jpg")