frutiemax commited on
Commit
88deab4
·
1 Parent(s): d751051

Use ConditionalUnetModel

Browse files
Files changed (3) hide show
  1. rct_diffusion_pipeline.py +24 -13
  2. test_pipeline.py +80 -0
  3. train_model.py +4 -4
rct_diffusion_pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
  from diffusers import DiffusionPipeline
2
  from diffusers import DDPMPipeline
3
- from diffusers import DDPMScheduler, UNet2DModel
4
  import torch
5
  import torchvision.transforms as T
6
  from PIL import Image
@@ -24,11 +24,13 @@ class RCTDiffusionPipeline(DiffusionPipeline):
24
  self.scheduler = DDPMScheduler()
25
 
26
  # the number of hidden features is dependant on the loaded dictionaries!
27
- self.unet = UNet2DModel(sample_size=256, in_channels=12, out_channels=12, \
28
- down_block_types=('DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D'), up_block_types=('UpBlock2D', 'UpBlock2D', 'AttnUpBlock2D'), \
29
- block_out_channels=(16, 32, 64), norm_num_groups=16)
 
 
30
 
31
- self.unet.to('cuda')
32
 
33
  def load_dictionaries_from_dataset(self):
34
  dataset = load_dataset('frutiemax/rct_dataset')
@@ -118,6 +120,8 @@ class RCTDiffusionPipeline(DiffusionPipeline):
118
 
119
  offset += len(self.color2_dict.items())
120
  class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index])
 
 
121
  return class_labels
122
 
123
  def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
@@ -164,29 +168,36 @@ class RCTDiffusionPipeline(DiffusionPipeline):
164
 
165
  # now put those weights into a tensor
166
  class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
167
- class_labels = class_labels.to('cuda')
 
 
 
 
 
168
 
169
  # set the inference steps
170
  self.scheduler.set_timesteps(num_inference_steps)
171
 
172
- noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to('cuda')
173
  for batch_index in range(batch_size):
174
  for view_index in range(4):
175
- noise = torch.randn(3, 256, 256).to('cuda')
176
  noise_batches[batch_index, view_index] = noise
177
 
178
  # reshape the data so it's (batch_size, 12, 256, 256)
179
- noise_batches = torch.reshape(noise_batches, (batch_size, 12, 256, 256)).to('cuda')
180
 
181
  # now call the model for the n interations
182
  progress_bar = tqdm(total=num_inference_steps)
183
  epoch = 0
184
  for t in self.scheduler.timesteps:
185
  progress_bar.set_description(f'Inference step {epoch}')
186
- with torch.no_grad():
187
- noise_residual = self.unet(noise_batches, t, class_labels=class_labels).sample
188
- previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches).prev_sample
189
- noise_batches = previous_noisy_sample
 
 
190
  progress_bar.update(1)
191
  epoch = epoch + 1
192
 
 
1
  from diffusers import DiffusionPipeline
2
  from diffusers import DDPMPipeline
3
+ from diffusers import DDPMScheduler, UNet2DConditionModel
4
  import torch
5
  import torchvision.transforms as T
6
  from PIL import Image
 
24
  self.scheduler = DDPMScheduler()
25
 
26
  # the number of hidden features is dependant on the loaded dictionaries!
27
+ hidden_dim = self.get_class_labels_size()
28
+ self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
29
+ down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
30
+ up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
31
+ block_out_channels=(12, 24, 30), norm_num_groups=6)
32
 
33
+ self.unet.to(device='cuda', dtype=torch.float16)
34
 
35
  def load_dictionaries_from_dataset(self):
36
  dataset = load_dataset('frutiemax/rct_dataset')
 
120
 
121
  offset += len(self.color2_dict.items())
122
  class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index])
123
+
124
+ class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size()))
125
  return class_labels
126
 
127
  def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
 
168
 
169
  # now put those weights into a tensor
170
  class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
171
+
172
+ # we need those class labels for the 12 channels
173
+ #new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
174
+ #new_class_labels[:, :] = class_labels
175
+ #class_labels = new_class_labels.to(device='cuda', dtype=torch.float16)
176
+ #del new_class_labels
177
 
178
  # set the inference steps
179
  self.scheduler.set_timesteps(num_inference_steps)
180
 
181
+ noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to(dtype=torch.float16, device='cuda')
182
  for batch_index in range(batch_size):
183
  for view_index in range(4):
184
+ noise = torch.randn(3, 256, 256).to(dtype=torch.float16, device='cuda')
185
  noise_batches[batch_index, view_index] = noise
186
 
187
  # reshape the data so it's (batch_size, 12, 256, 256)
188
+ noise_batches = torch.reshape(noise_batches, (batch_size, 1, 12, 256, 256)).to(dtype=torch.float16, device='cuda')
189
 
190
  # now call the model for the n interations
191
  progress_bar = tqdm(total=num_inference_steps)
192
  epoch = 0
193
  for t in self.scheduler.timesteps:
194
  progress_bar.set_description(f'Inference step {epoch}')
195
+
196
+ for batch_index in range(batch_size):
197
+ with torch.no_grad():
198
+ noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=class_labels).sample
199
+ previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
200
+ noise_batches[batch_index] = previous_noisy_sample
201
  progress_bar.update(1)
202
  epoch = epoch + 1
203
 
test_pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
 
2
 
3
 
4
  torch_device = "cuda"
@@ -6,4 +7,83 @@ torch_device = "cuda"
6
  pipeline = RCTDiffusionPipeline()
7
  pipeline.print_class_tokens_to_csv()
8
  output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  print('test')
 
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
2
+ from diffusers import UNet2DConditionModel
3
 
4
 
5
  torch_device = "cuda"
 
7
  pipeline = RCTDiffusionPipeline()
8
  pipeline.print_class_tokens_to_csv()
9
  output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
10
+
11
+ # from PIL import Image
12
+ # import torch
13
+ # from transformers import CLIPTextModel, CLIPTokenizer
14
+ # from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
15
+
16
+ # vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
17
+ # tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
18
+ # text_encoder = CLIPTextModel.from_pretrained(
19
+ # "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
20
+ # )
21
+ # unet = UNet2DConditionModel.from_pretrained(
22
+ # "CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True
23
+ # )
24
+
25
+ # from diffusers import UniPCMultistepScheduler
26
+
27
+ # scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
28
+ # torch_device = "cuda"
29
+ # vae.to(torch_device)
30
+ # text_encoder.to(torch_device)
31
+ # unet.to(torch_device)
32
+
33
+ # prompt = ["a photograph of an astronaut riding a horse"]
34
+ # height = 512 # default height of Stable Diffusion
35
+ # width = 512 # default width of Stable Diffusion
36
+ # num_inference_steps = 25 # Number of denoising steps
37
+ # guidance_scale = 7.5 # Scale for classifier-free guidance
38
+ # generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
39
+ # batch_size = len(prompt)
40
+
41
+ # text_input = tokenizer(
42
+ # prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
43
+ # )
44
+
45
+ # with torch.no_grad():
46
+ # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
47
+
48
+ # text_input = tokenizer(
49
+ # prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
50
+ # )
51
+
52
+ # with torch.no_grad():
53
+ # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
54
+
55
+ # max_length = text_input.input_ids.shape[-1]
56
+ # uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
57
+ # uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
58
+
59
+ # text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
60
+
61
+ # latents = torch.randn(
62
+ # (batch_size, unet.in_channels, height // 8, width // 8),
63
+ # generator=generator,
64
+ # )
65
+ # latents = latents.to(torch_device)
66
+
67
+ # latents = latents * scheduler.init_noise_sigma
68
+
69
+ # from tqdm.auto import tqdm
70
+
71
+ # scheduler.set_timesteps(num_inference_steps)
72
+
73
+ # for t in tqdm(scheduler.timesteps):
74
+ # # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
75
+ # latent_model_input = torch.cat([latents] * 2)
76
+
77
+ # latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
78
+
79
+ # # predict the noise residual
80
+ # with torch.no_grad():
81
+ # noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
82
+
83
+ # # perform guidance
84
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
85
+ # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
86
+
87
+ # # compute the previous noisy sample x_t -> x_t-1
88
+ # latents = scheduler.step(noise_pred, t, latents).prev_sample
89
  print('test')
train_model.py CHANGED
@@ -105,13 +105,13 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
105
  for batch_index in range(0, num_images, batch_size):
106
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
107
  batch_end = np.minimum(num_images, batch_index + batch_size)
108
- clean_images = targets[batch_index:batch_end].to('cuda')
109
- batch_labels = class_labels[batch_index:batch_end].to('cuda')
110
 
111
  noise = torch.randn(clean_images.shape).to('cuda')
112
  timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
113
- noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps)
114
- noise_pred = model.unet(noisy_images, timesteps, batch_labels, return_dict=False)[0]
115
  loss = F.mse_loss(noise_pred, noise)
116
  loss.backward()
117
 
 
105
  for batch_index in range(0, num_images, batch_size):
106
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
107
  batch_end = np.minimum(num_images, batch_index + batch_size)
108
+ clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
109
+ clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
110
 
111
  noise = torch.randn(clean_images.shape).to('cuda')
112
  timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
113
+ noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
114
+ noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
115
  loss = F.mse_loss(noise_pred, noise)
116
  loss.backward()
117