SunderAli17 commited on
Commit
3edb05e
·
verified ·
1 Parent(s): e57c305

Delete eva_clip/flux/util.py

Browse files
Files changed (1) hide show
  1. eva_clip/flux/util.py +0 -156
eva_clip/flux/util.py DELETED
@@ -1,156 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- from einops import rearrange
6
- from huggingface_hub import hf_hub_download
7
- from safetensors.torch import load_file as load_sft
8
-
9
- from flux.model import Flux, FluxParams
10
- from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
11
- from flux.modules.conditioner import HFEmbedder
12
-
13
-
14
- @dataclass
15
- class ModelSpec:
16
- params: FluxParams
17
- ae_params: AutoEncoderParams
18
- ckpt_path: str
19
- ae_path: str
20
- repo_id: str
21
- repo_flow: str
22
- repo_ae: str
23
-
24
-
25
- configs = {
26
- "flux-dev": ModelSpec(
27
- repo_id="black-forest-labs/FLUX.1-dev",
28
- repo_flow="flux1-dev.safetensors",
29
- repo_ae="ae.safetensors",
30
- ckpt_path='models/flux1-dev.safetensors',
31
- params=FluxParams(
32
- in_channels=64,
33
- vec_in_dim=768,
34
- context_in_dim=4096,
35
- hidden_size=3072,
36
- mlp_ratio=4.0,
37
- num_heads=24,
38
- depth=19,
39
- depth_single_blocks=38,
40
- axes_dim=[16, 56, 56],
41
- theta=10_000,
42
- qkv_bias=True,
43
- guidance_embed=True,
44
- ),
45
- ae_path='models/ae.safetensors',
46
- ae_params=AutoEncoderParams(
47
- resolution=256,
48
- in_channels=3,
49
- ch=128,
50
- out_ch=3,
51
- ch_mult=[1, 2, 4, 4],
52
- num_res_blocks=2,
53
- z_channels=16,
54
- scale_factor=0.3611,
55
- shift_factor=0.1159,
56
- ),
57
- ),
58
- "flux-schnell": ModelSpec(
59
- repo_id="black-forest-labs/FLUX.1-schnell",
60
- repo_flow="flux1-schnell.safetensors",
61
- repo_ae="ae.safetensors",
62
- ckpt_path=os.getenv("FLUX_SCHNELL"),
63
- params=FluxParams(
64
- in_channels=64,
65
- vec_in_dim=768,
66
- context_in_dim=4096,
67
- hidden_size=3072,
68
- mlp_ratio=4.0,
69
- num_heads=24,
70
- depth=19,
71
- depth_single_blocks=38,
72
- axes_dim=[16, 56, 56],
73
- theta=10_000,
74
- qkv_bias=True,
75
- guidance_embed=False,
76
- ),
77
- ae_path=os.getenv("AE"),
78
- ae_params=AutoEncoderParams(
79
- resolution=256,
80
- in_channels=3,
81
- ch=128,
82
- out_ch=3,
83
- ch_mult=[1, 2, 4, 4],
84
- num_res_blocks=2,
85
- z_channels=16,
86
- scale_factor=0.3611,
87
- shift_factor=0.1159,
88
- ),
89
- ),
90
- }
91
-
92
-
93
- def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
94
- if len(missing) > 0 and len(unexpected) > 0:
95
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
96
- print("\n" + "-" * 79 + "\n")
97
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
98
- elif len(missing) > 0:
99
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
100
- elif len(unexpected) > 0:
101
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
102
-
103
-
104
- def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
105
- # Loading Flux
106
- print("Init model")
107
- ckpt_path = configs[name].ckpt_path
108
- if (
109
- not os.path.exists(ckpt_path)
110
- and configs[name].repo_id is not None
111
- and configs[name].repo_flow is not None
112
- and hf_download
113
- ):
114
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
115
-
116
- with torch.device(device):
117
- model = Flux(configs[name].params).to(torch.bfloat16)
118
-
119
- if ckpt_path is not None:
120
- print("Loading checkpoint")
121
- # load_sft doesn't support torch.device
122
- sd = load_sft(ckpt_path, device=str(device))
123
- missing, unexpected = model.load_state_dict(sd, strict=False)
124
- print_load_warning(missing, unexpected)
125
- return model
126
-
127
-
128
- def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
129
- # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
130
- return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
131
-
132
-
133
- def load_clip(device: str = "cuda") -> HFEmbedder:
134
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
135
-
136
-
137
- def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
138
- ckpt_path = configs[name].ae_path
139
- if (
140
- not os.path.exists(ckpt_path)
141
- and configs[name].repo_id is not None
142
- and configs[name].repo_ae is not None
143
- and hf_download
144
- ):
145
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
146
-
147
- # Loading the autoencoder
148
- print("Init AE")
149
- with torch.device(device):
150
- ae = AutoEncoder(configs[name].ae_params)
151
-
152
- if ckpt_path is not None:
153
- sd = load_sft(ckpt_path, device=str(device))
154
- missing, unexpected = ae.load_state_dict(sd, strict=False)
155
- print_load_warning(missing, unexpected)
156
- return ae