rct_model / test_pipeline.py
frutiemax's picture
Use vae for encoding and decoding for training
9bde8da
raw
history blame
2.21 kB
from rct_diffusion_pipeline import RCTDiffusionPipeline
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
import torch
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn as nn
torch_device = "cuda"
# test of text tokenizers
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
).to('cuda')
test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
#test3 = tokenizer([1.0, 0.0, .05], is_split_into_words=True, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
test1 = text_encoder(test1.input_ids.to('cuda'))[0]
test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
test2 = text_encoder(test2.input_ids.to('cuda'))[0]
unet = UNet2DConditionModel(sample_size=32, in_channels=4, out_channels=4, \
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
unet = unet.to('cuda', dtype=torch.float16)
# put float32 for the accumulation
for layer in unet.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
scheduler = DDPMScheduler(num_train_timesteps=20)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
vae = vae.to('cuda', dtype=torch.float16)
#pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_999')
output = pipeline(['pagoda pine tree'], ['green'], ['grey'])
output[0].save('out.png')
pipeline.save_pretrained('test')
print('test')