wondervictor commited on
Commit
a49d0a8
·
verified ·
1 Parent(s): ca14b5b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +242 -242
model.py CHANGED
@@ -1,242 +1,242 @@
1
- import gc
2
- import spaces
3
- from safetensors.torch import load_file
4
- from autoregressive.models.gpt_t2i import GPT_models
5
- from tokenizer.tokenizer_image.vq_model import VQ_models
6
- from language.t5 import T5Embedder
7
- import torch
8
- import numpy as np
9
- import PIL
10
- from PIL import Image
11
- from condition.canny import CannyDetector
12
- import time
13
- from autoregressive.models.generate import generate
14
- from condition.midas.depth import MidasDetector
15
-
16
- models = {
17
- "canny": "checkpoints/t2i/canny_MR.safetensors",
18
- "depth": "checkpoints/t2i/depth_MR.safetensors",
19
- }
20
-
21
-
22
- def resize_image_to_16_multiple(image, condition_type='canny'):
23
- if isinstance(image, np.ndarray):
24
- image = Image.fromarray(image)
25
- # image = Image.open(image_path)
26
- width, height = image.size
27
-
28
- if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
29
- new_width = (width + 31) // 32 * 32
30
- new_height = (height + 31) // 32 * 32
31
- else:
32
- new_width = (width + 15) // 16 * 16
33
- new_height = (height + 15) // 16 * 16
34
-
35
- resized_image = image.resize((new_width, new_height))
36
- return resized_image
37
-
38
-
39
- class Model:
40
-
41
- def __init__(self):
42
- self.device = torch.device(
43
- "cuda:0" if torch.cuda.is_available() else "cpu")
44
- self.base_model_id = ""
45
- self.task_name = ""
46
- self.vq_model = self.load_vq()
47
- self.t5_model = self.load_t5()
48
- self.gpt_model_canny = self.load_gpt(condition_type='canny')
49
- self.gpt_model_depth = self.load_gpt(condition_type='depth')
50
- self.get_control_canny = CannyDetector()
51
- self.get_control_depth = MidasDetector(device=self.device)
52
-
53
- def load_vq(self):
54
- vq_model = VQ_models["VQ-16"](codebook_size=16384,
55
- codebook_embed_dim=8)
56
- vq_model.to(self.device)
57
- vq_model.eval()
58
- checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
59
- map_location="cpu")
60
- vq_model.load_state_dict(checkpoint["model"])
61
- del checkpoint
62
- print(f"image tokenizer is loaded")
63
- return vq_model
64
-
65
- def load_gpt(self, condition_type='canny'):
66
- gpt_ckpt = models[condition_type]
67
- precision = torch.bfloat16
68
- latent_size = 768 // 16
69
- gpt_model = GPT_models["GPT-XL"](
70
- block_size=latent_size**2,
71
- cls_token_num=120,
72
- model_type='t2i',
73
- condition_type=condition_type,
74
- ).to(device=self.device, dtype=precision)
75
-
76
- model_weight = load_file(gpt_ckpt)
77
- gpt_model.load_state_dict(model_weight, strict=False)
78
- gpt_model.eval()
79
- print(f"gpt model is loaded")
80
- return gpt_model
81
-
82
- def load_t5(self):
83
- precision = torch.bfloat16
84
- t5_model = T5Embedder(
85
- device=self.device,
86
- local_cache=True,
87
- # cache_dir='checkpoints/t5-ckpt',
88
- dir_or_name='flan-t5-xl',
89
- torch_dtype=precision,
90
- model_max_length=120,
91
- )
92
- return t5_model
93
-
94
- @torch.no_grad()
95
- @spaces.GPU(enable_queue=True)
96
- def process_canny(
97
- self,
98
- image: np.ndarray,
99
- prompt: str,
100
- cfg_scale: float,
101
- temperature: float,
102
- top_k: int,
103
- top_p: int,
104
- seed: int,
105
- low_threshold: int,
106
- high_threshold: int,
107
- ) -> list[PIL.Image.Image]:
108
-
109
- image = resize_image_to_16_multiple(image, 'canny')
110
- W, H = image.size
111
- print(W, H)
112
- condition_img = self.get_control_canny(np.array(image), low_threshold,
113
- high_threshold)
114
- condition_img = torch.from_numpy(condition_img[None, None,
115
- ...]).repeat(
116
- 2, 3, 1, 1)
117
- condition_img = condition_img.to(self.device)
118
- condition_img = 2 * (condition_img / 255 - 0.5)
119
- prompts = [prompt] * 2
120
- caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
121
-
122
- print(f"processing left-padding...")
123
- new_emb_masks = torch.flip(emb_masks, dims=[-1])
124
- new_caption_embs = []
125
- for idx, (caption_emb,
126
- emb_mask) in enumerate(zip(caption_embs, emb_masks)):
127
- valid_num = int(emb_mask.sum().item())
128
- print(f' prompt {idx} token len: {valid_num}')
129
- new_caption_emb = torch.cat(
130
- [caption_emb[valid_num:], caption_emb[:valid_num]])
131
- new_caption_embs.append(new_caption_emb)
132
- new_caption_embs = torch.stack(new_caption_embs)
133
- c_indices = new_caption_embs * new_emb_masks[:, :, None]
134
- c_emb_masks = new_emb_masks
135
- qzshape = [len(c_indices), 8, H // 16, W // 16]
136
- t1 = time.time()
137
- index_sample = generate(
138
- self.gpt_model_canny,
139
- c_indices,
140
- (H // 16) * (W // 16),
141
- c_emb_masks,
142
- condition=condition_img,
143
- cfg_scale=cfg_scale,
144
- temperature=temperature,
145
- top_k=top_k,
146
- top_p=top_p,
147
- sample_logits=True,
148
- )
149
- sampling_time = time.time() - t1
150
- print(f"Full sampling takes about {sampling_time:.2f} seconds.")
151
-
152
- t2 = time.time()
153
- print(index_sample.shape)
154
- samples = self.vq_model.decode_code(
155
- index_sample, qzshape) # output value is between [-1, 1]
156
- decoder_time = time.time() - t2
157
- print(f"decoder takes about {decoder_time:.2f} seconds.")
158
-
159
- samples = torch.cat((condition_img[0:1], samples), dim=0)
160
- samples = 255 * (samples * 0.5 + 0.5)
161
- samples = [image] + [
162
- Image.fromarray(
163
- sample.permute(1, 2, 0).cpu().detach().numpy().clip(
164
- 0, 255).astype(np.uint8)) for sample in samples
165
- ]
166
- del condition_img
167
- torch.cuda.empty_cache()
168
- return samples
169
-
170
- @torch.no_grad()
171
- @spaces.GPU(enable_queue=True)
172
- def process_depth(
173
- self,
174
- image: np.ndarray,
175
- prompt: str,
176
- cfg_scale: float,
177
- temperature: float,
178
- top_k: int,
179
- top_p: int,
180
- seed: int,
181
- ) -> list[PIL.Image.Image]:
182
- image = resize_image_to_16_multiple(image, 'depth')
183
- W, H = image.size
184
- print(W, H)
185
- image_tensor = torch.from_numpy(np.array(image)).to(self.device)
186
- condition_img = torch.from_numpy(
187
- self.get_control_depth(image_tensor)).unsqueeze(0)
188
- condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
189
- condition_img = condition_img.to(self.device)
190
- condition_img = 2 * (condition_img / 255 - 0.5)
191
- prompts = [prompt] * 2
192
- caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
193
-
194
- print(f"processing left-padding...")
195
- new_emb_masks = torch.flip(emb_masks, dims=[-1])
196
- new_caption_embs = []
197
- for idx, (caption_emb,
198
- emb_mask) in enumerate(zip(caption_embs, emb_masks)):
199
- valid_num = int(emb_mask.sum().item())
200
- print(f' prompt {idx} token len: {valid_num}')
201
- new_caption_emb = torch.cat(
202
- [caption_emb[valid_num:], caption_emb[:valid_num]])
203
- new_caption_embs.append(new_caption_emb)
204
- new_caption_embs = torch.stack(new_caption_embs)
205
-
206
- c_indices = new_caption_embs * new_emb_masks[:, :, None]
207
- c_emb_masks = new_emb_masks
208
- qzshape = [len(c_indices), 8, H // 16, W // 16]
209
- t1 = time.time()
210
- index_sample = generate(
211
- self.gpt_model_depth,
212
- c_indices,
213
- (H // 16) * (W // 16),
214
- c_emb_masks,
215
- condition=condition_img,
216
- cfg_scale=cfg_scale,
217
- temperature=temperature,
218
- top_k=top_k,
219
- top_p=top_p,
220
- sample_logits=True,
221
- )
222
- sampling_time = time.time() - t1
223
- print(f"Full sampling takes about {sampling_time:.2f} seconds.")
224
-
225
- t2 = time.time()
226
- print(index_sample.shape)
227
- samples = self.vq_model.decode_code(index_sample, qzshape)
228
- decoder_time = time.time() - t2
229
- print(f"decoder takes about {decoder_time:.2f} seconds.")
230
- condition_img = condition_img.cpu()
231
- samples = samples.cpu()
232
- samples = torch.cat((condition_img[0:1], samples), dim=0)
233
- samples = 255 * (samples * 0.5 + 0.5)
234
- samples = [image] + [
235
- Image.fromarray(
236
- sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8))
237
- for sample in samples
238
- ]
239
- del image_tensor
240
- del condition_img
241
- torch.cuda.empty_cache()
242
- return samples
 
1
+ import gc
2
+ import spaces
3
+ from safetensors.torch import load_file
4
+ from autoregressive.models.gpt_t2i import GPT_models
5
+ from tokenizer.tokenizer_image.vq_model import VQ_models
6
+ from language.t5 import T5Embedder
7
+ import torch
8
+ import numpy as np
9
+ import PIL
10
+ from PIL import Image
11
+ from condition.canny import CannyDetector
12
+ import time
13
+ from autoregressive.models.generate import generate
14
+ from condition.midas.depth import MidasDetector
15
+
16
+ models = {
17
+ "canny": "checkpoints/t2i/canny_MR.safetensors",
18
+ "depth": "checkpoints/t2i/depth_MR.safetensors",
19
+ }
20
+
21
+
22
+ def resize_image_to_16_multiple(image, condition_type='canny'):
23
+ if isinstance(image, np.ndarray):
24
+ image = Image.fromarray(image)
25
+ # image = Image.open(image_path)
26
+ width, height = image.size
27
+
28
+ if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
29
+ new_width = (width + 31) // 32 * 32
30
+ new_height = (height + 31) // 32 * 32
31
+ else:
32
+ new_width = (width + 15) // 16 * 16
33
+ new_height = (height + 15) // 16 * 16
34
+
35
+ resized_image = image.resize((new_width, new_height))
36
+ return resized_image
37
+
38
+
39
+ class Model:
40
+
41
+ def __init__(self):
42
+ self.device = torch.device(
43
+ "cuda:0" if torch.cuda.is_available() else "cpu")
44
+ self.base_model_id = ""
45
+ self.task_name = ""
46
+ self.vq_model = self.load_vq()
47
+ self.t5_model = self.load_t5()
48
+ self.gpt_model_canny = self.load_gpt(condition_type='canny')
49
+ self.gpt_model_depth = self.load_gpt(condition_type='depth')
50
+ self.get_control_canny = CannyDetector()
51
+ self.get_control_depth = MidasDetector(device=self.device)
52
+
53
+ def load_vq(self):
54
+ vq_model = VQ_models["VQ-16"](codebook_size=16384,
55
+ codebook_embed_dim=8)
56
+ vq_model.to(self.device)
57
+ vq_model.eval()
58
+ checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
59
+ map_location="cpu")
60
+ vq_model.load_state_dict(checkpoint["model"])
61
+ del checkpoint
62
+ print(f"image tokenizer is loaded")
63
+ return vq_model
64
+
65
+ def load_gpt(self, condition_type='canny'):
66
+ gpt_ckpt = models[condition_type]
67
+ precision = torch.bfloat16
68
+ latent_size = 768 // 16
69
+ gpt_model = GPT_models["GPT-XL"](
70
+ block_size=latent_size**2,
71
+ cls_token_num=120,
72
+ model_type='t2i',
73
+ condition_type=condition_type,
74
+ ).to(device=self.device, dtype=precision)
75
+
76
+ model_weight = load_file(gpt_ckpt)
77
+ gpt_model.load_state_dict(model_weight, strict=False)
78
+ gpt_model.eval()
79
+ print(f"gpt model is loaded")
80
+ return gpt_model
81
+
82
+ def load_t5(self):
83
+ precision = torch.bfloat16
84
+ t5_model = T5Embedder(
85
+ device=self.device,
86
+ local_cache=True,
87
+ cache_dir='checkpoints/t5-ckpt',
88
+ dir_or_name='google/flan-t5-xl',
89
+ torch_dtype=precision,
90
+ model_max_length=120,
91
+ )
92
+ return t5_model
93
+
94
+ @torch.no_grad()
95
+ @spaces.GPU(enable_queue=True)
96
+ def process_canny(
97
+ self,
98
+ image: np.ndarray,
99
+ prompt: str,
100
+ cfg_scale: float,
101
+ temperature: float,
102
+ top_k: int,
103
+ top_p: int,
104
+ seed: int,
105
+ low_threshold: int,
106
+ high_threshold: int,
107
+ ) -> list[PIL.Image.Image]:
108
+
109
+ image = resize_image_to_16_multiple(image, 'canny')
110
+ W, H = image.size
111
+ print(W, H)
112
+ condition_img = self.get_control_canny(np.array(image), low_threshold,
113
+ high_threshold)
114
+ condition_img = torch.from_numpy(condition_img[None, None,
115
+ ...]).repeat(
116
+ 2, 3, 1, 1)
117
+ condition_img = condition_img.to(self.device)
118
+ condition_img = 2 * (condition_img / 255 - 0.5)
119
+ prompts = [prompt] * 2
120
+ caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
121
+
122
+ print(f"processing left-padding...")
123
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
124
+ new_caption_embs = []
125
+ for idx, (caption_emb,
126
+ emb_mask) in enumerate(zip(caption_embs, emb_masks)):
127
+ valid_num = int(emb_mask.sum().item())
128
+ print(f' prompt {idx} token len: {valid_num}')
129
+ new_caption_emb = torch.cat(
130
+ [caption_emb[valid_num:], caption_emb[:valid_num]])
131
+ new_caption_embs.append(new_caption_emb)
132
+ new_caption_embs = torch.stack(new_caption_embs)
133
+ c_indices = new_caption_embs * new_emb_masks[:, :, None]
134
+ c_emb_masks = new_emb_masks
135
+ qzshape = [len(c_indices), 8, H // 16, W // 16]
136
+ t1 = time.time()
137
+ index_sample = generate(
138
+ self.gpt_model_canny,
139
+ c_indices,
140
+ (H // 16) * (W // 16),
141
+ c_emb_masks,
142
+ condition=condition_img,
143
+ cfg_scale=cfg_scale,
144
+ temperature=temperature,
145
+ top_k=top_k,
146
+ top_p=top_p,
147
+ sample_logits=True,
148
+ )
149
+ sampling_time = time.time() - t1
150
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
151
+
152
+ t2 = time.time()
153
+ print(index_sample.shape)
154
+ samples = self.vq_model.decode_code(
155
+ index_sample, qzshape) # output value is between [-1, 1]
156
+ decoder_time = time.time() - t2
157
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
158
+
159
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
160
+ samples = 255 * (samples * 0.5 + 0.5)
161
+ samples = [image] + [
162
+ Image.fromarray(
163
+ sample.permute(1, 2, 0).cpu().detach().numpy().clip(
164
+ 0, 255).astype(np.uint8)) for sample in samples
165
+ ]
166
+ del condition_img
167
+ torch.cuda.empty_cache()
168
+ return samples
169
+
170
+ @torch.no_grad()
171
+ @spaces.GPU(enable_queue=True)
172
+ def process_depth(
173
+ self,
174
+ image: np.ndarray,
175
+ prompt: str,
176
+ cfg_scale: float,
177
+ temperature: float,
178
+ top_k: int,
179
+ top_p: int,
180
+ seed: int,
181
+ ) -> list[PIL.Image.Image]:
182
+ image = resize_image_to_16_multiple(image, 'depth')
183
+ W, H = image.size
184
+ print(W, H)
185
+ image_tensor = torch.from_numpy(np.array(image)).to(self.device)
186
+ condition_img = torch.from_numpy(
187
+ self.get_control_depth(image_tensor)).unsqueeze(0)
188
+ condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
189
+ condition_img = condition_img.to(self.device)
190
+ condition_img = 2 * (condition_img / 255 - 0.5)
191
+ prompts = [prompt] * 2
192
+ caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
193
+
194
+ print(f"processing left-padding...")
195
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
196
+ new_caption_embs = []
197
+ for idx, (caption_emb,
198
+ emb_mask) in enumerate(zip(caption_embs, emb_masks)):
199
+ valid_num = int(emb_mask.sum().item())
200
+ print(f' prompt {idx} token len: {valid_num}')
201
+ new_caption_emb = torch.cat(
202
+ [caption_emb[valid_num:], caption_emb[:valid_num]])
203
+ new_caption_embs.append(new_caption_emb)
204
+ new_caption_embs = torch.stack(new_caption_embs)
205
+
206
+ c_indices = new_caption_embs * new_emb_masks[:, :, None]
207
+ c_emb_masks = new_emb_masks
208
+ qzshape = [len(c_indices), 8, H // 16, W // 16]
209
+ t1 = time.time()
210
+ index_sample = generate(
211
+ self.gpt_model_depth,
212
+ c_indices,
213
+ (H // 16) * (W // 16),
214
+ c_emb_masks,
215
+ condition=condition_img,
216
+ cfg_scale=cfg_scale,
217
+ temperature=temperature,
218
+ top_k=top_k,
219
+ top_p=top_p,
220
+ sample_logits=True,
221
+ )
222
+ sampling_time = time.time() - t1
223
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
224
+
225
+ t2 = time.time()
226
+ print(index_sample.shape)
227
+ samples = self.vq_model.decode_code(index_sample, qzshape)
228
+ decoder_time = time.time() - t2
229
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
230
+ condition_img = condition_img.cpu()
231
+ samples = samples.cpu()
232
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
233
+ samples = 255 * (samples * 0.5 + 0.5)
234
+ samples = [image] + [
235
+ Image.fromarray(
236
+ sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8))
237
+ for sample in samples
238
+ ]
239
+ del image_tensor
240
+ del condition_img
241
+ torch.cuda.empty_cache()
242
+ return samples