Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,17 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import random
|
4 |
import spaces
|
5 |
import torch
|
6 |
-
from
|
|
|
|
|
7 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
8 |
|
9 |
-
from model import Flux
|
10 |
|
11 |
def calculate_shift(
|
12 |
image_seq_len,
|
@@ -174,20 +179,17 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
|
|
174 |
@dataclass
|
175 |
class ModelSpec:
|
176 |
params: FluxParams
|
177 |
-
ae_params: AutoEncoderParams
|
178 |
-
ckpt_path: str
|
179 |
-
ae_path: str
|
180 |
repo_id: str
|
181 |
repo_flow: str
|
182 |
repo_ae: str
|
183 |
repo_id_ae: str
|
184 |
|
|
|
185 |
config = ModelSpec(
|
186 |
repo_id="TencentARC/flux-mini",
|
187 |
repo_flow="flux-mini.safetensors",
|
188 |
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
189 |
repo_ae="ae.safetensors",
|
190 |
-
ckpt_path=os.getenv("FLUX_MINI", None),
|
191 |
params=FluxParams(
|
192 |
in_channels=64,
|
193 |
vec_in_dim=768,
|
@@ -202,35 +204,33 @@ config = ModelSpec(
|
|
202 |
qkv_bias=True,
|
203 |
guidance_embed=True,
|
204 |
)
|
|
|
205 |
|
206 |
|
207 |
-
def load_flow_model2(device: str = "cuda", hf_download: bool = True):
|
208 |
-
if (
|
209 |
-
and config.repo_id is not None
|
210 |
and config.repo_flow is not None
|
211 |
and hf_download
|
212 |
):
|
213 |
-
ckpt_path = hf_hub_download(
|
214 |
|
215 |
-
model = Flux(params)
|
216 |
if ckpt_path is not None:
|
217 |
sd = load_sft(ckpt_path, device=str(device))
|
218 |
missing, unexpected = model.load_state_dict(sd, strict=True)
|
219 |
return model
|
220 |
|
221 |
|
222 |
-
|
223 |
-
|
224 |
dtype = torch.bfloat16
|
225 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
226 |
|
227 |
-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
|
228 |
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
229 |
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
|
230 |
-
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
|
231 |
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
|
232 |
-
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
|
233 |
-
transformer = load_flow_model2(device)
|
234 |
|
235 |
pipe = FluxPipeline(
|
236 |
scheduler,
|
@@ -238,7 +238,7 @@ pipe = FluxPipeline(
|
|
238 |
text_encoder,
|
239 |
tokenizer,
|
240 |
text_encoder_2,
|
241 |
-
tokenizer_2
|
242 |
transformer
|
243 |
)
|
244 |
torch.cuda.empty_cache()
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Union, Optional, List, Any, Dict
|
3 |
+
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import random
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
|
12 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
13 |
|
14 |
+
from model import Flux, FluxParams
|
15 |
|
16 |
def calculate_shift(
|
17 |
image_seq_len,
|
|
|
179 |
@dataclass
|
180 |
class ModelSpec:
|
181 |
params: FluxParams
|
|
|
|
|
|
|
182 |
repo_id: str
|
183 |
repo_flow: str
|
184 |
repo_ae: str
|
185 |
repo_id_ae: str
|
186 |
|
187 |
+
|
188 |
config = ModelSpec(
|
189 |
repo_id="TencentARC/flux-mini",
|
190 |
repo_flow="flux-mini.safetensors",
|
191 |
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
192 |
repo_ae="ae.safetensors",
|
|
|
193 |
params=FluxParams(
|
194 |
in_channels=64,
|
195 |
vec_in_dim=768,
|
|
|
204 |
qkv_bias=True,
|
205 |
guidance_embed=True,
|
206 |
)
|
207 |
+
)
|
208 |
|
209 |
|
210 |
+
def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
|
211 |
+
if (config.repo_id is not None
|
|
|
212 |
and config.repo_flow is not None
|
213 |
and hf_download
|
214 |
):
|
215 |
+
ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
|
216 |
|
217 |
+
model = Flux(config.params)
|
218 |
if ckpt_path is not None:
|
219 |
sd = load_sft(ckpt_path, device=str(device))
|
220 |
missing, unexpected = model.load_state_dict(sd, strict=True)
|
221 |
return model
|
222 |
|
223 |
|
|
|
|
|
224 |
dtype = torch.bfloat16
|
225 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
226 |
|
227 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
|
228 |
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
229 |
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
|
230 |
+
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
|
231 |
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
|
232 |
+
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
|
233 |
+
transformer = load_flow_model2(config, device)
|
234 |
|
235 |
pipe = FluxPipeline(
|
236 |
scheduler,
|
|
|
238 |
text_encoder,
|
239 |
tokenizer,
|
240 |
text_encoder_2,
|
241 |
+
tokenizer_2,
|
242 |
transformer
|
243 |
)
|
244 |
torch.cuda.empty_cache()
|