manbeast3b commited on
Commit
252dcbb
·
verified ·
1 Parent(s): c19c9a6

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +73 -33
src/pipeline.py CHANGED
@@ -27,47 +27,87 @@ import torch.nn as nn
27
  import torch.nn.functional as F
28
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
29
 
30
- # preconfigs
31
- import os
32
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
33
- os.environ["TOKENIZERS_PARALLELISM"] = "True"
34
- torch._dynamo.config.suppress_errors = True
35
- torch.backends.cuda.matmul.allow_tf32 = True
36
- torch.backends.cudnn.enabled = True
37
  # torch.backends.cudnn.benchmark = True
 
 
 
38
 
39
- # globals
40
  Pipeline = None
41
- ckpt_id = "black-forest-labs/FLUX.1-schnell"
42
- ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
43
  TinyVAE = "madebyollin/taef1"
44
  TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
45
 
46
- def empty_cache():
47
- gc.collect()
48
- torch.cuda.empty_cache()
49
- torch.cuda.reset_max_memory_allocated()
50
- torch.cuda.reset_peak_memory_stats()
51
 
52
- def load_pipeline() -> Pipeline:
53
- text_encoder_2 = T5EncoderModel.from_pretrained("manbeast3b/flux.1-schnell-full1", revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", subfolder="text_encoder_2",torch_dtype=torch.bfloat16)
54
- path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
55
- transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
56
- pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
57
- basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell-vae-kl-unst0_1_iter0/snapshots/b586f7e1125722a242c38fe963904f453095903f")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
59
  pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
60
- pipeline.to("cuda")
61
- pipeline.to(memory_format=torch.channels_last)
62
- for _ in range(1):
63
- pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  return pipeline
65
 
66
- sample = 1
67
- @torch.no_grad()
68
- def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
69
- global sample
70
- if not sample:
71
- sample=1
72
- empty_cache()
73
- return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
 
 
 
 
 
 
 
27
  import torch.nn.functional as F
28
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
29
 
 
 
 
 
 
 
 
30
  # torch.backends.cudnn.benchmark = True
31
+ from collections import namedtuple
32
+ import os
33
+ Config = namedtuple('Config', ['model_id', 'revision', 'text_encoder_id', 'text_encoder_rev', 'vae_id', 'vae_rev'])
34
 
 
35
  Pipeline = None
 
 
36
  TinyVAE = "madebyollin/taef1"
37
  TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
38
 
 
 
 
 
 
39
 
40
+ def setup_environment():
41
+ os.environ.update({
42
+ 'PYTORCH_CUDA_ALLOC_CONF': 'expandable_segments:True',
43
+ 'TOKENIZERS_PARALLELISM': 'True',
44
+ })
45
+ torch.backends.cuda.matmul.allow_tf32 = True
46
+ torch.backends.cudnn.enabled = True
47
+ torch._dynamo.config.suppress_errors = True
48
+
49
+ def get_model_components(cfg):
50
+ text_encoder = T5EncoderModel.from_pretrained(
51
+ cfg.text_encoder_id,
52
+ revision=cfg.text_encoder_rev,
53
+ subfolder="text_encoder_2",
54
+ torch_dtype=torch.bfloat16
55
+ )
56
+ path = os.path.join(HF_HUB_CACHE,
57
+ "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
58
+ return text_encoder, path
59
+
60
+ def initialize_pipeline(cfg, text_encoder, transformer_path):
61
+ transformer = FluxTransformer2DModel.from_pretrained(transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False)
62
+ pipeline = FluxPipeline.from_pretrained(
63
+ cfg.model_id,
64
+ revision=cfg.revision,
65
+ text_encoder_2=text_encoder,
66
+ transformer=transformer,
67
+ torch_dtype=torch.bfloat16
68
+ )
69
+ basepath = os.path.join(HF_HUB_CACHE,
70
+ "models--manbeast3b--Flux.1.schnell-vae-kl-unst0_1_iter0/snapshots/b586f7e1125722a242c38fe963904f453095903f")
71
+
72
  pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
73
  pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
74
+ return pipeline.to("cuda", memory_format=torch.channels_last)
75
+
76
+ def load_pipeline():
77
+ setup_environment()
78
+ torch.cuda.empty_cache()
79
+ cfg = Config(
80
+ "black-forest-labs/FLUX.1-schnell",
81
+ "741f7c3ce8b383c54771c7003378a50191e9efe9",
82
+ "manbeast3b/flux.1-schnell-full1",
83
+ "cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
84
+ "manbeast3b/Flux.1.schnell-vae-kl-unst0_1_iter0",
85
+ "b586f7e1125722a242c38fe963904f453095903f"
86
+ )
87
+ text_encoder, transformer_path = get_model_components(cfg)
88
+ pipeline = initialize_pipeline(cfg, text_encoder, transformer_path)
89
+ warmup_ = "insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus"
90
+ pipeline(
91
+ prompt=warmup_,
92
+ width=1024, height=1024,
93
+ guidance_scale=0.0,
94
+ num_inference_steps=4,
95
+ max_sequence_length=256
96
+ )
97
+ pipeline("")
98
  return pipeline
99
 
100
+
101
+ def infer(request, pipeline, generator):
102
+ with torch.no_grad():
103
+ result = pipeline(
104
+ prompt=request.prompt,
105
+ generator=generator,
106
+ guidance_scale=0.0,
107
+ num_inference_steps=4,
108
+ max_sequence_length=256,
109
+ height=request.height,
110
+ width=request.width,
111
+ output_type="pil"
112
+ )
113
+ return result.images[0]