Use ConditionalUnetModel
Browse files- rct_diffusion_pipeline.py +24 -13
- test_pipeline.py +80 -0
- train_model.py +4 -4
rct_diffusion_pipeline.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from diffusers import DiffusionPipeline
|
2 |
from diffusers import DDPMPipeline
|
3 |
-
from diffusers import DDPMScheduler,
|
4 |
import torch
|
5 |
import torchvision.transforms as T
|
6 |
from PIL import Image
|
@@ -24,11 +24,13 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
24 |
self.scheduler = DDPMScheduler()
|
25 |
|
26 |
# the number of hidden features is dependant on the loaded dictionaries!
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
|
31 |
-
self.unet.to('cuda')
|
32 |
|
33 |
def load_dictionaries_from_dataset(self):
|
34 |
dataset = load_dataset('frutiemax/rct_dataset')
|
@@ -118,6 +120,8 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
118 |
|
119 |
offset += len(self.color2_dict.items())
|
120 |
class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index])
|
|
|
|
|
121 |
return class_labels
|
122 |
|
123 |
def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
|
@@ -164,29 +168,36 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
164 |
|
165 |
# now put those weights into a tensor
|
166 |
class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# set the inference steps
|
170 |
self.scheduler.set_timesteps(num_inference_steps)
|
171 |
|
172 |
-
noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to('cuda')
|
173 |
for batch_index in range(batch_size):
|
174 |
for view_index in range(4):
|
175 |
-
noise = torch.randn(3, 256, 256).to('cuda')
|
176 |
noise_batches[batch_index, view_index] = noise
|
177 |
|
178 |
# reshape the data so it's (batch_size, 12, 256, 256)
|
179 |
-
noise_batches = torch.reshape(noise_batches, (batch_size, 12, 256, 256)).to('cuda')
|
180 |
|
181 |
# now call the model for the n interations
|
182 |
progress_bar = tqdm(total=num_inference_steps)
|
183 |
epoch = 0
|
184 |
for t in self.scheduler.timesteps:
|
185 |
progress_bar.set_description(f'Inference step {epoch}')
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
190 |
progress_bar.update(1)
|
191 |
epoch = epoch + 1
|
192 |
|
|
|
1 |
from diffusers import DiffusionPipeline
|
2 |
from diffusers import DDPMPipeline
|
3 |
+
from diffusers import DDPMScheduler, UNet2DConditionModel
|
4 |
import torch
|
5 |
import torchvision.transforms as T
|
6 |
from PIL import Image
|
|
|
24 |
self.scheduler = DDPMScheduler()
|
25 |
|
26 |
# the number of hidden features is dependant on the loaded dictionaries!
|
27 |
+
hidden_dim = self.get_class_labels_size()
|
28 |
+
self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
|
29 |
+
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
30 |
+
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
|
31 |
+
block_out_channels=(12, 24, 30), norm_num_groups=6)
|
32 |
|
33 |
+
self.unet.to(device='cuda', dtype=torch.float16)
|
34 |
|
35 |
def load_dictionaries_from_dataset(self):
|
36 |
dataset = load_dataset('frutiemax/rct_dataset')
|
|
|
120 |
|
121 |
offset += len(self.color2_dict.items())
|
122 |
class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index])
|
123 |
+
|
124 |
+
class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size()))
|
125 |
return class_labels
|
126 |
|
127 |
def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
|
|
|
168 |
|
169 |
# now put those weights into a tensor
|
170 |
class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
|
171 |
+
|
172 |
+
# we need those class labels for the 12 channels
|
173 |
+
#new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
|
174 |
+
#new_class_labels[:, :] = class_labels
|
175 |
+
#class_labels = new_class_labels.to(device='cuda', dtype=torch.float16)
|
176 |
+
#del new_class_labels
|
177 |
|
178 |
# set the inference steps
|
179 |
self.scheduler.set_timesteps(num_inference_steps)
|
180 |
|
181 |
+
noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to(dtype=torch.float16, device='cuda')
|
182 |
for batch_index in range(batch_size):
|
183 |
for view_index in range(4):
|
184 |
+
noise = torch.randn(3, 256, 256).to(dtype=torch.float16, device='cuda')
|
185 |
noise_batches[batch_index, view_index] = noise
|
186 |
|
187 |
# reshape the data so it's (batch_size, 12, 256, 256)
|
188 |
+
noise_batches = torch.reshape(noise_batches, (batch_size, 1, 12, 256, 256)).to(dtype=torch.float16, device='cuda')
|
189 |
|
190 |
# now call the model for the n interations
|
191 |
progress_bar = tqdm(total=num_inference_steps)
|
192 |
epoch = 0
|
193 |
for t in self.scheduler.timesteps:
|
194 |
progress_bar.set_description(f'Inference step {epoch}')
|
195 |
+
|
196 |
+
for batch_index in range(batch_size):
|
197 |
+
with torch.no_grad():
|
198 |
+
noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=class_labels).sample
|
199 |
+
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
|
200 |
+
noise_batches[batch_index] = previous_noisy_sample
|
201 |
progress_bar.update(1)
|
202 |
epoch = epoch + 1
|
203 |
|
test_pipeline.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
|
|
2 |
|
3 |
|
4 |
torch_device = "cuda"
|
@@ -6,4 +7,83 @@ torch_device = "cuda"
|
|
6 |
pipeline = RCTDiffusionPipeline()
|
7 |
pipeline.print_class_tokens_to_csv()
|
8 |
output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
print('test')
|
|
|
1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
2 |
+
from diffusers import UNet2DConditionModel
|
3 |
|
4 |
|
5 |
torch_device = "cuda"
|
|
|
7 |
pipeline = RCTDiffusionPipeline()
|
8 |
pipeline.print_class_tokens_to_csv()
|
9 |
output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
10 |
+
|
11 |
+
# from PIL import Image
|
12 |
+
# import torch
|
13 |
+
# from transformers import CLIPTextModel, CLIPTokenizer
|
14 |
+
# from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
15 |
+
|
16 |
+
# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
|
17 |
+
# tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
18 |
+
# text_encoder = CLIPTextModel.from_pretrained(
|
19 |
+
# "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
20 |
+
# )
|
21 |
+
# unet = UNet2DConditionModel.from_pretrained(
|
22 |
+
# "CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True
|
23 |
+
# )
|
24 |
+
|
25 |
+
# from diffusers import UniPCMultistepScheduler
|
26 |
+
|
27 |
+
# scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
28 |
+
# torch_device = "cuda"
|
29 |
+
# vae.to(torch_device)
|
30 |
+
# text_encoder.to(torch_device)
|
31 |
+
# unet.to(torch_device)
|
32 |
+
|
33 |
+
# prompt = ["a photograph of an astronaut riding a horse"]
|
34 |
+
# height = 512 # default height of Stable Diffusion
|
35 |
+
# width = 512 # default width of Stable Diffusion
|
36 |
+
# num_inference_steps = 25 # Number of denoising steps
|
37 |
+
# guidance_scale = 7.5 # Scale for classifier-free guidance
|
38 |
+
# generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
|
39 |
+
# batch_size = len(prompt)
|
40 |
+
|
41 |
+
# text_input = tokenizer(
|
42 |
+
# prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
43 |
+
# )
|
44 |
+
|
45 |
+
# with torch.no_grad():
|
46 |
+
# text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
47 |
+
|
48 |
+
# text_input = tokenizer(
|
49 |
+
# prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
50 |
+
# )
|
51 |
+
|
52 |
+
# with torch.no_grad():
|
53 |
+
# text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
54 |
+
|
55 |
+
# max_length = text_input.input_ids.shape[-1]
|
56 |
+
# uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
57 |
+
# uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
58 |
+
|
59 |
+
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
60 |
+
|
61 |
+
# latents = torch.randn(
|
62 |
+
# (batch_size, unet.in_channels, height // 8, width // 8),
|
63 |
+
# generator=generator,
|
64 |
+
# )
|
65 |
+
# latents = latents.to(torch_device)
|
66 |
+
|
67 |
+
# latents = latents * scheduler.init_noise_sigma
|
68 |
+
|
69 |
+
# from tqdm.auto import tqdm
|
70 |
+
|
71 |
+
# scheduler.set_timesteps(num_inference_steps)
|
72 |
+
|
73 |
+
# for t in tqdm(scheduler.timesteps):
|
74 |
+
# # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
75 |
+
# latent_model_input = torch.cat([latents] * 2)
|
76 |
+
|
77 |
+
# latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
|
78 |
+
|
79 |
+
# # predict the noise residual
|
80 |
+
# with torch.no_grad():
|
81 |
+
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
82 |
+
|
83 |
+
# # perform guidance
|
84 |
+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
85 |
+
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
86 |
+
|
87 |
+
# # compute the previous noisy sample x_t -> x_t-1
|
88 |
+
# latents = scheduler.step(noise_pred, t, latents).prev_sample
|
89 |
print('test')
|
train_model.py
CHANGED
@@ -105,13 +105,13 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
|
|
105 |
for batch_index in range(0, num_images, batch_size):
|
106 |
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
107 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
108 |
-
clean_images = targets[batch_index:batch_end].to('cuda')
|
109 |
-
|
110 |
|
111 |
noise = torch.randn(clean_images.shape).to('cuda')
|
112 |
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
|
113 |
-
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps)
|
114 |
-
noise_pred = model.unet(noisy_images, timesteps,
|
115 |
loss = F.mse_loss(noise_pred, noise)
|
116 |
loss.backward()
|
117 |
|
|
|
105 |
for batch_index in range(0, num_images, batch_size):
|
106 |
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
107 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
108 |
+
clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
|
109 |
+
clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
|
110 |
|
111 |
noise = torch.randn(clean_images.shape).to('cuda')
|
112 |
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
|
113 |
+
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
114 |
+
noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
115 |
loss = F.mse_loss(noise_pred, noise)
|
116 |
loss.backward()
|
117 |
|