callMeHeisenberg
commited on
Commit
•
daf0991
1
Parent(s):
3df0a69
Upload 10 files
Browse files- OmniGen/__init__.py +4 -0
- OmniGen/model.py +406 -0
- OmniGen/pipeline.py +307 -0
- OmniGen/processor.py +338 -0
- OmniGen/scheduler.py +181 -0
- OmniGen/train_helper/__init__.py +2 -0
- OmniGen/train_helper/data.py +116 -0
- OmniGen/train_helper/loss.py +68 -0
- OmniGen/transformer.py +194 -0
- OmniGen/utils.py +110 -0
OmniGen/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import OmniGen
|
2 |
+
from .processor import OmniGenProcessor
|
3 |
+
from .scheduler import OmniGenScheduler
|
4 |
+
from .pipeline import OmniGenPipeline
|
OmniGen/model.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The code is revised from DiT
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
from typing import Dict
|
8 |
+
|
9 |
+
from diffusers.loaders import PeftAdapterMixin
|
10 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
from safetensors.torch import load_file
|
13 |
+
|
14 |
+
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
15 |
+
|
16 |
+
|
17 |
+
def modulate(x, shift, scale):
|
18 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
19 |
+
|
20 |
+
|
21 |
+
class TimestepEmbedder(nn.Module):
|
22 |
+
"""
|
23 |
+
Embeds scalar timesteps into vector representations.
|
24 |
+
"""
|
25 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
26 |
+
super().__init__()
|
27 |
+
self.mlp = nn.Sequential(
|
28 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
29 |
+
nn.SiLU(),
|
30 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
31 |
+
)
|
32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def timestep_embedding(t, dim, max_period=10000):
|
36 |
+
"""
|
37 |
+
Create sinusoidal timestep embeddings.
|
38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
39 |
+
These may be fractional.
|
40 |
+
:param dim: the dimension of the output.
|
41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
43 |
+
"""
|
44 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
45 |
+
half = dim // 2
|
46 |
+
freqs = torch.exp(
|
47 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
48 |
+
).to(device=t.device)
|
49 |
+
args = t[:, None].float() * freqs[None]
|
50 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
51 |
+
if dim % 2:
|
52 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
53 |
+
return embedding
|
54 |
+
|
55 |
+
def forward(self, t, dtype=torch.float32):
|
56 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
57 |
+
t_emb = self.mlp(t_freq)
|
58 |
+
return t_emb
|
59 |
+
|
60 |
+
|
61 |
+
class FinalLayer(nn.Module):
|
62 |
+
"""
|
63 |
+
The final layer of DiT.
|
64 |
+
"""
|
65 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
66 |
+
super().__init__()
|
67 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
68 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
69 |
+
self.adaLN_modulation = nn.Sequential(
|
70 |
+
nn.SiLU(),
|
71 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, x, c):
|
75 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
76 |
+
x = modulate(self.norm_final(x), shift, scale)
|
77 |
+
x = self.linear(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
82 |
+
"""
|
83 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
84 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
85 |
+
"""
|
86 |
+
if isinstance(grid_size, int):
|
87 |
+
grid_size = (grid_size, grid_size)
|
88 |
+
|
89 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
90 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
91 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
92 |
+
grid = np.stack(grid, axis=0)
|
93 |
+
|
94 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
95 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
96 |
+
if cls_token and extra_tokens > 0:
|
97 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
98 |
+
return pos_embed
|
99 |
+
|
100 |
+
|
101 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
102 |
+
assert embed_dim % 2 == 0
|
103 |
+
|
104 |
+
# use half of dimensions to encode grid_h
|
105 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
106 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
107 |
+
|
108 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
109 |
+
return emb
|
110 |
+
|
111 |
+
|
112 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
113 |
+
"""
|
114 |
+
embed_dim: output dimension for each position
|
115 |
+
pos: a list of positions to be encoded: size (M,)
|
116 |
+
out: (M, D)
|
117 |
+
"""
|
118 |
+
assert embed_dim % 2 == 0
|
119 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
120 |
+
omega /= embed_dim / 2.
|
121 |
+
omega = 1. / 10000**omega # (D/2,)
|
122 |
+
|
123 |
+
pos = pos.reshape(-1) # (M,)
|
124 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
125 |
+
|
126 |
+
emb_sin = np.sin(out) # (M, D/2)
|
127 |
+
emb_cos = np.cos(out) # (M, D/2)
|
128 |
+
|
129 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
130 |
+
return emb
|
131 |
+
|
132 |
+
|
133 |
+
class PatchEmbedMR(nn.Module):
|
134 |
+
""" 2D Image to Patch Embedding
|
135 |
+
"""
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
patch_size: int = 2,
|
139 |
+
in_chans: int = 4,
|
140 |
+
embed_dim: int = 768,
|
141 |
+
bias: bool = True,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
x = self.proj(x)
|
148 |
+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class OmniGen(nn.Module, PeftAdapterMixin):
|
153 |
+
"""
|
154 |
+
Diffusion model with a Transformer backbone.
|
155 |
+
"""
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
transformer_config: Phi3Config,
|
159 |
+
patch_size=2,
|
160 |
+
in_channels=4,
|
161 |
+
pe_interpolation: float = 1.0,
|
162 |
+
pos_embed_max_size: int = 192,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.in_channels = in_channels
|
166 |
+
self.out_channels = in_channels
|
167 |
+
self.patch_size = patch_size
|
168 |
+
self.pos_embed_max_size = pos_embed_max_size
|
169 |
+
|
170 |
+
hidden_size = transformer_config.hidden_size
|
171 |
+
|
172 |
+
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
173 |
+
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
174 |
+
|
175 |
+
self.time_token = TimestepEmbedder(hidden_size)
|
176 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
177 |
+
|
178 |
+
self.pe_interpolation = pe_interpolation
|
179 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
180 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
181 |
+
|
182 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
183 |
+
|
184 |
+
self.initialize_weights()
|
185 |
+
|
186 |
+
self.llm = Phi3Transformer(config=transformer_config)
|
187 |
+
self.llm.config.use_cache = False
|
188 |
+
|
189 |
+
@classmethod
|
190 |
+
def from_pretrained(cls, model_name):
|
191 |
+
if not os.path.exists(model_name):
|
192 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
193 |
+
model_name = snapshot_download(repo_id=model_name,
|
194 |
+
cache_dir=cache_folder,
|
195 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
196 |
+
config = Phi3Config.from_pretrained(model_name)
|
197 |
+
model = cls(config)
|
198 |
+
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
199 |
+
print("Loading safetensors")
|
200 |
+
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
201 |
+
else:
|
202 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
203 |
+
model.load_state_dict(ckpt)
|
204 |
+
return model
|
205 |
+
|
206 |
+
def initialize_weights(self):
|
207 |
+
assert not hasattr(self, "llama")
|
208 |
+
|
209 |
+
# Initialize transformer layers:
|
210 |
+
def _basic_init(module):
|
211 |
+
if isinstance(module, nn.Linear):
|
212 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
213 |
+
if module.bias is not None:
|
214 |
+
nn.init.constant_(module.bias, 0)
|
215 |
+
self.apply(_basic_init)
|
216 |
+
|
217 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
218 |
+
w = self.x_embedder.proj.weight.data
|
219 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
220 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
221 |
+
|
222 |
+
w = self.input_x_embedder.proj.weight.data
|
223 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
224 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
225 |
+
|
226 |
+
|
227 |
+
# Initialize timestep embedding MLP:
|
228 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
229 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
230 |
+
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
231 |
+
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
232 |
+
|
233 |
+
# Zero-out output layers:
|
234 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
235 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
236 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
237 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
238 |
+
|
239 |
+
def unpatchify(self, x, h, w):
|
240 |
+
"""
|
241 |
+
x: (N, T, patch_size**2 * C)
|
242 |
+
imgs: (N, H, W, C)
|
243 |
+
"""
|
244 |
+
c = self.out_channels
|
245 |
+
|
246 |
+
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
247 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
248 |
+
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
249 |
+
return imgs
|
250 |
+
|
251 |
+
|
252 |
+
def cropped_pos_embed(self, height, width):
|
253 |
+
"""Crops positional embeddings for SD3 compatibility."""
|
254 |
+
if self.pos_embed_max_size is None:
|
255 |
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
256 |
+
|
257 |
+
height = height // self.patch_size
|
258 |
+
width = width // self.patch_size
|
259 |
+
if height > self.pos_embed_max_size:
|
260 |
+
raise ValueError(
|
261 |
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
262 |
+
)
|
263 |
+
if width > self.pos_embed_max_size:
|
264 |
+
raise ValueError(
|
265 |
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
266 |
+
)
|
267 |
+
|
268 |
+
top = (self.pos_embed_max_size - height) // 2
|
269 |
+
left = (self.pos_embed_max_size - width) // 2
|
270 |
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
271 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
272 |
+
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
273 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
274 |
+
return spatial_pos_embed
|
275 |
+
|
276 |
+
|
277 |
+
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
278 |
+
if isinstance(latents, list):
|
279 |
+
return_list = False
|
280 |
+
if padding_latent is None:
|
281 |
+
padding_latent = [None] * len(latents)
|
282 |
+
return_list = True
|
283 |
+
patched_latents, num_tokens, shapes = [], [], []
|
284 |
+
for latent, padding in zip(latents, padding_latent):
|
285 |
+
height, width = latent.shape[-2:]
|
286 |
+
if is_input_images:
|
287 |
+
latent = self.input_x_embedder(latent)
|
288 |
+
else:
|
289 |
+
latent = self.x_embedder(latent)
|
290 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
291 |
+
latent = latent + pos_embed
|
292 |
+
if padding is not None:
|
293 |
+
latent = torch.cat([latent, padding], dim=-2)
|
294 |
+
patched_latents.append(latent)
|
295 |
+
|
296 |
+
num_tokens.append(pos_embed.size(1))
|
297 |
+
shapes.append([height, width])
|
298 |
+
if not return_list:
|
299 |
+
latents = torch.cat(patched_latents, dim=0)
|
300 |
+
else:
|
301 |
+
latents = patched_latents
|
302 |
+
else:
|
303 |
+
height, width = latents.shape[-2:]
|
304 |
+
if is_input_images:
|
305 |
+
latents = self.input_x_embedder(latents)
|
306 |
+
else:
|
307 |
+
latents = self.x_embedder(latents)
|
308 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
309 |
+
latents = latents + pos_embed
|
310 |
+
num_tokens = latents.size(1)
|
311 |
+
shapes = [height, width]
|
312 |
+
return latents, num_tokens, shapes
|
313 |
+
|
314 |
+
|
315 |
+
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
316 |
+
"""
|
317 |
+
|
318 |
+
"""
|
319 |
+
input_is_list = isinstance(x, list)
|
320 |
+
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
321 |
+
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
322 |
+
|
323 |
+
if input_img_latents is not None:
|
324 |
+
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
325 |
+
if input_ids is not None:
|
326 |
+
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
327 |
+
input_img_inx = 0
|
328 |
+
for b_inx in input_image_sizes.keys():
|
329 |
+
for start_inx, end_inx in input_image_sizes[b_inx]:
|
330 |
+
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
331 |
+
input_img_inx += 1
|
332 |
+
if input_img_latents is not None:
|
333 |
+
assert input_img_inx == len(input_latents)
|
334 |
+
|
335 |
+
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
336 |
+
else:
|
337 |
+
input_emb = torch.cat([time_token, x], dim=1)
|
338 |
+
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
339 |
+
output, past_key_values = output.last_hidden_state, output.past_key_values
|
340 |
+
if input_is_list:
|
341 |
+
image_embedding = output[:, -max(num_tokens):]
|
342 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
343 |
+
x = self.final_layer(image_embedding, time_emb)
|
344 |
+
latents = []
|
345 |
+
for i in range(x.size(0)):
|
346 |
+
latent = x[i:i+1, :num_tokens[i]]
|
347 |
+
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
348 |
+
latents.append(latent)
|
349 |
+
else:
|
350 |
+
image_embedding = output[:, -num_tokens:]
|
351 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
352 |
+
x = self.final_layer(image_embedding, time_emb)
|
353 |
+
latents = self.unpatchify(x, shapes[0], shapes[1])
|
354 |
+
|
355 |
+
if return_past_key_values:
|
356 |
+
return latents, past_key_values
|
357 |
+
return latents
|
358 |
+
|
359 |
+
@torch.no_grad()
|
360 |
+
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
361 |
+
self.llm.config.use_cache = use_kv_cache
|
362 |
+
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
363 |
+
if use_img_cfg:
|
364 |
+
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
365 |
+
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
366 |
+
model_out = [cond, cond, cond]
|
367 |
+
else:
|
368 |
+
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
369 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
370 |
+
model_out = [cond, cond]
|
371 |
+
|
372 |
+
return torch.cat(model_out, dim=0), past_key_values
|
373 |
+
|
374 |
+
|
375 |
+
@torch.no_grad()
|
376 |
+
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
377 |
+
self.llm.config.use_cache = use_kv_cache
|
378 |
+
if past_key_values is None:
|
379 |
+
past_key_values = [None] * len(attention_mask)
|
380 |
+
|
381 |
+
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
382 |
+
timestep = timestep.to(x[0].dtype)
|
383 |
+
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
384 |
+
|
385 |
+
model_out, pask_key_values = [], []
|
386 |
+
for i in range(len(input_ids)):
|
387 |
+
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
388 |
+
model_out.append(temp_out)
|
389 |
+
pask_key_values.append(temp_pask_key_values)
|
390 |
+
|
391 |
+
if len(model_out) == 3:
|
392 |
+
cond, uncond, img_cond = model_out
|
393 |
+
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
394 |
+
model_out = [cond, cond, cond]
|
395 |
+
elif len(model_out) == 2:
|
396 |
+
cond, uncond = model_out
|
397 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
398 |
+
model_out = [cond, cond]
|
399 |
+
else:
|
400 |
+
return model_out[0]
|
401 |
+
|
402 |
+
return torch.cat(model_out, dim=0), pask_key_values
|
403 |
+
|
404 |
+
|
405 |
+
|
406 |
+
|
OmniGen/pipeline.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import inspect
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
+
import gc
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from peft import LoraConfig, PeftModel
|
11 |
+
from diffusers.models import AutoencoderKL
|
12 |
+
from diffusers.utils import (
|
13 |
+
USE_PEFT_BACKEND,
|
14 |
+
is_torch_xla_available,
|
15 |
+
logging,
|
16 |
+
replace_example_docstring,
|
17 |
+
scale_lora_layers,
|
18 |
+
unscale_lora_layers,
|
19 |
+
)
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
|
22 |
+
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
EXAMPLE_DOC_STRING = """
|
28 |
+
Examples:
|
29 |
+
```py
|
30 |
+
>>> from OmniGen import OmniGenPipeline
|
31 |
+
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
32 |
+
... base_model
|
33 |
+
... )
|
34 |
+
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
35 |
+
>>> image = pipe(
|
36 |
+
... prompt,
|
37 |
+
... guidance_scale=2.5,
|
38 |
+
... num_inference_steps=50,
|
39 |
+
... ).images[0]
|
40 |
+
>>> image.save("t2i.png")
|
41 |
+
```
|
42 |
+
"""
|
43 |
+
|
44 |
+
|
45 |
+
90
|
46 |
+
class OmniGenPipeline:
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
vae: AutoencoderKL,
|
50 |
+
model: OmniGen,
|
51 |
+
processor: OmniGenProcessor,
|
52 |
+
):
|
53 |
+
self.vae = vae
|
54 |
+
self.model = model
|
55 |
+
self.processor = processor
|
56 |
+
|
57 |
+
if torch.cuda.is_available():
|
58 |
+
self.device = torch.device("cuda")
|
59 |
+
elif torch.backends.mps.is_available():
|
60 |
+
self.device = torch.device("mps")
|
61 |
+
else:
|
62 |
+
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
|
63 |
+
self.device = torch.device("cpu")
|
64 |
+
|
65 |
+
self.model.to(torch.bfloat16)
|
66 |
+
self.model.eval()
|
67 |
+
self.vae.eval()
|
68 |
+
|
69 |
+
self.model_cpu_offload = False
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def from_pretrained(cls, model_name, vae_path: str=None):
|
73 |
+
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
|
74 |
+
logger.info("Model not found, downloading...")
|
75 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
76 |
+
model_name = snapshot_download(repo_id=model_name,
|
77 |
+
cache_dir=cache_folder,
|
78 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
|
79 |
+
logger.info(f"Downloaded model to {model_name}")
|
80 |
+
model = OmniGen.from_pretrained(model_name)
|
81 |
+
processor = OmniGenProcessor.from_pretrained(model_name)
|
82 |
+
|
83 |
+
if os.path.exists(os.path.join(model_name, "vae")):
|
84 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
|
85 |
+
elif vae_path is not None:
|
86 |
+
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
|
87 |
+
else:
|
88 |
+
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
|
89 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
|
90 |
+
|
91 |
+
return cls(vae, model, processor)
|
92 |
+
|
93 |
+
def merge_lora(self, lora_path: str):
|
94 |
+
model = PeftModel.from_pretrained(self.model, lora_path)
|
95 |
+
model.merge_and_unload()
|
96 |
+
|
97 |
+
self.model = model
|
98 |
+
|
99 |
+
def to(self, device: Union[str, torch.device]):
|
100 |
+
if isinstance(device, str):
|
101 |
+
device = torch.device(device)
|
102 |
+
self.model.to(device)
|
103 |
+
self.vae.to(device)
|
104 |
+
self.device = device
|
105 |
+
|
106 |
+
def vae_encode(self, x, dtype):
|
107 |
+
if self.vae.config.shift_factor is not None:
|
108 |
+
x = self.vae.encode(x).latent_dist.sample()
|
109 |
+
x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
110 |
+
else:
|
111 |
+
x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
|
112 |
+
x = x.to(dtype)
|
113 |
+
return x
|
114 |
+
|
115 |
+
def move_to_device(self, data):
|
116 |
+
if isinstance(data, list):
|
117 |
+
return [x.to(self.device) for x in data]
|
118 |
+
return data.to(self.device)
|
119 |
+
|
120 |
+
def enable_model_cpu_offload(self):
|
121 |
+
self.model_cpu_offload = True
|
122 |
+
self.model.to("cpu")
|
123 |
+
self.vae.to("cpu")
|
124 |
+
torch.cuda.empty_cache() # Clear VRAM
|
125 |
+
gc.collect() # Run garbage collection to free system RAM
|
126 |
+
|
127 |
+
def disable_model_cpu_offload(self):
|
128 |
+
self.model_cpu_offload = False
|
129 |
+
self.model.to(self.device)
|
130 |
+
self.vae.to(self.device)
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
134 |
+
def __call__(
|
135 |
+
self,
|
136 |
+
prompt: Union[str, List[str]],
|
137 |
+
input_images: Union[List[str], List[List[str]]] = None,
|
138 |
+
height: int = 1024,
|
139 |
+
width: int = 1024,
|
140 |
+
num_inference_steps: int = 50,
|
141 |
+
guidance_scale: float = 3,
|
142 |
+
use_img_guidance: bool = True,
|
143 |
+
img_guidance_scale: float = 1.6,
|
144 |
+
max_input_image_size: int = 1024,
|
145 |
+
separate_cfg_infer: bool = True,
|
146 |
+
offload_model: bool = False,
|
147 |
+
use_kv_cache: bool = True,
|
148 |
+
offload_kv_cache: bool = True,
|
149 |
+
use_input_image_size_as_output: bool = False,
|
150 |
+
dtype: torch.dtype = torch.bfloat16,
|
151 |
+
seed: int = None,
|
152 |
+
):
|
153 |
+
r"""
|
154 |
+
Function invoked when calling the pipeline for generation.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
prompt (`str` or `List[str]`):
|
158 |
+
The prompt or prompts to guide the image generation.
|
159 |
+
input_images (`List[str]` or `List[List[str]]`, *optional*):
|
160 |
+
The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
|
161 |
+
height (`int`, *optional*, defaults to 1024):
|
162 |
+
The height in pixels of the generated image. The number must be a multiple of 16.
|
163 |
+
width (`int`, *optional*, defaults to 1024):
|
164 |
+
The width in pixels of the generated image. The number must be a multiple of 16.
|
165 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
166 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
167 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
168 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
169 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
170 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
171 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
172 |
+
usually at the expense of lower image quality.
|
173 |
+
use_img_guidance (`bool`, *optional*, defaults to True):
|
174 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
175 |
+
img_guidance_scale (`float`, *optional*, defaults to 1.6):
|
176 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
177 |
+
max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
|
178 |
+
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
179 |
+
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
180 |
+
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
181 |
+
offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
|
182 |
+
offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
|
183 |
+
use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
|
184 |
+
seed (`int`, *optional*):
|
185 |
+
A random seed for generating output.
|
186 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
187 |
+
data type for the model
|
188 |
+
Examples:
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
A list with the generated images.
|
192 |
+
"""
|
193 |
+
# check inputs:
|
194 |
+
if use_input_image_size_as_output:
|
195 |
+
assert isinstance(prompt, str) and len(input_images) == 1, "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images"
|
196 |
+
else:
|
197 |
+
assert height%16 == 0 and width%16 == 0, "The height and width must be a multiple of 16."
|
198 |
+
if input_images is None:
|
199 |
+
use_img_guidance = False
|
200 |
+
if isinstance(prompt, str):
|
201 |
+
prompt = [prompt]
|
202 |
+
input_images = [input_images] if input_images is not None else None
|
203 |
+
|
204 |
+
# set model and processor
|
205 |
+
if max_input_image_size != self.processor.max_image_size:
|
206 |
+
self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size)
|
207 |
+
if offload_model:
|
208 |
+
self.enable_model_cpu_offload()
|
209 |
+
else:
|
210 |
+
self.disable_model_cpu_offload()
|
211 |
+
|
212 |
+
input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer, use_input_image_size_as_output=use_input_image_size_as_output)
|
213 |
+
|
214 |
+
num_prompt = len(prompt)
|
215 |
+
num_cfg = 2 if use_img_guidance else 1
|
216 |
+
if use_input_image_size_as_output:
|
217 |
+
if separate_cfg_infer:
|
218 |
+
height, width = input_data['input_pixel_values'][0][0].shape[-2:]
|
219 |
+
else:
|
220 |
+
height, width = input_data['input_pixel_values'][0].shape[-2:]
|
221 |
+
latent_size_h, latent_size_w = height//8, width//8
|
222 |
+
|
223 |
+
if seed is not None:
|
224 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
225 |
+
else:
|
226 |
+
generator = None
|
227 |
+
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
|
228 |
+
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
|
229 |
+
|
230 |
+
if input_images is not None and self.model_cpu_offload: self.vae.to(self.device)
|
231 |
+
input_img_latents = []
|
232 |
+
if separate_cfg_infer:
|
233 |
+
for temp_pixel_values in input_data['input_pixel_values']:
|
234 |
+
temp_input_latents = []
|
235 |
+
for img in temp_pixel_values:
|
236 |
+
img = self.vae_encode(img.to(self.device), dtype)
|
237 |
+
temp_input_latents.append(img)
|
238 |
+
input_img_latents.append(temp_input_latents)
|
239 |
+
else:
|
240 |
+
for img in input_data['input_pixel_values']:
|
241 |
+
img = self.vae_encode(img.to(self.device), dtype)
|
242 |
+
input_img_latents.append(img)
|
243 |
+
if input_images is not None and self.model_cpu_offload:
|
244 |
+
self.vae.to('cpu')
|
245 |
+
torch.cuda.empty_cache() # Clear VRAM
|
246 |
+
gc.collect() # Run garbage collection to free system RAM
|
247 |
+
|
248 |
+
model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
|
249 |
+
input_img_latents=input_img_latents,
|
250 |
+
input_image_sizes=input_data['input_image_sizes'],
|
251 |
+
attention_mask=self.move_to_device(input_data["attention_mask"]),
|
252 |
+
position_ids=self.move_to_device(input_data["position_ids"]),
|
253 |
+
cfg_scale=guidance_scale,
|
254 |
+
img_cfg_scale=img_guidance_scale,
|
255 |
+
use_img_cfg=use_img_guidance,
|
256 |
+
use_kv_cache=use_kv_cache,
|
257 |
+
offload_model=offload_model,
|
258 |
+
)
|
259 |
+
|
260 |
+
if separate_cfg_infer:
|
261 |
+
func = self.model.forward_with_separate_cfg
|
262 |
+
else:
|
263 |
+
func = self.model.forward_with_cfg
|
264 |
+
self.model.to(dtype)
|
265 |
+
|
266 |
+
if self.model_cpu_offload:
|
267 |
+
for name, param in self.model.named_parameters():
|
268 |
+
if 'layers' in name and 'layers.0' not in name:
|
269 |
+
param.data = param.data.cpu()
|
270 |
+
else:
|
271 |
+
param.data = param.data.to(self.device)
|
272 |
+
for buffer_name, buffer in self.model.named_buffers():
|
273 |
+
setattr(self.model, buffer_name, buffer.to(self.device))
|
274 |
+
# else:
|
275 |
+
# self.model.to(self.device)
|
276 |
+
|
277 |
+
scheduler = OmniGenScheduler(num_steps=num_inference_steps)
|
278 |
+
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
|
279 |
+
samples = samples.chunk((1+num_cfg), dim=0)[0]
|
280 |
+
|
281 |
+
if self.model_cpu_offload:
|
282 |
+
self.model.to('cpu')
|
283 |
+
torch.cuda.empty_cache()
|
284 |
+
gc.collect()
|
285 |
+
|
286 |
+
self.vae.to(self.device)
|
287 |
+
samples = samples.to(torch.float32)
|
288 |
+
if self.vae.config.shift_factor is not None:
|
289 |
+
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
290 |
+
else:
|
291 |
+
samples = samples / self.vae.config.scaling_factor
|
292 |
+
samples = self.vae.decode(samples).sample
|
293 |
+
|
294 |
+
if self.model_cpu_offload:
|
295 |
+
self.vae.to('cpu')
|
296 |
+
torch.cuda.empty_cache()
|
297 |
+
gc.collect()
|
298 |
+
|
299 |
+
output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
|
300 |
+
output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
301 |
+
output_images = []
|
302 |
+
for i, sample in enumerate(output_samples):
|
303 |
+
output_images.append(Image.fromarray(sample))
|
304 |
+
|
305 |
+
torch.cuda.empty_cache() # Clear VRAM
|
306 |
+
gc.collect() # Run garbage collection to free system RAM
|
307 |
+
return output_images
|
OmniGen/processor.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import Dict, List
|
4 |
+
import json
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
|
14 |
+
from OmniGen.utils import (
|
15 |
+
create_logger,
|
16 |
+
update_ema,
|
17 |
+
requires_grad,
|
18 |
+
center_crop_arr,
|
19 |
+
crop_arr,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class OmniGenProcessor:
|
26 |
+
def __init__(self,
|
27 |
+
text_tokenizer,
|
28 |
+
max_image_size: int=1024):
|
29 |
+
self.text_tokenizer = text_tokenizer
|
30 |
+
self.max_image_size = max_image_size
|
31 |
+
|
32 |
+
self.image_transform = transforms.Compose([
|
33 |
+
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
36 |
+
])
|
37 |
+
|
38 |
+
self.collator = OmniGenCollator()
|
39 |
+
self.separate_collator = OmniGenSeparateCollator()
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_pretrained(cls, model_name):
|
43 |
+
if not os.path.exists(model_name):
|
44 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
45 |
+
model_name = snapshot_download(repo_id=model_name,
|
46 |
+
cache_dir=cache_folder,
|
47 |
+
allow_patterns="*.json")
|
48 |
+
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
49 |
+
|
50 |
+
return cls(text_tokenizer)
|
51 |
+
|
52 |
+
|
53 |
+
def process_image(self, image):
|
54 |
+
image = Image.open(image).convert('RGB')
|
55 |
+
return self.image_transform(image)
|
56 |
+
|
57 |
+
def process_multi_modal_prompt(self, text, input_images):
|
58 |
+
text = self.add_prefix_instruction(text)
|
59 |
+
if input_images is None or len(input_images) == 0:
|
60 |
+
model_inputs = self.text_tokenizer(text)
|
61 |
+
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
62 |
+
|
63 |
+
pattern = r"<\|image_\d+\|>"
|
64 |
+
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
|
65 |
+
|
66 |
+
for i in range(1, len(prompt_chunks)):
|
67 |
+
if prompt_chunks[i][0] == 1:
|
68 |
+
prompt_chunks[i] = prompt_chunks[i][1:]
|
69 |
+
|
70 |
+
image_tags = re.findall(pattern, text)
|
71 |
+
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
72 |
+
|
73 |
+
unique_image_ids = sorted(list(set(image_ids)))
|
74 |
+
assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
75 |
+
# total images must be the same as the number of image tags
|
76 |
+
assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
77 |
+
|
78 |
+
input_images = [input_images[x-1] for x in image_ids]
|
79 |
+
|
80 |
+
all_input_ids = []
|
81 |
+
img_inx = []
|
82 |
+
idx = 0
|
83 |
+
for i in range(len(prompt_chunks)):
|
84 |
+
all_input_ids.extend(prompt_chunks[i])
|
85 |
+
if i != len(prompt_chunks) -1:
|
86 |
+
start_inx = len(all_input_ids)
|
87 |
+
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
|
88 |
+
img_inx.append([start_inx, start_inx+size])
|
89 |
+
all_input_ids.extend([0]*size)
|
90 |
+
|
91 |
+
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
|
92 |
+
|
93 |
+
|
94 |
+
def add_prefix_instruction(self, prompt):
|
95 |
+
user_prompt = '<|user|>\n'
|
96 |
+
generation_prompt = 'Generate an image according to the following instructions\n'
|
97 |
+
assistant_prompt = '<|assistant|>\n<|diffusion|>'
|
98 |
+
prompt_suffix = "<|end|>\n"
|
99 |
+
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
|
100 |
+
return prompt
|
101 |
+
|
102 |
+
|
103 |
+
def __call__(self,
|
104 |
+
instructions: List[str],
|
105 |
+
input_images: List[List[str]] = None,
|
106 |
+
height: int = 1024,
|
107 |
+
width: int = 1024,
|
108 |
+
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
109 |
+
use_img_cfg: bool = True,
|
110 |
+
separate_cfg_input: bool = False,
|
111 |
+
use_input_image_size_as_output: bool=False,
|
112 |
+
) -> Dict:
|
113 |
+
|
114 |
+
if input_images is None:
|
115 |
+
use_img_cfg = False
|
116 |
+
if isinstance(instructions, str):
|
117 |
+
instructions = [instructions]
|
118 |
+
input_images = [input_images]
|
119 |
+
|
120 |
+
input_data = []
|
121 |
+
for i in range(len(instructions)):
|
122 |
+
cur_instruction = instructions[i]
|
123 |
+
cur_input_images = None if input_images is None else input_images[i]
|
124 |
+
if cur_input_images is not None and len(cur_input_images) > 0:
|
125 |
+
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
126 |
+
else:
|
127 |
+
cur_input_images = None
|
128 |
+
assert "<img><|image_1|></img>" not in cur_instruction
|
129 |
+
|
130 |
+
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
|
131 |
+
|
132 |
+
|
133 |
+
neg_mllm_input, img_cfg_mllm_input = None, None
|
134 |
+
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
|
135 |
+
if use_img_cfg:
|
136 |
+
if cur_input_images is not None and len(cur_input_images) >= 1:
|
137 |
+
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
|
138 |
+
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
|
139 |
+
else:
|
140 |
+
img_cfg_mllm_input = neg_mllm_input
|
141 |
+
|
142 |
+
if use_input_image_size_as_output:
|
143 |
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
|
144 |
+
else:
|
145 |
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
146 |
+
|
147 |
+
if separate_cfg_input:
|
148 |
+
return self.separate_collator(input_data)
|
149 |
+
return self.collator(input_data)
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
class OmniGenCollator:
|
155 |
+
def __init__(self, pad_token_id=2, hidden_size=3072):
|
156 |
+
self.pad_token_id = pad_token_id
|
157 |
+
self.hidden_size = hidden_size
|
158 |
+
|
159 |
+
def create_position(self, attention_mask, num_tokens_for_output_images):
|
160 |
+
position_ids = []
|
161 |
+
text_length = attention_mask.size(-1)
|
162 |
+
img_length = max(num_tokens_for_output_images)
|
163 |
+
for mask in attention_mask:
|
164 |
+
temp_l = torch.sum(mask)
|
165 |
+
temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
|
166 |
+
position_ids.append(temp_position)
|
167 |
+
return torch.LongTensor(position_ids)
|
168 |
+
|
169 |
+
def create_mask(self, attention_mask, num_tokens_for_output_images):
|
170 |
+
extended_mask = []
|
171 |
+
padding_images = []
|
172 |
+
text_length = attention_mask.size(-1)
|
173 |
+
img_length = max(num_tokens_for_output_images)
|
174 |
+
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
|
175 |
+
inx = 0
|
176 |
+
for mask in attention_mask:
|
177 |
+
temp_l = torch.sum(mask)
|
178 |
+
pad_l = text_length - temp_l
|
179 |
+
|
180 |
+
temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
|
181 |
+
|
182 |
+
image_mask = torch.zeros(size=(temp_l+1, img_length))
|
183 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
|
184 |
+
|
185 |
+
image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
|
186 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
|
187 |
+
|
188 |
+
if pad_l > 0:
|
189 |
+
pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
|
190 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
|
191 |
+
|
192 |
+
pad_mask = torch.ones(size=(pad_l, seq_len))
|
193 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
|
194 |
+
|
195 |
+
true_img_length = num_tokens_for_output_images[inx]
|
196 |
+
pad_img_length = img_length - true_img_length
|
197 |
+
if pad_img_length > 0:
|
198 |
+
temp_mask[:, -pad_img_length:] = 0
|
199 |
+
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
|
200 |
+
else:
|
201 |
+
temp_padding_imgs = None
|
202 |
+
|
203 |
+
extended_mask.append(temp_mask.unsqueeze(0))
|
204 |
+
padding_images.append(temp_padding_imgs)
|
205 |
+
inx += 1
|
206 |
+
return torch.cat(extended_mask, dim=0), padding_images
|
207 |
+
|
208 |
+
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
|
209 |
+
for b_inx in image_sizes.keys():
|
210 |
+
for start_inx, end_inx in image_sizes[b_inx]:
|
211 |
+
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
|
212 |
+
|
213 |
+
return attention_mask
|
214 |
+
|
215 |
+
def pad_input_ids(self, input_ids, image_sizes):
|
216 |
+
max_l = max([len(x) for x in input_ids])
|
217 |
+
padded_ids = []
|
218 |
+
attention_mask = []
|
219 |
+
new_image_sizes = []
|
220 |
+
|
221 |
+
for i in range(len(input_ids)):
|
222 |
+
temp_ids = input_ids[i]
|
223 |
+
temp_l = len(temp_ids)
|
224 |
+
pad_l = max_l - temp_l
|
225 |
+
if pad_l == 0:
|
226 |
+
attention_mask.append([1]*max_l)
|
227 |
+
padded_ids.append(temp_ids)
|
228 |
+
else:
|
229 |
+
attention_mask.append([0]*pad_l+[1]*temp_l)
|
230 |
+
padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
|
231 |
+
|
232 |
+
if i in image_sizes:
|
233 |
+
new_inx = []
|
234 |
+
for old_inx in image_sizes[i]:
|
235 |
+
new_inx.append([x+pad_l for x in old_inx])
|
236 |
+
image_sizes[i] = new_inx
|
237 |
+
|
238 |
+
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
|
239 |
+
|
240 |
+
|
241 |
+
def process_mllm_input(self, mllm_inputs, target_img_size):
|
242 |
+
num_tokens_for_output_images = []
|
243 |
+
for img_size in target_img_size:
|
244 |
+
num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
|
245 |
+
|
246 |
+
pixel_values, image_sizes = [], {}
|
247 |
+
b_inx = 0
|
248 |
+
for x in mllm_inputs:
|
249 |
+
if x['pixel_values'] is not None:
|
250 |
+
pixel_values.extend(x['pixel_values'])
|
251 |
+
for size in x['image_sizes']:
|
252 |
+
if b_inx not in image_sizes:
|
253 |
+
image_sizes[b_inx] = [size]
|
254 |
+
else:
|
255 |
+
image_sizes[b_inx].append(size)
|
256 |
+
b_inx += 1
|
257 |
+
pixel_values = [x.unsqueeze(0) for x in pixel_values]
|
258 |
+
|
259 |
+
|
260 |
+
input_ids = [x['input_ids'] for x in mllm_inputs]
|
261 |
+
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
|
262 |
+
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
|
263 |
+
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
|
264 |
+
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
|
265 |
+
|
266 |
+
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
|
267 |
+
|
268 |
+
|
269 |
+
def __call__(self, features):
|
270 |
+
mllm_inputs = [f[0] for f in features]
|
271 |
+
cfg_mllm_inputs = [f[1] for f in features]
|
272 |
+
img_cfg_mllm_input = [f[2] for f in features]
|
273 |
+
target_img_size = [f[3] for f in features]
|
274 |
+
|
275 |
+
|
276 |
+
if img_cfg_mllm_input[0] is not None:
|
277 |
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
|
278 |
+
target_img_size = target_img_size + target_img_size + target_img_size
|
279 |
+
else:
|
280 |
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs
|
281 |
+
target_img_size = target_img_size + target_img_size
|
282 |
+
|
283 |
+
|
284 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
285 |
+
|
286 |
+
data = {"input_ids": all_padded_input_ids,
|
287 |
+
"attention_mask": all_attention_mask,
|
288 |
+
"position_ids": all_position_ids,
|
289 |
+
"input_pixel_values": all_pixel_values,
|
290 |
+
"input_image_sizes": all_image_sizes,
|
291 |
+
"padding_images": all_padding_images,
|
292 |
+
}
|
293 |
+
return data
|
294 |
+
|
295 |
+
|
296 |
+
class OmniGenSeparateCollator(OmniGenCollator):
|
297 |
+
def __call__(self, features):
|
298 |
+
mllm_inputs = [f[0] for f in features]
|
299 |
+
cfg_mllm_inputs = [f[1] for f in features]
|
300 |
+
img_cfg_mllm_input = [f[2] for f in features]
|
301 |
+
target_img_size = [f[3] for f in features]
|
302 |
+
|
303 |
+
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
|
304 |
+
|
305 |
+
|
306 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
307 |
+
all_padded_input_ids.append(padded_input_ids)
|
308 |
+
all_attention_mask.append(attention_mask)
|
309 |
+
all_position_ids.append(position_ids)
|
310 |
+
all_pixel_values.append(pixel_values)
|
311 |
+
all_image_sizes.append(image_sizes)
|
312 |
+
all_padding_images.append(padding_images)
|
313 |
+
|
314 |
+
if cfg_mllm_inputs[0] is not None:
|
315 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
|
316 |
+
all_padded_input_ids.append(padded_input_ids)
|
317 |
+
all_attention_mask.append(attention_mask)
|
318 |
+
all_position_ids.append(position_ids)
|
319 |
+
all_pixel_values.append(pixel_values)
|
320 |
+
all_image_sizes.append(image_sizes)
|
321 |
+
all_padding_images.append(padding_images)
|
322 |
+
if img_cfg_mllm_input[0] is not None:
|
323 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
|
324 |
+
all_padded_input_ids.append(padded_input_ids)
|
325 |
+
all_attention_mask.append(attention_mask)
|
326 |
+
all_position_ids.append(position_ids)
|
327 |
+
all_pixel_values.append(pixel_values)
|
328 |
+
all_image_sizes.append(image_sizes)
|
329 |
+
all_padding_images.append(padding_images)
|
330 |
+
|
331 |
+
data = {"input_ids": all_padded_input_ids,
|
332 |
+
"attention_mask": all_attention_mask,
|
333 |
+
"position_ids": all_position_ids,
|
334 |
+
"input_pixel_values": all_pixel_values,
|
335 |
+
"input_image_sizes": all_image_sizes,
|
336 |
+
"padding_images": all_padding_images,
|
337 |
+
}
|
338 |
+
return data
|
OmniGen/scheduler.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from typing import Optional, Dict, Any, Tuple, List
|
3 |
+
import gc
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class OmniGenCache(DynamicCache):
|
11 |
+
def __init__(self,
|
12 |
+
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
|
13 |
+
if not torch.cuda.is_available():
|
14 |
+
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
15 |
+
super().__init__()
|
16 |
+
self.original_device = []
|
17 |
+
self.prefetch_stream = torch.cuda.Stream()
|
18 |
+
self.num_tokens_for_img = num_tokens_for_img
|
19 |
+
self.offload_kv_cache = offload_kv_cache
|
20 |
+
|
21 |
+
def prefetch_layer(self, layer_idx: int):
|
22 |
+
"Starts prefetching the next layer cache"
|
23 |
+
if layer_idx < len(self):
|
24 |
+
with torch.cuda.stream(self.prefetch_stream):
|
25 |
+
# Prefetch next layer tensors to GPU
|
26 |
+
device = self.original_device[layer_idx]
|
27 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
28 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
29 |
+
|
30 |
+
|
31 |
+
def evict_previous_layer(self, layer_idx: int):
|
32 |
+
"Moves the previous layer cache to the CPU"
|
33 |
+
if len(self) > 2:
|
34 |
+
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
|
35 |
+
if layer_idx == 0:
|
36 |
+
prev_layer_idx = -1
|
37 |
+
else:
|
38 |
+
prev_layer_idx = (layer_idx - 1) % len(self)
|
39 |
+
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
40 |
+
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
41 |
+
|
42 |
+
|
43 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
44 |
+
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
45 |
+
if layer_idx < len(self):
|
46 |
+
if self.offload_kv_cache:
|
47 |
+
# Evict the previous layer if necessary
|
48 |
+
torch.cuda.current_stream().synchronize()
|
49 |
+
self.evict_previous_layer(layer_idx)
|
50 |
+
# Load current layer cache to its original device if not already there
|
51 |
+
original_device = self.original_device[layer_idx]
|
52 |
+
# self.prefetch_stream.synchronize(original_device)
|
53 |
+
torch.cuda.synchronize(self.prefetch_stream)
|
54 |
+
key_tensor = self.key_cache[layer_idx]
|
55 |
+
value_tensor = self.value_cache[layer_idx]
|
56 |
+
|
57 |
+
# Prefetch the next layer
|
58 |
+
self.prefetch_layer((layer_idx + 1) % len(self))
|
59 |
+
else:
|
60 |
+
key_tensor = self.key_cache[layer_idx]
|
61 |
+
value_tensor = self.value_cache[layer_idx]
|
62 |
+
return (key_tensor, value_tensor)
|
63 |
+
else:
|
64 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
65 |
+
|
66 |
+
|
67 |
+
def update(
|
68 |
+
self,
|
69 |
+
key_states: torch.Tensor,
|
70 |
+
value_states: torch.Tensor,
|
71 |
+
layer_idx: int,
|
72 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
73 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
74 |
+
"""
|
75 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
76 |
+
Parameters:
|
77 |
+
key_states (`torch.Tensor`):
|
78 |
+
The new key states to cache.
|
79 |
+
value_states (`torch.Tensor`):
|
80 |
+
The new value states to cache.
|
81 |
+
layer_idx (`int`):
|
82 |
+
The index of the layer to cache the states for.
|
83 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
84 |
+
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
|
85 |
+
Return:
|
86 |
+
A tuple containing the updated key and value states.
|
87 |
+
"""
|
88 |
+
# Update the cache
|
89 |
+
if len(self.key_cache) < layer_idx:
|
90 |
+
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
|
91 |
+
elif len(self.key_cache) == layer_idx:
|
92 |
+
# only cache the states for condition tokens
|
93 |
+
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
|
94 |
+
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
|
95 |
+
|
96 |
+
# Update the number of seen tokens
|
97 |
+
if layer_idx == 0:
|
98 |
+
self._seen_tokens += key_states.shape[-2]
|
99 |
+
|
100 |
+
self.key_cache.append(key_states)
|
101 |
+
self.value_cache.append(value_states)
|
102 |
+
self.original_device.append(key_states.device)
|
103 |
+
if self.offload_kv_cache:
|
104 |
+
self.evict_previous_layer(layer_idx)
|
105 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
106 |
+
else:
|
107 |
+
# only cache the states for condition tokens
|
108 |
+
key_tensor, value_tensor = self[layer_idx]
|
109 |
+
k = torch.cat([key_tensor, key_states], dim=-2)
|
110 |
+
v = torch.cat([value_tensor, value_states], dim=-2)
|
111 |
+
return k, v
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
class OmniGenScheduler:
|
116 |
+
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
117 |
+
self.num_steps = num_steps
|
118 |
+
self.time_shift = time_shifting_factor
|
119 |
+
|
120 |
+
t = torch.linspace(0, 1, num_steps+1)
|
121 |
+
t = t / (t + time_shifting_factor - time_shifting_factor * t)
|
122 |
+
self.sigma = t
|
123 |
+
|
124 |
+
def crop_kv_cache(self, past_key_values, num_tokens_for_img):
|
125 |
+
# return
|
126 |
+
crop_past_key_values = ()
|
127 |
+
for layer_idx in range(len(past_key_values)):
|
128 |
+
key_states, value_states = past_key_values[layer_idx][:2]
|
129 |
+
crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
|
130 |
+
# return crop_past_key_values
|
131 |
+
return DynamicCache.from_legacy_cache(crop_past_key_values)
|
132 |
+
|
133 |
+
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
134 |
+
if isinstance(position_ids, list):
|
135 |
+
for i in range(len(position_ids)):
|
136 |
+
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
|
137 |
+
else:
|
138 |
+
position_ids = position_ids[:, -(num_tokens_for_img+1):]
|
139 |
+
return position_ids
|
140 |
+
|
141 |
+
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
|
142 |
+
if isinstance(attention_mask, list):
|
143 |
+
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
144 |
+
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
145 |
+
|
146 |
+
def crop_cache(self, cache, num_tokens_for_img):
|
147 |
+
for i in range(len(cache.key_cache)):
|
148 |
+
cache.key_cache[i] = cache.key_cache[i][..., :-(num_tokens_for_img+1), :]
|
149 |
+
cache.value_cache[i] = cache.value_cache[i][..., :-(num_tokens_for_img+1), :]
|
150 |
+
|
151 |
+
return cache
|
152 |
+
|
153 |
+
def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True, offload_kv_cache: bool=True):
|
154 |
+
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
155 |
+
if isinstance(model_kwargs['input_ids'], list):
|
156 |
+
cache = [OmniGenCache(num_tokens_for_img, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
|
157 |
+
else:
|
158 |
+
cache = OmniGenCache(num_tokens_for_img, offload_kv_cache) if use_kv_cache else None
|
159 |
+
results = {}
|
160 |
+
for i in tqdm(range(self.num_steps)):
|
161 |
+
timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
|
162 |
+
pred, cache = func(z, timesteps, past_key_values=cache, **model_kwargs)
|
163 |
+
sigma_next = self.sigma[i+1]
|
164 |
+
sigma = self.sigma[i]
|
165 |
+
z = z + (sigma_next - sigma) * pred
|
166 |
+
if i == 0 and use_kv_cache:
|
167 |
+
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
168 |
+
if isinstance(cache, list):
|
169 |
+
model_kwargs['input_ids'] = [None] * len(cache)
|
170 |
+
else:
|
171 |
+
model_kwargs['input_ids'] = None
|
172 |
+
|
173 |
+
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
174 |
+
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
175 |
+
|
176 |
+
del cache
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
gc.collect()
|
179 |
+
return z
|
180 |
+
|
181 |
+
|
OmniGen/train_helper/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .data import DatasetFromJson, TrainDataCollator
|
2 |
+
from .loss import training_losses
|
OmniGen/train_helper/data.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datasets
|
3 |
+
from datasets import load_dataset, ClassLabel, concatenate_datasets
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from PIL import Image
|
8 |
+
import json
|
9 |
+
import copy
|
10 |
+
# import torchvision.transforms as T
|
11 |
+
from torchvision import transforms
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
|
15 |
+
from OmniGen import OmniGenProcessor
|
16 |
+
from OmniGen.processor import OmniGenCollator
|
17 |
+
|
18 |
+
|
19 |
+
class DatasetFromJson(torch.utils.data.Dataset):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
json_file: str,
|
23 |
+
image_path: str,
|
24 |
+
processer: OmniGenProcessor,
|
25 |
+
image_transform,
|
26 |
+
max_input_length_limit: int = 18000,
|
27 |
+
condition_dropout_prob: float = 0.1,
|
28 |
+
keep_raw_resolution: bool = True,
|
29 |
+
):
|
30 |
+
|
31 |
+
self.image_transform = image_transform
|
32 |
+
self.processer = processer
|
33 |
+
self.condition_dropout_prob = condition_dropout_prob
|
34 |
+
self.max_input_length_limit = max_input_length_limit
|
35 |
+
self.keep_raw_resolution = keep_raw_resolution
|
36 |
+
|
37 |
+
self.data = load_dataset('json', data_files=json_file)['train']
|
38 |
+
self.image_path = image_path
|
39 |
+
|
40 |
+
def process_image(self, image_file):
|
41 |
+
if self.image_path is not None:
|
42 |
+
image_file = os.path.join(self.image_path, image_file)
|
43 |
+
image = Image.open(image_file).convert('RGB')
|
44 |
+
return self.image_transform(image)
|
45 |
+
|
46 |
+
def get_example(self, index):
|
47 |
+
example = self.data[index]
|
48 |
+
|
49 |
+
instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
|
50 |
+
if random.random() < self.condition_dropout_prob:
|
51 |
+
instruction = '<cfg>'
|
52 |
+
input_images = None
|
53 |
+
if input_images is not None:
|
54 |
+
input_images = [self.process_image(x) for x in input_images]
|
55 |
+
mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)
|
56 |
+
|
57 |
+
output_image = self.process_image(output_image)
|
58 |
+
|
59 |
+
return (mllm_input, output_image)
|
60 |
+
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
return self.get_example(index)
|
64 |
+
for _ in range(8):
|
65 |
+
try:
|
66 |
+
mllm_input, output_image = self.get_example(index)
|
67 |
+
if len(mllm_input['input_ids']) > self.max_input_length_limit:
|
68 |
+
raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
|
69 |
+
return mllm_input, output_image
|
70 |
+
except Exception as e:
|
71 |
+
print("error when loading data: ", e)
|
72 |
+
print(self.data[index])
|
73 |
+
index = random.randint(0, len(self.data)-1)
|
74 |
+
raise RuntimeError("Too many bad data.")
|
75 |
+
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.data)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class TrainDataCollator(OmniGenCollator):
|
83 |
+
def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
|
84 |
+
self.pad_token_id = pad_token_id
|
85 |
+
self.hidden_size = hidden_size
|
86 |
+
self.keep_raw_resolution = keep_raw_resolution
|
87 |
+
|
88 |
+
def __call__(self, features):
|
89 |
+
mllm_inputs = [f[0] for f in features]
|
90 |
+
|
91 |
+
output_images = [f[1].unsqueeze(0) for f in features]
|
92 |
+
target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]
|
93 |
+
|
94 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
95 |
+
|
96 |
+
if not self.keep_raw_resolution:
|
97 |
+
output_image = torch.cat(output_image, dim=0)
|
98 |
+
if len(pixel_values) > 0:
|
99 |
+
all_pixel_values = torch.cat(all_pixel_values, dim=0)
|
100 |
+
else:
|
101 |
+
all_pixel_values = None
|
102 |
+
|
103 |
+
data = {"input_ids": all_padded_input_ids,
|
104 |
+
"attention_mask": all_attention_mask,
|
105 |
+
"position_ids": all_position_ids,
|
106 |
+
"input_pixel_values": all_pixel_values,
|
107 |
+
"input_image_sizes": all_image_sizes,
|
108 |
+
"padding_images": all_padding_images,
|
109 |
+
"output_images": output_images,
|
110 |
+
}
|
111 |
+
return data
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
OmniGen/train_helper/loss.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def sample_x0(x1):
|
5 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
6 |
+
Args:
|
7 |
+
x1 - data point; [batch, *dim]
|
8 |
+
"""
|
9 |
+
if isinstance(x1, (list, tuple)):
|
10 |
+
x0 = [torch.randn_like(img_start) for img_start in x1]
|
11 |
+
else:
|
12 |
+
x0 = torch.randn_like(x1)
|
13 |
+
|
14 |
+
return x0
|
15 |
+
|
16 |
+
def sample_timestep(x1):
|
17 |
+
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
|
18 |
+
t = 1 / (1 + torch.exp(-u))
|
19 |
+
t = t.to(x1[0])
|
20 |
+
return t
|
21 |
+
|
22 |
+
|
23 |
+
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
|
24 |
+
"""Loss for training torche score model
|
25 |
+
Args:
|
26 |
+
- model: backbone model; could be score, noise, or velocity
|
27 |
+
- x1: datapoint
|
28 |
+
- model_kwargs: additional arguments for torche model
|
29 |
+
"""
|
30 |
+
if model_kwargs == None:
|
31 |
+
model_kwargs = {}
|
32 |
+
|
33 |
+
B = len(x1)
|
34 |
+
|
35 |
+
x0 = sample_x0(x1)
|
36 |
+
t = sample_timestep(x1)
|
37 |
+
|
38 |
+
if isinstance(x1, (list, tuple)):
|
39 |
+
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
|
40 |
+
ut = [x1[i] - x0[i] for i in range(B)]
|
41 |
+
else:
|
42 |
+
dims = [1] * (len(x1.size()) - 1)
|
43 |
+
t_ = t.view(t.size(0), *dims)
|
44 |
+
xt = t_ * x1 + (1 - t_) * x0
|
45 |
+
ut = x1 - x0
|
46 |
+
|
47 |
+
model_output = model(xt, t, **model_kwargs)
|
48 |
+
|
49 |
+
terms = {}
|
50 |
+
|
51 |
+
if isinstance(x1, (list, tuple)):
|
52 |
+
assert len(model_output) == len(ut) == len(x1)
|
53 |
+
for i in range(B):
|
54 |
+
terms["loss"] = torch.stack(
|
55 |
+
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
|
56 |
+
dim=0,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
terms["loss"] = mean_flat(((model_output - ut) ** 2))
|
60 |
+
|
61 |
+
return terms
|
62 |
+
|
63 |
+
|
64 |
+
def mean_flat(x):
|
65 |
+
"""
|
66 |
+
Take torche mean over all non-batch dimensions.
|
67 |
+
"""
|
68 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
OmniGen/transformer.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import (
|
12 |
+
BaseModelOutputWithPast,
|
13 |
+
CausalLMOutputWithPast,
|
14 |
+
SequenceClassifierOutputWithPast,
|
15 |
+
TokenClassifierOutput,
|
16 |
+
)
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers import Phi3Config, Phi3Model
|
19 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class Phi3Transformer(Phi3Model):
|
26 |
+
"""
|
27 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
28 |
+
We only modified the attention mask
|
29 |
+
Args:
|
30 |
+
config: Phi3Config
|
31 |
+
"""
|
32 |
+
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
33 |
+
"Starts prefetching the next layer cache"
|
34 |
+
with torch.cuda.stream(self.prefetch_stream):
|
35 |
+
# Prefetch next layer tensors to GPU
|
36 |
+
for name, param in self.layers[layer_idx].named_parameters():
|
37 |
+
param.data = param.data.to(device, non_blocking=True)
|
38 |
+
|
39 |
+
def evict_previous_layer(self, layer_idx: int):
|
40 |
+
"Moves the previous layer cache to the CPU"
|
41 |
+
prev_layer_idx = layer_idx - 1
|
42 |
+
for name, param in self.layers[prev_layer_idx].named_parameters():
|
43 |
+
param.data = param.data.to("cpu", non_blocking=True)
|
44 |
+
|
45 |
+
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
46 |
+
# init stream
|
47 |
+
if not hasattr(self, "prefetch_stream"):
|
48 |
+
self.prefetch_stream = torch.cuda.Stream()
|
49 |
+
|
50 |
+
# delete previous layer
|
51 |
+
torch.cuda.current_stream().synchronize()
|
52 |
+
self.evict_previous_layer(layer_idx)
|
53 |
+
|
54 |
+
# make sure the current layer is ready
|
55 |
+
torch.cuda.synchronize(self.prefetch_stream)
|
56 |
+
|
57 |
+
# load next layer
|
58 |
+
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
59 |
+
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
input_ids: torch.LongTensor = None,
|
64 |
+
attention_mask: Optional[torch.Tensor] = None,
|
65 |
+
position_ids: Optional[torch.LongTensor] = None,
|
66 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
67 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
68 |
+
use_cache: Optional[bool] = None,
|
69 |
+
output_attentions: Optional[bool] = None,
|
70 |
+
output_hidden_states: Optional[bool] = None,
|
71 |
+
return_dict: Optional[bool] = None,
|
72 |
+
cache_position: Optional[torch.LongTensor] = None,
|
73 |
+
offload_model: Optional[bool] = False,
|
74 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
75 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
76 |
+
output_hidden_states = (
|
77 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
78 |
+
)
|
79 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
80 |
+
|
81 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
82 |
+
|
83 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
84 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
85 |
+
|
86 |
+
if self.gradient_checkpointing and self.training:
|
87 |
+
if use_cache:
|
88 |
+
logger.warning_once(
|
89 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
90 |
+
)
|
91 |
+
use_cache = False
|
92 |
+
|
93 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
94 |
+
return_legacy_cache = False
|
95 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
96 |
+
return_legacy_cache = True
|
97 |
+
if past_key_values is None:
|
98 |
+
past_key_values = DynamicCache()
|
99 |
+
else:
|
100 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
101 |
+
logger.warning_once(
|
102 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
103 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
104 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
105 |
+
)
|
106 |
+
|
107 |
+
# if inputs_embeds is None:
|
108 |
+
# inputs_embeds = self.embed_tokens(input_ids)
|
109 |
+
|
110 |
+
# if cache_position is None:
|
111 |
+
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
112 |
+
# cache_position = torch.arange(
|
113 |
+
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
114 |
+
# )
|
115 |
+
# if position_ids is None:
|
116 |
+
# position_ids = cache_position.unsqueeze(0)
|
117 |
+
|
118 |
+
if attention_mask is not None and attention_mask.dim() == 3:
|
119 |
+
dtype = inputs_embeds.dtype
|
120 |
+
min_dtype = torch.finfo(dtype).min
|
121 |
+
attention_mask = (1 - attention_mask) * min_dtype
|
122 |
+
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
123 |
+
else:
|
124 |
+
raise
|
125 |
+
# causal_mask = self._update_causal_mask(
|
126 |
+
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
127 |
+
# )
|
128 |
+
|
129 |
+
hidden_states = inputs_embeds
|
130 |
+
|
131 |
+
# decoder layers
|
132 |
+
all_hidden_states = () if output_hidden_states else None
|
133 |
+
all_self_attns = () if output_attentions else None
|
134 |
+
next_decoder_cache = None
|
135 |
+
|
136 |
+
layer_idx = -1
|
137 |
+
for decoder_layer in self.layers:
|
138 |
+
layer_idx += 1
|
139 |
+
|
140 |
+
if output_hidden_states:
|
141 |
+
all_hidden_states += (hidden_states,)
|
142 |
+
|
143 |
+
if self.gradient_checkpointing and self.training:
|
144 |
+
layer_outputs = self._gradient_checkpointing_func(
|
145 |
+
decoder_layer.__call__,
|
146 |
+
hidden_states,
|
147 |
+
attention_mask,
|
148 |
+
position_ids,
|
149 |
+
past_key_values,
|
150 |
+
output_attentions,
|
151 |
+
use_cache,
|
152 |
+
cache_position,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
if offload_model and not self.training:
|
156 |
+
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
157 |
+
layer_outputs = decoder_layer(
|
158 |
+
hidden_states,
|
159 |
+
attention_mask=attention_mask,
|
160 |
+
position_ids=position_ids,
|
161 |
+
past_key_value=past_key_values,
|
162 |
+
output_attentions=output_attentions,
|
163 |
+
use_cache=use_cache,
|
164 |
+
cache_position=cache_position,
|
165 |
+
)
|
166 |
+
|
167 |
+
hidden_states = layer_outputs[0]
|
168 |
+
|
169 |
+
if use_cache:
|
170 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
171 |
+
|
172 |
+
if output_attentions:
|
173 |
+
all_self_attns += (layer_outputs[1],)
|
174 |
+
|
175 |
+
hidden_states = self.norm(hidden_states)
|
176 |
+
|
177 |
+
# add hidden states from the last decoder layer
|
178 |
+
if output_hidden_states:
|
179 |
+
print('************')
|
180 |
+
all_hidden_states += (hidden_states,)
|
181 |
+
|
182 |
+
next_cache = next_decoder_cache if use_cache else None
|
183 |
+
if return_legacy_cache:
|
184 |
+
next_cache = next_cache.to_legacy_cache()
|
185 |
+
|
186 |
+
if not return_dict:
|
187 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
188 |
+
return BaseModelOutputWithPast(
|
189 |
+
last_hidden_state=hidden_states,
|
190 |
+
past_key_values=next_cache,
|
191 |
+
hidden_states=all_hidden_states,
|
192 |
+
attentions=all_self_attns,
|
193 |
+
)
|
194 |
+
|
OmniGen/utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def create_logger(logging_dir):
|
8 |
+
"""
|
9 |
+
Create a logger that writes to a log file and stdout.
|
10 |
+
"""
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
14 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
15 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
return logger
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def update_ema(ema_model, model, decay=0.9999):
|
23 |
+
"""
|
24 |
+
Step the EMA model towards the current model.
|
25 |
+
"""
|
26 |
+
ema_params = dict(ema_model.named_parameters())
|
27 |
+
for name, param in model.named_parameters():
|
28 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
29 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def requires_grad(model, flag=True):
|
35 |
+
"""
|
36 |
+
Set requires_grad flag for all parameters in a model.
|
37 |
+
"""
|
38 |
+
for p in model.parameters():
|
39 |
+
p.requires_grad = flag
|
40 |
+
|
41 |
+
|
42 |
+
def center_crop_arr(pil_image, image_size):
|
43 |
+
"""
|
44 |
+
Center cropping implementation from ADM.
|
45 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
46 |
+
"""
|
47 |
+
while min(*pil_image.size) >= 2 * image_size:
|
48 |
+
pil_image = pil_image.resize(
|
49 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
50 |
+
)
|
51 |
+
|
52 |
+
scale = image_size / min(*pil_image.size)
|
53 |
+
pil_image = pil_image.resize(
|
54 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
55 |
+
)
|
56 |
+
|
57 |
+
arr = np.array(pil_image)
|
58 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
59 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
60 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def crop_arr(pil_image, max_image_size):
|
65 |
+
while min(*pil_image.size) >= 2 * max_image_size:
|
66 |
+
pil_image = pil_image.resize(
|
67 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
68 |
+
)
|
69 |
+
|
70 |
+
if max(*pil_image.size) > max_image_size:
|
71 |
+
scale = max_image_size / max(*pil_image.size)
|
72 |
+
pil_image = pil_image.resize(
|
73 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
74 |
+
)
|
75 |
+
|
76 |
+
if min(*pil_image.size) < 16:
|
77 |
+
scale = 16 / min(*pil_image.size)
|
78 |
+
pil_image = pil_image.resize(
|
79 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
80 |
+
)
|
81 |
+
|
82 |
+
arr = np.array(pil_image)
|
83 |
+
crop_y1 = (arr.shape[0] % 16) // 2
|
84 |
+
crop_y2 = arr.shape[0] % 16 - crop_y1
|
85 |
+
|
86 |
+
crop_x1 = (arr.shape[1] % 16) // 2
|
87 |
+
crop_x2 = arr.shape[1] % 16 - crop_x1
|
88 |
+
|
89 |
+
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
|
90 |
+
return Image.fromarray(arr)
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def vae_encode(vae, x, weight_dtype):
|
95 |
+
if x is not None:
|
96 |
+
if vae.config.shift_factor is not None:
|
97 |
+
x = vae.encode(x).latent_dist.sample()
|
98 |
+
x = (x - vae.config.shift_factor) * vae.config.scaling_factor
|
99 |
+
else:
|
100 |
+
x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
|
101 |
+
x = x.to(weight_dtype)
|
102 |
+
return x
|
103 |
+
|
104 |
+
def vae_encode_list(vae, x, weight_dtype):
|
105 |
+
latents = []
|
106 |
+
for img in x:
|
107 |
+
img = vae_encode(vae, img, weight_dtype)
|
108 |
+
latents.append(img)
|
109 |
+
return latents
|
110 |
+
|