manbeast3b commited on
Commit
a22cf6c
1 Parent(s): 58d3b2e

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +15 -9
src/pipeline.py CHANGED
@@ -16,7 +16,9 @@ from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
16
  import os
17
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
18
  torch._dynamo.config.suppress_errors = True
19
-
 
 
20
  Pipeline = None
21
 
22
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
@@ -35,26 +37,30 @@ def load_pipeline() -> Pipeline:
35
  text_encoder_2 = T5EncoderModel.from_pretrained(
36
  "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
37
  ).to(memory_format=torch.channels_last)
38
-
39
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
40
- model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
41
  pipeline = FluxPipeline.from_pretrained(
42
  ckpt_id,
43
  revision=ckpt_revision,
44
- transformer=model,
45
  text_encoder_2=text_encoder_2,
46
  torch_dtype=dtype,
47
  ).to(device)
48
  pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead")
49
  quantize_(pipeline.vae, int8_weight_only())
50
- for _ in range(5):
 
51
  pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
52
 
53
  empty_cache()
54
  return pipeline
55
 
56
-
57
- @torch.no_grad()
58
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
59
- image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
60
- return(image)
 
 
 
 
 
16
  import os
17
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
18
  torch._dynamo.config.suppress_errors = True
19
+ torch.backends.cudnn.benchmark = True
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+ torch.cuda.set_per_process_memory_fraction(0.99)
22
  Pipeline = None
23
 
24
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
 
37
  text_encoder_2 = T5EncoderModel.from_pretrained(
38
  "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
39
  ).to(memory_format=torch.channels_last)
 
40
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
41
+ transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
42
  pipeline = FluxPipeline.from_pretrained(
43
  ckpt_id,
44
  revision=ckpt_revision,
45
+ transformer=transformer,
46
  text_encoder_2=text_encoder_2,
47
  torch_dtype=dtype,
48
  ).to(device)
49
  pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead")
50
  quantize_(pipeline.vae, int8_weight_only())
51
+
52
+ for _ in range(4):
53
  pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
54
 
55
  empty_cache()
56
  return pipeline
57
 
58
+ sample = True
59
+ @torch.inference_mode()
60
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
61
+ global sample
62
+ if sample:
63
+ empty_cache()
64
+ sample = None
65
+ image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
66
+ return(image)