Update src/pipeline.py
Browse files- src/pipeline.py +73 -33
src/pipeline.py
CHANGED
@@ -27,47 +27,87 @@ import torch.nn as nn
|
|
27 |
import torch.nn.functional as F
|
28 |
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
|
29 |
|
30 |
-
# preconfigs
|
31 |
-
import os
|
32 |
-
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
|
33 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
34 |
-
torch._dynamo.config.suppress_errors = True
|
35 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
36 |
-
torch.backends.cudnn.enabled = True
|
37 |
# torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
38 |
|
39 |
-
# globals
|
40 |
Pipeline = None
|
41 |
-
ckpt_id = "black-forest-labs/FLUX.1-schnell"
|
42 |
-
ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
|
43 |
TinyVAE = "madebyollin/taef1"
|
44 |
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
|
45 |
|
46 |
-
def empty_cache():
|
47 |
-
gc.collect()
|
48 |
-
torch.cuda.empty_cache()
|
49 |
-
torch.cuda.reset_max_memory_allocated()
|
50 |
-
torch.cuda.reset_peak_memory_stats()
|
51 |
|
52 |
-
def
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
|
59 |
pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
|
60 |
-
pipeline.to("cuda")
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return pipeline
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
import torch.nn.functional as F
|
28 |
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# torch.backends.cudnn.benchmark = True
|
31 |
+
from collections import namedtuple
|
32 |
+
import os
|
33 |
+
Config = namedtuple('Config', ['model_id', 'revision', 'text_encoder_id', 'text_encoder_rev', 'vae_id', 'vae_rev'])
|
34 |
|
|
|
35 |
Pipeline = None
|
|
|
|
|
36 |
TinyVAE = "madebyollin/taef1"
|
37 |
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def setup_environment():
|
41 |
+
os.environ.update({
|
42 |
+
'PYTORCH_CUDA_ALLOC_CONF': 'expandable_segments:True',
|
43 |
+
'TOKENIZERS_PARALLELISM': 'True',
|
44 |
+
})
|
45 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
46 |
+
torch.backends.cudnn.enabled = True
|
47 |
+
torch._dynamo.config.suppress_errors = True
|
48 |
+
|
49 |
+
def get_model_components(cfg):
|
50 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
51 |
+
cfg.text_encoder_id,
|
52 |
+
revision=cfg.text_encoder_rev,
|
53 |
+
subfolder="text_encoder_2",
|
54 |
+
torch_dtype=torch.bfloat16
|
55 |
+
)
|
56 |
+
path = os.path.join(HF_HUB_CACHE,
|
57 |
+
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
|
58 |
+
return text_encoder, path
|
59 |
+
|
60 |
+
def initialize_pipeline(cfg, text_encoder, transformer_path):
|
61 |
+
transformer = FluxTransformer2DModel.from_pretrained(transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False)
|
62 |
+
pipeline = FluxPipeline.from_pretrained(
|
63 |
+
cfg.model_id,
|
64 |
+
revision=cfg.revision,
|
65 |
+
text_encoder_2=text_encoder,
|
66 |
+
transformer=transformer,
|
67 |
+
torch_dtype=torch.bfloat16
|
68 |
+
)
|
69 |
+
basepath = os.path.join(HF_HUB_CACHE,
|
70 |
+
"models--manbeast3b--Flux.1.schnell-vae-kl-unst0_1_iter0/snapshots/b586f7e1125722a242c38fe963904f453095903f")
|
71 |
+
|
72 |
pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
|
73 |
pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
|
74 |
+
return pipeline.to("cuda", memory_format=torch.channels_last)
|
75 |
+
|
76 |
+
def load_pipeline():
|
77 |
+
setup_environment()
|
78 |
+
torch.cuda.empty_cache()
|
79 |
+
cfg = Config(
|
80 |
+
"black-forest-labs/FLUX.1-schnell",
|
81 |
+
"741f7c3ce8b383c54771c7003378a50191e9efe9",
|
82 |
+
"manbeast3b/flux.1-schnell-full1",
|
83 |
+
"cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
|
84 |
+
"manbeast3b/Flux.1.schnell-vae-kl-unst0_1_iter0",
|
85 |
+
"b586f7e1125722a242c38fe963904f453095903f"
|
86 |
+
)
|
87 |
+
text_encoder, transformer_path = get_model_components(cfg)
|
88 |
+
pipeline = initialize_pipeline(cfg, text_encoder, transformer_path)
|
89 |
+
warmup_ = "insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus"
|
90 |
+
pipeline(
|
91 |
+
prompt=warmup_,
|
92 |
+
width=1024, height=1024,
|
93 |
+
guidance_scale=0.0,
|
94 |
+
num_inference_steps=4,
|
95 |
+
max_sequence_length=256
|
96 |
+
)
|
97 |
+
pipeline("")
|
98 |
return pipeline
|
99 |
|
100 |
+
|
101 |
+
def infer(request, pipeline, generator):
|
102 |
+
with torch.no_grad():
|
103 |
+
result = pipeline(
|
104 |
+
prompt=request.prompt,
|
105 |
+
generator=generator,
|
106 |
+
guidance_scale=0.0,
|
107 |
+
num_inference_steps=4,
|
108 |
+
max_sequence_length=256,
|
109 |
+
height=request.height,
|
110 |
+
width=request.width,
|
111 |
+
output_type="pil"
|
112 |
+
)
|
113 |
+
return result.images[0]
|