Update pipeline.py
Browse files- pipeline.py +8 -11
pipeline.py
CHANGED
@@ -203,9 +203,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
203 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
204 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
205 |
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
206 |
-
negative_prompt_2_embed: Optional[torch.Tensor] = None,
|
207 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
208 |
-
negative_pooled_prompt_2_embed: Optional[torch.FloatTensor] = None,
|
209 |
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
210 |
lora_scale: Optional[float] = None,
|
211 |
):
|
@@ -268,26 +266,25 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
268 |
" the batch size of `prompt`."
|
269 |
)
|
270 |
|
271 |
-
|
272 |
-
|
273 |
device=device,
|
274 |
num_images_per_prompt=num_images_per_prompt,
|
275 |
)
|
276 |
-
|
277 |
-
|
278 |
-
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
279 |
prompt=negative_prompt_2,
|
280 |
num_images_per_prompt=num_images_per_prompt,
|
281 |
max_sequence_length=max_sequence_length,
|
282 |
device=device,
|
283 |
)
|
284 |
|
285 |
-
|
286 |
-
|
287 |
-
(0, t5_negative_prompt_embed.shape[-1] -
|
288 |
)
|
289 |
|
290 |
-
negative_prompt_embeds = torch.cat([
|
291 |
negative_pooled_prompt_embeds = torch.cat(
|
292 |
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
293 |
)
|
|
|
203 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
204 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
205 |
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
206 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
207 |
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
208 |
lora_scale: Optional[float] = None,
|
209 |
):
|
|
|
266 |
" the batch size of `prompt`."
|
267 |
)
|
268 |
|
269 |
+
negative_clip_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
270 |
+
prompt=negative_prompt,
|
271 |
device=device,
|
272 |
num_images_per_prompt=num_images_per_prompt,
|
273 |
)
|
274 |
+
|
275 |
+
t5_negative_prompt_embed, negative_pooled_prompt_2_embed = self._get_t5_prompt_embeds(
|
|
|
276 |
prompt=negative_prompt_2,
|
277 |
num_images_per_prompt=num_images_per_prompt,
|
278 |
max_sequence_length=max_sequence_length,
|
279 |
device=device,
|
280 |
)
|
281 |
|
282 |
+
negative_clip_prompt_embed = torch.nn.functional.pad(
|
283 |
+
negative_clip_prompt_embed,
|
284 |
+
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embed.shape[-1]),
|
285 |
)
|
286 |
|
287 |
+
negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
|
288 |
negative_pooled_prompt_embeds = torch.cat(
|
289 |
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
290 |
)
|