from rct_diffusion_pipeline import RCTDiffusionPipeline from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL import torch from transformers import CLIPTextModel, CLIPTokenizer import torch.nn as nn torch_device = "cuda" # test of text tokenizers tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained( "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True ).to('cuda') test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") #test3 = tokenizer([1.0, 0.0, .05], is_split_into_words=True, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): test1 = text_encoder(test1.input_ids.to('cuda'))[0] test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): test2 = text_encoder(test2.input_ids.to('cuda'))[0] unet = UNet2DConditionModel(sample_size=32, in_channels=4, out_channels=4, \ down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\ up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768, block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32) unet = unet.to('cuda', dtype=torch.float16) # put float32 for the accumulation for layer in unet.modules(): if isinstance(layer, nn.BatchNorm2d): layer.float() scheduler = DDPMScheduler(num_train_timesteps=20) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True) vae = vae.to('cuda', dtype=torch.float16) #pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder) pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_999') output = pipeline(['pagoda pine tree'], ['green'], ['grey']) output[0].save('out.png') pipeline.save_pretrained('test') print('test')