Move compile out of FluxPipeline init
Browse files- flux_pipeline.py +81 -57
flux_pipeline.py
CHANGED
@@ -31,6 +31,7 @@ from torchvision.transforms import functional as TF
|
|
31 |
from tqdm import tqdm
|
32 |
from util import (
|
33 |
ModelSpec,
|
|
|
34 |
into_device,
|
35 |
into_dtype,
|
36 |
load_config_from_path,
|
@@ -80,29 +81,17 @@ class FluxPipeline:
|
|
80 |
This class is responsible for preparing input tensors for the Flux model, generating
|
81 |
timesteps and noise, and handling device management for model offloading.
|
82 |
"""
|
|
|
|
|
|
|
|
|
83 |
self.debug = debug
|
84 |
self.name = name
|
85 |
-
self.device_flux = (
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
)
|
90 |
-
self.device_ae = (
|
91 |
-
ae_device
|
92 |
-
if isinstance(ae_device, torch.device)
|
93 |
-
else torch.device(ae_device)
|
94 |
-
)
|
95 |
-
self.device_clip = (
|
96 |
-
clip_device
|
97 |
-
if isinstance(clip_device, torch.device)
|
98 |
-
else torch.device(clip_device)
|
99 |
-
)
|
100 |
-
self.device_t5 = (
|
101 |
-
t5_device
|
102 |
-
if isinstance(t5_device, torch.device)
|
103 |
-
else torch.device(t5_device)
|
104 |
-
)
|
105 |
-
self.dtype = dtype
|
106 |
self.offload = offload
|
107 |
self.clip: "HFEmbedder" = clip
|
108 |
self.t5: "HFEmbedder" = t5
|
@@ -116,6 +105,8 @@ class FluxPipeline:
|
|
116 |
self.offload_text_encoder = config.offload_text_encoder
|
117 |
self.offload_vae = config.offload_vae
|
118 |
self.offload_flow = config.offload_flow
|
|
|
|
|
119 |
if not self.offload_flow:
|
120 |
self.model.to(self.device_flux)
|
121 |
if not self.offload_vae:
|
@@ -124,40 +115,16 @@ class FluxPipeline:
|
|
124 |
self.clip.to(self.device_clip)
|
125 |
self.t5.to(self.device_t5)
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
num_steps=25,
|
135 |
-
guidance=3.5,
|
136 |
-
seed=10,
|
137 |
-
)
|
138 |
-
self.generate(**warmup_dict)
|
139 |
-
to_gpu_extras = [
|
140 |
-
"vector_in",
|
141 |
-
"img_in",
|
142 |
-
"txt_in",
|
143 |
-
"time_in",
|
144 |
-
"guidance_in",
|
145 |
-
"final_layer",
|
146 |
-
"pe_embedder",
|
147 |
-
]
|
148 |
-
if self.config.compile_blocks:
|
149 |
-
for block in self.model.double_blocks:
|
150 |
-
block.compile()
|
151 |
-
for block in self.model.single_blocks:
|
152 |
-
block.compile()
|
153 |
-
if self.config.compile_extras:
|
154 |
-
for extra in to_gpu_extras:
|
155 |
-
getattr(self.model, extra).compile()
|
156 |
-
|
157 |
-
def set_seed(self, seed: int | None = None) -> torch.Generator:
|
158 |
if isinstance(seed, (int, float)):
|
159 |
seed = int(abs(seed)) % MAX_RAND
|
160 |
-
|
161 |
elif isinstance(seed, str):
|
162 |
try:
|
163 |
seed = abs(int(seed)) % MAX_RAND
|
@@ -166,14 +133,71 @@ class FluxPipeline:
|
|
166 |
f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
|
167 |
)
|
168 |
seed = abs(self.rng.seed()) % MAX_RAND
|
|
|
169 |
else:
|
170 |
seed = abs(self.rng.seed()) % MAX_RAND
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
175 |
return cuda_generator, seed
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
@torch.inference_mode()
|
178 |
def prepare(
|
179 |
self,
|
|
|
31 |
from tqdm import tqdm
|
32 |
from util import (
|
33 |
ModelSpec,
|
34 |
+
ModelVersion,
|
35 |
into_device,
|
36 |
into_dtype,
|
37 |
load_config_from_path,
|
|
|
81 |
This class is responsible for preparing input tensors for the Flux model, generating
|
82 |
timesteps and noise, and handling device management for model offloading.
|
83 |
"""
|
84 |
+
|
85 |
+
if config is None:
|
86 |
+
raise ValueError("ModelSpec config is required!")
|
87 |
+
|
88 |
self.debug = debug
|
89 |
self.name = name
|
90 |
+
self.device_flux = into_device(flux_device)
|
91 |
+
self.device_ae = into_device(ae_device)
|
92 |
+
self.device_clip = into_device(clip_device)
|
93 |
+
self.device_t5 = into_device(t5_device)
|
94 |
+
self.dtype = into_dtype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.offload = offload
|
96 |
self.clip: "HFEmbedder" = clip
|
97 |
self.t5: "HFEmbedder" = t5
|
|
|
105 |
self.offload_text_encoder = config.offload_text_encoder
|
106 |
self.offload_vae = config.offload_vae
|
107 |
self.offload_flow = config.offload_flow
|
108 |
+
# If models are not offloaded, move them to the appropriate devices
|
109 |
+
|
110 |
if not self.offload_flow:
|
111 |
self.model.to(self.device_flux)
|
112 |
if not self.offload_vae:
|
|
|
115 |
self.clip.to(self.device_clip)
|
116 |
self.t5.to(self.device_t5)
|
117 |
|
118 |
+
# compile the model if needed
|
119 |
+
if config.compile_blocks or config.compile_extras:
|
120 |
+
self.compile()
|
121 |
+
|
122 |
+
def set_seed(
|
123 |
+
self, seed: int | None = None, seed_globally: bool = False
|
124 |
+
) -> torch.Generator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if isinstance(seed, (int, float)):
|
126 |
seed = int(abs(seed)) % MAX_RAND
|
127 |
+
cuda_generator = torch.Generator("cuda").manual_seed(seed)
|
128 |
elif isinstance(seed, str):
|
129 |
try:
|
130 |
seed = abs(int(seed)) % MAX_RAND
|
|
|
133 |
f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
|
134 |
)
|
135 |
seed = abs(self.rng.seed()) % MAX_RAND
|
136 |
+
cuda_generator = torch.Generator("cuda").manual_seed(seed)
|
137 |
else:
|
138 |
seed = abs(self.rng.seed()) % MAX_RAND
|
139 |
+
cuda_generator = torch.Generator("cuda").manual_seed(seed)
|
140 |
+
|
141 |
+
if seed_globally:
|
142 |
+
torch.cuda.manual_seed_all(seed)
|
143 |
+
np.random.seed(seed)
|
144 |
+
random.seed(seed)
|
145 |
return cuda_generator, seed
|
146 |
|
147 |
+
@torch.inference_mode()
|
148 |
+
def compile(self):
|
149 |
+
"""
|
150 |
+
Compiles the model and extras.
|
151 |
+
|
152 |
+
First, if:
|
153 |
+
|
154 |
+
- A) Checkpoint which already has float8 quantized weights and tuned input scales.
|
155 |
+
In which case, it will not run warmups since it assumes the input scales are already tuned.
|
156 |
+
|
157 |
+
- B) Checkpoint which has not been quantized, in which case it will be quantized
|
158 |
+
and the input scales will be tuned. via running a warmup loop.
|
159 |
+
- If the model is flux-schnell, it will run 3 warmup loops since each loop is 4 steps.
|
160 |
+
- If the model is flux-dev, it will run 1 warmup loop for 12 steps.
|
161 |
+
|
162 |
+
"""
|
163 |
+
|
164 |
+
# Run warmups if the checkpoint is not prequantized
|
165 |
+
if not self.config.prequantized_flow:
|
166 |
+
logger.info("Running warmups for compile...")
|
167 |
+
warmup_dict = dict(
|
168 |
+
prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
|
169 |
+
height=768,
|
170 |
+
width=768,
|
171 |
+
num_steps=12,
|
172 |
+
guidance=3.5,
|
173 |
+
seed=10,
|
174 |
+
)
|
175 |
+
if self.config.version == ModelVersion.flux_schnell:
|
176 |
+
warmup_dict["num_steps"] = 4
|
177 |
+
for _ in range(3):
|
178 |
+
self.generate(**warmup_dict)
|
179 |
+
else:
|
180 |
+
self.generate(**warmup_dict)
|
181 |
+
|
182 |
+
# Compile the model and extras
|
183 |
+
to_gpu_extras = [
|
184 |
+
"vector_in",
|
185 |
+
"img_in",
|
186 |
+
"txt_in",
|
187 |
+
"time_in",
|
188 |
+
"guidance_in",
|
189 |
+
"final_layer",
|
190 |
+
"pe_embedder",
|
191 |
+
]
|
192 |
+
if self.config.compile_blocks:
|
193 |
+
for block in self.model.double_blocks:
|
194 |
+
block.compile()
|
195 |
+
for block in self.model.single_blocks:
|
196 |
+
block.compile()
|
197 |
+
if self.config.compile_extras:
|
198 |
+
for extra in to_gpu_extras:
|
199 |
+
getattr(self.model, extra).compile()
|
200 |
+
|
201 |
@torch.inference_mode()
|
202 |
def prepare(
|
203 |
self,
|