Safetensors
aredden commited on
Commit
ac049be
·
1 Parent(s): 37bd8c1

Move compile out of FluxPipeline init

Browse files
Files changed (1) hide show
  1. flux_pipeline.py +81 -57
flux_pipeline.py CHANGED
@@ -31,6 +31,7 @@ from torchvision.transforms import functional as TF
31
  from tqdm import tqdm
32
  from util import (
33
  ModelSpec,
 
34
  into_device,
35
  into_dtype,
36
  load_config_from_path,
@@ -80,29 +81,17 @@ class FluxPipeline:
80
  This class is responsible for preparing input tensors for the Flux model, generating
81
  timesteps and noise, and handling device management for model offloading.
82
  """
 
 
 
 
83
  self.debug = debug
84
  self.name = name
85
- self.device_flux = (
86
- flux_device
87
- if isinstance(flux_device, torch.device)
88
- else torch.device(flux_device)
89
- )
90
- self.device_ae = (
91
- ae_device
92
- if isinstance(ae_device, torch.device)
93
- else torch.device(ae_device)
94
- )
95
- self.device_clip = (
96
- clip_device
97
- if isinstance(clip_device, torch.device)
98
- else torch.device(clip_device)
99
- )
100
- self.device_t5 = (
101
- t5_device
102
- if isinstance(t5_device, torch.device)
103
- else torch.device(t5_device)
104
- )
105
- self.dtype = dtype
106
  self.offload = offload
107
  self.clip: "HFEmbedder" = clip
108
  self.t5: "HFEmbedder" = t5
@@ -116,6 +105,8 @@ class FluxPipeline:
116
  self.offload_text_encoder = config.offload_text_encoder
117
  self.offload_vae = config.offload_vae
118
  self.offload_flow = config.offload_flow
 
 
119
  if not self.offload_flow:
120
  self.model.to(self.device_flux)
121
  if not self.offload_vae:
@@ -124,40 +115,16 @@ class FluxPipeline:
124
  self.clip.to(self.device_clip)
125
  self.t5.to(self.device_t5)
126
 
127
- if self.config.compile_blocks or self.config.compile_extras:
128
- if not self.config.prequantized_flow:
129
- logger.info("Running warmups for compile...")
130
- warmup_dict = dict(
131
- prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
132
- height=768,
133
- width=768,
134
- num_steps=25,
135
- guidance=3.5,
136
- seed=10,
137
- )
138
- self.generate(**warmup_dict)
139
- to_gpu_extras = [
140
- "vector_in",
141
- "img_in",
142
- "txt_in",
143
- "time_in",
144
- "guidance_in",
145
- "final_layer",
146
- "pe_embedder",
147
- ]
148
- if self.config.compile_blocks:
149
- for block in self.model.double_blocks:
150
- block.compile()
151
- for block in self.model.single_blocks:
152
- block.compile()
153
- if self.config.compile_extras:
154
- for extra in to_gpu_extras:
155
- getattr(self.model, extra).compile()
156
-
157
- def set_seed(self, seed: int | None = None) -> torch.Generator:
158
  if isinstance(seed, (int, float)):
159
  seed = int(abs(seed)) % MAX_RAND
160
- self.rng = torch.manual_seed(seed)
161
  elif isinstance(seed, str):
162
  try:
163
  seed = abs(int(seed)) % MAX_RAND
@@ -166,14 +133,71 @@ class FluxPipeline:
166
  f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
167
  )
168
  seed = abs(self.rng.seed()) % MAX_RAND
 
169
  else:
170
  seed = abs(self.rng.seed()) % MAX_RAND
171
- torch.cuda.manual_seed_all(seed)
172
- np.random.seed(seed)
173
- random.seed(seed)
174
- cuda_generator = torch.Generator("cuda").manual_seed(seed)
 
 
175
  return cuda_generator, seed
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  @torch.inference_mode()
178
  def prepare(
179
  self,
 
31
  from tqdm import tqdm
32
  from util import (
33
  ModelSpec,
34
+ ModelVersion,
35
  into_device,
36
  into_dtype,
37
  load_config_from_path,
 
81
  This class is responsible for preparing input tensors for the Flux model, generating
82
  timesteps and noise, and handling device management for model offloading.
83
  """
84
+
85
+ if config is None:
86
+ raise ValueError("ModelSpec config is required!")
87
+
88
  self.debug = debug
89
  self.name = name
90
+ self.device_flux = into_device(flux_device)
91
+ self.device_ae = into_device(ae_device)
92
+ self.device_clip = into_device(clip_device)
93
+ self.device_t5 = into_device(t5_device)
94
+ self.dtype = into_dtype(dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  self.offload = offload
96
  self.clip: "HFEmbedder" = clip
97
  self.t5: "HFEmbedder" = t5
 
105
  self.offload_text_encoder = config.offload_text_encoder
106
  self.offload_vae = config.offload_vae
107
  self.offload_flow = config.offload_flow
108
+ # If models are not offloaded, move them to the appropriate devices
109
+
110
  if not self.offload_flow:
111
  self.model.to(self.device_flux)
112
  if not self.offload_vae:
 
115
  self.clip.to(self.device_clip)
116
  self.t5.to(self.device_t5)
117
 
118
+ # compile the model if needed
119
+ if config.compile_blocks or config.compile_extras:
120
+ self.compile()
121
+
122
+ def set_seed(
123
+ self, seed: int | None = None, seed_globally: bool = False
124
+ ) -> torch.Generator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if isinstance(seed, (int, float)):
126
  seed = int(abs(seed)) % MAX_RAND
127
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
128
  elif isinstance(seed, str):
129
  try:
130
  seed = abs(int(seed)) % MAX_RAND
 
133
  f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
134
  )
135
  seed = abs(self.rng.seed()) % MAX_RAND
136
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
137
  else:
138
  seed = abs(self.rng.seed()) % MAX_RAND
139
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
140
+
141
+ if seed_globally:
142
+ torch.cuda.manual_seed_all(seed)
143
+ np.random.seed(seed)
144
+ random.seed(seed)
145
  return cuda_generator, seed
146
 
147
+ @torch.inference_mode()
148
+ def compile(self):
149
+ """
150
+ Compiles the model and extras.
151
+
152
+ First, if:
153
+
154
+ - A) Checkpoint which already has float8 quantized weights and tuned input scales.
155
+ In which case, it will not run warmups since it assumes the input scales are already tuned.
156
+
157
+ - B) Checkpoint which has not been quantized, in which case it will be quantized
158
+ and the input scales will be tuned. via running a warmup loop.
159
+ - If the model is flux-schnell, it will run 3 warmup loops since each loop is 4 steps.
160
+ - If the model is flux-dev, it will run 1 warmup loop for 12 steps.
161
+
162
+ """
163
+
164
+ # Run warmups if the checkpoint is not prequantized
165
+ if not self.config.prequantized_flow:
166
+ logger.info("Running warmups for compile...")
167
+ warmup_dict = dict(
168
+ prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
169
+ height=768,
170
+ width=768,
171
+ num_steps=12,
172
+ guidance=3.5,
173
+ seed=10,
174
+ )
175
+ if self.config.version == ModelVersion.flux_schnell:
176
+ warmup_dict["num_steps"] = 4
177
+ for _ in range(3):
178
+ self.generate(**warmup_dict)
179
+ else:
180
+ self.generate(**warmup_dict)
181
+
182
+ # Compile the model and extras
183
+ to_gpu_extras = [
184
+ "vector_in",
185
+ "img_in",
186
+ "txt_in",
187
+ "time_in",
188
+ "guidance_in",
189
+ "final_layer",
190
+ "pe_embedder",
191
+ ]
192
+ if self.config.compile_blocks:
193
+ for block in self.model.double_blocks:
194
+ block.compile()
195
+ for block in self.model.single_blocks:
196
+ block.compile()
197
+ if self.config.compile_extras:
198
+ for extra in to_gpu_extras:
199
+ getattr(self.model, extra).compile()
200
+
201
  @torch.inference_mode()
202
  def prepare(
203
  self,