|
from rct_diffusion_pipeline import RCTDiffusionPipeline |
|
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL |
|
import torch |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import torch.nn as nn |
|
|
|
torch_device = "cuda" |
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True |
|
).to('cuda') |
|
|
|
test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
test1 = text_encoder(test1.input_ids.to('cuda'))[0] |
|
|
|
test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
test2 = text_encoder(test2.input_ids.to('cuda'))[0] |
|
|
|
unet = UNet2DConditionModel(sample_size=32, in_channels=4, out_channels=4, \ |
|
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\ |
|
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768, |
|
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32) |
|
unet = unet.to('cuda', dtype=torch.float16) |
|
|
|
|
|
for layer in unet.modules(): |
|
if isinstance(layer, nn.BatchNorm2d): |
|
layer.float() |
|
|
|
scheduler = DDPMScheduler(num_train_timesteps=20) |
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True) |
|
vae = vae.to('cuda', dtype=torch.float16) |
|
|
|
|
|
pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_999') |
|
output = pipeline(['pagoda pine tree'], ['green'], ['grey']) |
|
output[0].save('out.png') |
|
pipeline.save_pretrained('test') |
|
print('test') |