Spaces:
Running
on
Zero
Running
on
Zero
tight-inversion
commited on
Commit
·
4d0ddc3
1
Parent(s):
10d3d92
Align with pulid demo
Browse files- app.py +1 -0
- flux/util.py +3 -22
app.py
CHANGED
@@ -431,6 +431,7 @@ if __name__ == "__main__":
|
|
431 |
args.offload = True
|
432 |
|
433 |
print(f"Using device: {args.device}")
|
|
|
434 |
print(f"Offload: {args.offload}")
|
435 |
|
436 |
demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
|
|
|
431 |
args.offload = True
|
432 |
|
433 |
print(f"Using device: {args.device}")
|
434 |
+
print(f"fp8: {args.fp8}")
|
435 |
print(f"Offload: {args.offload}")
|
436 |
|
437 |
demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
|
flux/util.py
CHANGED
@@ -123,36 +123,17 @@ def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
|
|
123 |
):
|
124 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
model = Flux(configs[name].params)
|
129 |
-
model = model.to_empty(device=device)
|
130 |
|
131 |
if ckpt_path is not None:
|
132 |
print("Loading checkpoint")
|
133 |
-
#
|
134 |
sd = load_sft(ckpt_path, device=str(device))
|
135 |
-
# Load the state dictionary into the model
|
136 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
137 |
print_load_warning(missing, unexpected)
|
138 |
-
model.to(torch.bfloat16)
|
139 |
return model
|
140 |
|
141 |
-
# from XLabs-AI https://github.com/XLabs-AI/x-flux/blob/1f8ef54972105ad9062be69fe6b7f841bce02a08/src/flux/util.py#L330
|
142 |
-
def load_flow_model_quintized(name: str, device: str = "cuda", hf_download: bool = True):
|
143 |
-
# Loading Flux
|
144 |
-
print("Init model")
|
145 |
-
ckpt_path = 'models/flux-dev-fp8.safetensors'
|
146 |
-
if (
|
147 |
-
not os.path.exists(ckpt_path)
|
148 |
-
and hf_download
|
149 |
-
):
|
150 |
-
print("Downloading model")
|
151 |
-
ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors")
|
152 |
-
print("Model downloaded to", ckpt_path)
|
153 |
-
json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", 'flux_dev_quantization_map.json')
|
154 |
-
|
155 |
-
model = Flux(configs[name].params).to(torch.bfloat16)
|
156 |
def load_flow_model_quintized(
|
157 |
name: str,
|
158 |
device: str = "cuda",
|
|
|
123 |
):
|
124 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
|
125 |
|
126 |
+
with torch.device(device):
|
127 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
|
|
|
|
128 |
|
129 |
if ckpt_path is not None:
|
130 |
print("Loading checkpoint")
|
131 |
+
# load_sft doesn't support torch.device
|
132 |
sd = load_sft(ckpt_path, device=str(device))
|
|
|
133 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
134 |
print_load_warning(missing, unexpected)
|
|
|
135 |
return model
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
def load_flow_model_quintized(
|
138 |
name: str,
|
139 |
device: str = "cuda",
|