from conversion_utils import populate_text_encoder, populate_unet, run_assertion from diffusers import ( AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import CLIPTextModel import keras_cv import tensorflow as tf PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4" REVISION = None NON_EMA_REVISION = None IMG_HEIGHT = IMG_WIDTH = 512 def initialize_pt_models(): """Initializes the separate models of Stable Diffusion from diffusers and downloads their pre-trained weights.""" pt_text_encoder = CLIPTextModel.from_pretrained( PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION ) pt_vae = AutoencoderKL.from_pretrained( PRETRAINED_CKPT, subfolder="vae", revision=REVISION ) pt_unet = UNet2DConditionModel.from_pretrained( PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION ) pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION ) return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker def initialize_tf_models(): """Initializes the separate models of Stable Diffusion from KerasCV and downloads their pre-trained weights.""" tf_sd_model = keras_cv.models.StableDiffusion(img_height=IMG_HEIGHT, img_width=IMG_WIDTH) _ = tf_sd_model.text_to_image("Cartoon") # To download the weights. tf_text_encoder = tf_sd_model.text_encoder tf_vae = tf_sd_model.image_encoder tf_unet = tf_sd_model.diffusion_model return tf_sd_model, tf_text_encoder, tf_vae, tf_unet def run_conversion(text_encoder_weights: str = None, unet_weights: str = None): pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models() tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models() print("Pre-trained model weights downloaded.") if text_encoder_weights is not None: print("Loading fine-tuned text encoder weights.") text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights) tf_text_encoder.load_weights(text_encoder_weights_path) if unet_weights is not None: print("Loading fine-tuned UNet weights.") unet_weights_path = tf.keras.utils.get_file(unet_weights) tf_unet.load_weights(unet_weights_path) text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder) unet_state_dict_from_tf = populate_unet(tf_unet) print("Conversion done, now running assertions...") # Since we cannot compare the fine-tuned weights. if text_encoder_weights is None: text_encoder_state_dict_from_pt = pt_text_encoder.state_dict() run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf) if unet_weights is None: unet_state_dict_from_pt = pt_text_encoder.state_dict() run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf) print("Assertions successful, populating the converted parameters into the diffusers models...") pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf) pt_unet.load_state_dict(unet_state_dict_from_tf) print("Parameters ported, preparing StabelDiffusionPipeline...") pipeline = StableDiffusionPipeline.from_pretrained( PRETRAINED_CKPT, unet=pt_unet, text_encoder=pt_text_encoder, vae=pt_vae, safety_checker=pt_safety_checker, revision=None, ) return pipeline