from constants import LCM_DEFAULT_MODEL from diffusers import ( DiffusionPipeline, AutoencoderTiny, UNet2DConditionModel, LCMScheduler, ) import torch from backend.tiny_decoder import get_tiny_decoder_vae_model from typing import Any from diffusers import ( LCMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, ) def _get_lcm_pipeline_from_base_model( lcm_model_id: str, base_model_id: str, use_local_model: bool, ): pipeline = None unet = UNet2DConditionModel.from_pretrained( lcm_model_id, torch_dtype=torch.float32, local_files_only=use_local_model, ) pipeline = DiffusionPipeline.from_pretrained( base_model_id, unet=unet, torch_dtype=torch.float32, local_files_only=use_local_model, ) pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) return pipeline def load_taesd( pipeline: Any, use_local_model: bool = False, torch_data_type: torch.dtype = torch.float32, ): vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__) pipeline.vae = AutoencoderTiny.from_pretrained( vae_model, torch_dtype=torch_data_type, local_files_only=use_local_model, ) def get_lcm_model_pipeline( model_id: str = LCM_DEFAULT_MODEL, use_local_model: bool = False, ): pipeline = None if model_id == "latent-consistency/lcm-sdxl": pipeline = _get_lcm_pipeline_from_base_model( model_id, "stabilityai/stable-diffusion-xl-base-1.0", use_local_model, ) elif model_id == "latent-consistency/lcm-ssd-1b": pipeline = _get_lcm_pipeline_from_base_model( model_id, "segmind/SSD-1B", use_local_model, ) else: pipeline = DiffusionPipeline.from_pretrained( model_id, local_files_only=use_local_model, ) return pipeline def get_image_to_image_pipeline(pipeline: Any) -> Any: components = pipeline.components pipeline_class = pipeline.__class__.__name__ if ( pipeline_class == "LatentConsistencyModelPipeline" or pipeline_class == "StableDiffusionPipeline" ): return StableDiffusionImg2ImgPipeline(**components) elif pipeline_class == "StableDiffusionXLPipeline": return StableDiffusionXLImg2ImgPipeline(**components) else: raise Exception(f"Unknown pipeline {pipeline_class}")