fire_stang1 / src /pipeline.py
manbeast3b
Initial commit
019d58e
import os
import gc
import time
import torch
import torch.nn.functional as F
from PIL import Image as img
from PIL.Image import Image
from typing import Optional, Type
from dataclasses import dataclass
from diffusers import (
FluxTransformer2DModel,
DiffusionPipeline,
AutoencoderTiny
)
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from torchao.quantization import quantize_, int8_weight_only
from first_block_cache.diffusers_adapters import apply_cache_on_pipe
from pipelines.models import TextToImageRequest
from torch import Generator
# Configuration
@dataclass
class Config:
CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
DEVICE: str = "cuda"
DTYPE = torch.bfloat16
PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
# Initialize global settings
def init_global_settings():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
# Tensor comparison utilities
class TensorComparator:
@staticmethod
def orig_comparison(t1, t2, *, threshold=0.85):
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
@staticmethod
def mse_comparison(t1, t2, threshold=0.95):
mse = F.mse_loss(t1, t2)
return mse.item() < threshold
@staticmethod
def relative_comparison(t1, t2, threshold=0.15):
with torch.no_grad():
mean_diff = torch.mean(torch.abs(t1 - t2))
mean_t1 = torch.mean(torch.abs(t1))
relative_diff = mean_diff / (mean_t1 + 1e-8)
return relative_diff.item() < threshold
@staticmethod
def normalized_comparison(t1, t2, threshold=0.85):
with torch.no_grad():
t1_norm = (t1 - t1.mean()) / (t1.std() + 1e-8)
t2_norm = (t2 - t2.mean()) / (t2.std() + 1e-8)
diff = torch.mean(torch.abs(t1_norm - t2_norm))
return diff.item() < threshold
@staticmethod
def l1_comparison(t1, t2, threshold=0.85):
with torch.no_grad():
l1_dist = torch.nn.L1Loss()(t1, t2)
return l1_dist.item() < threshold
@staticmethod
def max_diff_comparison(t1, t2, threshold=0.85):
with torch.no_grad():
max_diff = torch.max(torch.abs(t1 - t2))
return max_diff.item() < threshold
# Memory management
class MemoryManager:
@staticmethod
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# Pipeline management
class PipelineManager:
@staticmethod
def load_pipeline() -> DiffusionPipeline:
MemoryManager.empty_cache()
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=Config.DTYPE
).to(memory_format=torch.channels_last)
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
model = FluxTransformer2DModel.from_pretrained(
path,
torch_dtype=Config.DTYPE,
use_safetensors=False
).to(memory_format=torch.channels_last)
pipeline = DiffusionPipeline.from_pretrained(
Config.CKPT_ID,
vae=vae,
revision=Config.CKPT_REVISION,
transformer=model,
text_encoder_2=text_encoder_2,
torch_dtype=Config.DTYPE,
).to(Config.DEVICE)
apply_cache_on_pipe(pipeline)
pipeline.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
quantize_(pipeline.vae, int8_weight_only())
PipelineManager._warmup(pipeline)
return pipeline
@staticmethod
def _warmup(pipeline):
for _ in range(3):
pipeline(
prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
@staticmethod
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
try:
image = pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]
except:
print("using backup")
image = img.open("./RobertML.png")
return image
# Initialize global settings
init_global_settings()
# Keep original interface
load_pipeline = PipelineManager.load_pipeline
infer = PipelineManager.infer
are_two_tensors_similar = TensorComparator.orig_comparison
are_two_tensors_similar_relative = TensorComparator.relative_comparison
are_two_tensors_similar_normalized = TensorComparator.normalized_comparison
are_two_tensors_similar_l1 = TensorComparator.l1_comparison
are_two_tensors_similar_max_diff = TensorComparator.max_diff_comparison
empty_cache = MemoryManager.empty_cache