yrr commited on
Commit
a713a09
·
1 Parent(s): 0d3229d

update inference code

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 VectorSpaceLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
OmniGen/model.py CHANGED
@@ -312,7 +312,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
312
  return latents, num_tokens, shapes
313
 
314
 
315
- def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
316
  """
317
 
318
  """
@@ -335,7 +335,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
335
  input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
336
  else:
337
  input_emb = torch.cat([time_token, x], dim=1)
338
- output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
339
  output, past_key_values = output.last_hidden_state, output.past_key_values
340
  if input_is_list:
341
  image_embedding = output[:, -max(num_tokens):]
@@ -357,12 +357,9 @@ class OmniGen(nn.Module, PeftAdapterMixin):
357
  return latents
358
 
359
  @torch.no_grad()
360
- def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
361
- """
362
- Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
363
- """
364
  self.llm.config.use_cache = use_kv_cache
365
- model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
366
  if use_img_cfg:
367
  cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
368
  cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
@@ -376,10 +373,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
376
 
377
 
378
  @torch.no_grad()
379
- def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):
380
- """
381
- Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
382
- """
383
  self.llm.config.use_cache = use_kv_cache
384
  if past_key_values is None:
385
  past_key_values = [None] * len(attention_mask)
@@ -390,7 +384,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
390
 
391
  model_out, pask_key_values = [], []
392
  for i in range(len(input_ids)):
393
- temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])
394
  model_out.append(temp_out)
395
  pask_key_values.append(temp_pask_key_values)
396
 
 
312
  return latents, num_tokens, shapes
313
 
314
 
315
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
316
  """
317
 
318
  """
 
335
  input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
336
  else:
337
  input_emb = torch.cat([time_token, x], dim=1)
338
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
339
  output, past_key_values = output.last_hidden_state, output.past_key_values
340
  if input_is_list:
341
  image_embedding = output[:, -max(num_tokens):]
 
357
  return latents
358
 
359
  @torch.no_grad()
360
+ def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
 
 
 
361
  self.llm.config.use_cache = use_kv_cache
362
+ model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
363
  if use_img_cfg:
364
  cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
365
  cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
 
373
 
374
 
375
  @torch.no_grad()
376
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
 
 
 
377
  self.llm.config.use_cache = use_kv_cache
378
  if past_key_values is None:
379
  past_key_values = [None] * len(attention_mask)
 
384
 
385
  model_out, pask_key_values = [], []
386
  for i in range(len(input_ids)):
387
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
388
  model_out.append(temp_out)
389
  pask_key_values.append(temp_pask_key_values)
390
 
OmniGen/pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import inspect
3
  from typing import Any, Callable, Dict, List, Optional, Union
 
4
 
5
  from PIL import Image
6
  import numpy as np
@@ -33,7 +34,7 @@ EXAMPLE_DOC_STRING = """
33
  >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
34
  >>> image = pipe(
35
  ... prompt,
36
- ... guidance_scale=3.0,
37
  ... num_inference_steps=50,
38
  ... ).images[0]
39
  >>> image.save("t2i.png")
@@ -41,7 +42,7 @@ EXAMPLE_DOC_STRING = """
41
  """
42
 
43
 
44
-
45
  class OmniGenPipeline:
46
  def __init__(
47
  self,
@@ -53,10 +54,21 @@ class OmniGenPipeline:
53
  self.model = model
54
  self.processor = processor
55
 
56
- self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
57
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
58
  self.model.eval()
59
- self.vae.to(self.device)
 
 
60
 
61
  @classmethod
62
  def from_pretrained(cls, model_name, vae_path: str=None):
@@ -84,7 +96,6 @@ class OmniGenPipeline:
84
  model = PeftModel.from_pretrained(self.model, lora_path)
85
  model.merge_and_unload()
86
 
87
-
88
  self.model = model
89
 
90
  def to(self, device: Union[str, torch.device]):
@@ -92,6 +103,7 @@ class OmniGenPipeline:
92
  device = torch.device(device)
93
  self.model.to(device)
94
  self.vae.to(device)
 
95
 
96
  def vae_encode(self, x, dtype):
97
  if self.vae.config.shift_factor is not None:
@@ -107,6 +119,17 @@ class OmniGenPipeline:
107
  return [x.to(self.device) for x in data]
108
  return data.to(self.device)
109
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  @torch.no_grad()
112
  @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -120,8 +143,12 @@ class OmniGenPipeline:
120
  guidance_scale: float = 3,
121
  use_img_guidance: bool = True,
122
  img_guidance_scale: float = 1.6,
123
- separate_cfg_infer: bool = False,
 
 
124
  use_kv_cache: bool = True,
 
 
125
  dtype: torch.dtype = torch.bfloat16,
126
  seed: int = None,
127
  ):
@@ -149,31 +176,50 @@ class OmniGenPipeline:
149
  Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
150
  img_guidance_scale (`float`, *optional*, defaults to 1.6):
151
  Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
 
152
  separate_cfg_infer (`bool`, *optional*, defaults to False):
153
  Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
154
  use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
155
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
156
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
157
- to make generation deterministic.
 
 
 
 
158
  Examples:
159
 
160
  Returns:
161
  A list with the generated images.
162
  """
163
- assert height%16 == 0 and width%16 == 0
164
- if separate_cfg_infer:
165
- use_kv_cache = False
166
- # raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
 
167
  if input_images is None:
168
  use_img_guidance = False
169
  if isinstance(prompt, str):
170
  prompt = [prompt]
171
  input_images = [input_images] if input_images is not None else None
 
 
 
 
 
 
 
 
172
 
173
- input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer)
174
 
175
  num_prompt = len(prompt)
176
  num_cfg = 2 if use_img_guidance else 1
 
 
 
 
 
177
  latent_size_h, latent_size_w = height//8, width//8
178
 
179
  if seed is not None:
@@ -183,6 +229,7 @@ class OmniGenPipeline:
183
  latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
184
  latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
185
 
 
186
  input_img_latents = []
187
  if separate_cfg_infer:
188
  for temp_pixel_values in input_data['input_pixel_values']:
@@ -195,6 +242,10 @@ class OmniGenPipeline:
195
  for img in input_data['input_pixel_values']:
196
  img = self.vae_encode(img.to(self.device), dtype)
197
  input_img_latents.append(img)
 
 
 
 
198
 
199
  model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
200
  input_img_latents=input_img_latents,
@@ -204,7 +255,9 @@ class OmniGenPipeline:
204
  cfg_scale=guidance_scale,
205
  img_cfg_scale=img_guidance_scale,
206
  use_img_cfg=use_img_guidance,
207
- use_kv_cache=use_kv_cache)
 
 
208
 
209
  if separate_cfg_infer:
210
  func = self.model.forward_with_separate_cfg
@@ -212,16 +265,38 @@ class OmniGenPipeline:
212
  func = self.model.forward_with_cfg
213
  self.model.to(dtype)
214
 
 
 
 
 
 
 
 
 
 
 
 
215
  scheduler = OmniGenScheduler(num_steps=num_inference_steps)
216
- samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache)
217
  samples = samples.chunk((1+num_cfg), dim=0)[0]
218
 
 
 
 
 
 
 
219
  samples = samples.to(torch.float32)
220
  if self.vae.config.shift_factor is not None:
221
  samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
222
  else:
223
  samples = samples / self.vae.config.scaling_factor
224
  samples = self.vae.decode(samples).sample
 
 
 
 
 
225
 
226
  output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
227
  output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
@@ -229,4 +304,6 @@ class OmniGenPipeline:
229
  for i, sample in enumerate(output_samples):
230
  output_images.append(Image.fromarray(sample))
231
 
 
 
232
  return output_images
 
1
  import os
2
  import inspect
3
  from typing import Any, Callable, Dict, List, Optional, Union
4
+ import gc
5
 
6
  from PIL import Image
7
  import numpy as np
 
34
  >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
35
  >>> image = pipe(
36
  ... prompt,
37
+ ... guidance_scale=2.5,
38
  ... num_inference_steps=50,
39
  ... ).images[0]
40
  >>> image.save("t2i.png")
 
42
  """
43
 
44
 
45
+ 90
46
  class OmniGenPipeline:
47
  def __init__(
48
  self,
 
54
  self.model = model
55
  self.processor = processor
56
 
57
+ if torch.cuda.is_available():
58
+ self.device = torch.device("cuda")
59
+ elif torch.backends.mps.is_available():
60
+ self.device = torch.device("mps")
61
+ elif is_torch_npu_available():
62
+ self.device = torch.device("npu")
63
+ else:
64
+ logger.info("Don't detect any available devices, using CPU instead")
65
+ self.device = torch.device("cpu")
66
+
67
+ self.model.to(torch.bfloat16)
68
  self.model.eval()
69
+ self.vae.eval()
70
+
71
+ self.model_cpu_offload = False
72
 
73
  @classmethod
74
  def from_pretrained(cls, model_name, vae_path: str=None):
 
96
  model = PeftModel.from_pretrained(self.model, lora_path)
97
  model.merge_and_unload()
98
 
 
99
  self.model = model
100
 
101
  def to(self, device: Union[str, torch.device]):
 
103
  device = torch.device(device)
104
  self.model.to(device)
105
  self.vae.to(device)
106
+ self.device = device
107
 
108
  def vae_encode(self, x, dtype):
109
  if self.vae.config.shift_factor is not None:
 
119
  return [x.to(self.device) for x in data]
120
  return data.to(self.device)
121
 
122
+ def enable_model_cpu_offload(self):
123
+ self.model_cpu_offload = True
124
+ self.model.to("cpu")
125
+ self.vae.to("cpu")
126
+ torch.cuda.empty_cache() # Clear VRAM
127
+ gc.collect() # Run garbage collection to free system RAM
128
+
129
+ def disable_model_cpu_offload(self):
130
+ self.model_cpu_offload = False
131
+ self.model.to(self.device)
132
+ self.vae.to(self.device)
133
 
134
  @torch.no_grad()
135
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
143
  guidance_scale: float = 3,
144
  use_img_guidance: bool = True,
145
  img_guidance_scale: float = 1.6,
146
+ max_input_image_size: int = 1024,
147
+ separate_cfg_infer: bool = True,
148
+ offload_model: bool = False,
149
  use_kv_cache: bool = True,
150
+ offload_kv_cache: bool = True,
151
+ use_input_image_size_as_output: bool = False,
152
  dtype: torch.dtype = torch.bfloat16,
153
  seed: int = None,
154
  ):
 
176
  Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
177
  img_guidance_scale (`float`, *optional*, defaults to 1.6):
178
  Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
179
+ max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
180
  separate_cfg_infer (`bool`, *optional*, defaults to False):
181
  Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
182
  use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
183
+ offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
184
+ offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
185
+ use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
186
+ seed (`int`, *optional*):
187
+ A random seed for generating output.
188
+ dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
189
+ data type for the model
190
  Examples:
191
 
192
  Returns:
193
  A list with the generated images.
194
  """
195
+ # check inputs:
196
+ if use_input_image_size_as_output:
197
+ assert isinstance(prompt, str) and len(input_images) == 1, "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images"
198
+ else:
199
+ assert height%16 == 0 and width%16 == 0, "The height and width must be a multiple of 16."
200
  if input_images is None:
201
  use_img_guidance = False
202
  if isinstance(prompt, str):
203
  prompt = [prompt]
204
  input_images = [input_images] if input_images is not None else None
205
+
206
+ # set model and processor
207
+ if max_input_image_size != self.processor.max_image_size:
208
+ self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size)
209
+ if offload_model:
210
+ self.enable_model_cpu_offload()
211
+ else:
212
+ self.disable_model_cpu_offload()
213
 
214
+ input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer, use_input_image_size_as_output=use_input_image_size_as_output)
215
 
216
  num_prompt = len(prompt)
217
  num_cfg = 2 if use_img_guidance else 1
218
+ if use_input_image_size_as_output:
219
+ if separate_cfg_infer:
220
+ height, width = input_data['input_pixel_values'][0][0].shape[-2:]
221
+ else:
222
+ height, width = input_data['input_pixel_values'][0].shape[-2:]
223
  latent_size_h, latent_size_w = height//8, width//8
224
 
225
  if seed is not None:
 
229
  latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
230
  latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
231
 
232
+ if input_images is not None and self.model_cpu_offload: self.vae.to(self.device)
233
  input_img_latents = []
234
  if separate_cfg_infer:
235
  for temp_pixel_values in input_data['input_pixel_values']:
 
242
  for img in input_data['input_pixel_values']:
243
  img = self.vae_encode(img.to(self.device), dtype)
244
  input_img_latents.append(img)
245
+ if input_images is not None and self.model_cpu_offload:
246
+ self.vae.to('cpu')
247
+ torch.cuda.empty_cache() # Clear VRAM
248
+ gc.collect() # Run garbage collection to free system RAM
249
 
250
  model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
251
  input_img_latents=input_img_latents,
 
255
  cfg_scale=guidance_scale,
256
  img_cfg_scale=img_guidance_scale,
257
  use_img_cfg=use_img_guidance,
258
+ use_kv_cache=use_kv_cache,
259
+ offload_model=offload_model,
260
+ )
261
 
262
  if separate_cfg_infer:
263
  func = self.model.forward_with_separate_cfg
 
265
  func = self.model.forward_with_cfg
266
  self.model.to(dtype)
267
 
268
+ if self.model_cpu_offload:
269
+ for name, param in self.model.named_parameters():
270
+ if 'layers' in name and 'layers.0' not in name:
271
+ param.data = param.data.cpu()
272
+ else:
273
+ param.data = param.data.to(self.device)
274
+ for buffer_name, buffer in self.model.named_buffers():
275
+ setattr(self.model, buffer_name, buffer.to(self.device))
276
+ # else:
277
+ # self.model.to(self.device)
278
+
279
  scheduler = OmniGenScheduler(num_steps=num_inference_steps)
280
+ samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
281
  samples = samples.chunk((1+num_cfg), dim=0)[0]
282
 
283
+ if self.model_cpu_offload:
284
+ self.model.to('cpu')
285
+ torch.cuda.empty_cache()
286
+ gc.collect()
287
+
288
+ self.vae.to(self.device)
289
  samples = samples.to(torch.float32)
290
  if self.vae.config.shift_factor is not None:
291
  samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
292
  else:
293
  samples = samples / self.vae.config.scaling_factor
294
  samples = self.vae.decode(samples).sample
295
+
296
+ if self.model_cpu_offload:
297
+ self.vae.to('cpu')
298
+ torch.cuda.empty_cache()
299
+ gc.collect()
300
 
301
  output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
302
  output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
 
304
  for i, sample in enumerate(output_samples):
305
  output_images.append(Image.fromarray(sample))
306
 
307
+ torch.cuda.empty_cache() # Clear VRAM
308
+ gc.collect() # Run garbage collection to free system RAM
309
  return output_images
OmniGen/processor.py CHANGED
@@ -108,6 +108,7 @@ class OmniGenProcessor:
108
  negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
109
  use_img_cfg: bool = True,
110
  separate_cfg_input: bool = False,
 
111
  ) -> Dict:
112
 
113
  if input_images is None:
@@ -138,7 +139,10 @@ class OmniGenProcessor:
138
  else:
139
  img_cfg_mllm_input = neg_mllm_input
140
 
141
- input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
 
 
 
142
 
143
  if separate_cfg_input:
144
  return self.separate_collator(input_data)
@@ -295,7 +299,6 @@ class OmniGenSeparateCollator(OmniGenCollator):
295
  cfg_mllm_inputs = [f[1] for f in features]
296
  img_cfg_mllm_input = [f[2] for f in features]
297
  target_img_size = [f[3] for f in features]
298
-
299
 
300
  all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
301
 
 
108
  negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
109
  use_img_cfg: bool = True,
110
  separate_cfg_input: bool = False,
111
+ use_input_image_size_as_output: bool=False,
112
  ) -> Dict:
113
 
114
  if input_images is None:
 
139
  else:
140
  img_cfg_mllm_input = neg_mllm_input
141
 
142
+ if use_input_image_size_as_output:
143
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
144
+ else:
145
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
146
 
147
  if separate_cfg_input:
148
  return self.separate_collator(input_data)
 
299
  cfg_mllm_inputs = [f[1] for f in features]
300
  img_cfg_mllm_input = [f[2] for f in features]
301
  target_img_size = [f[3] for f in features]
 
302
 
303
  all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
304
 
OmniGen/scheduler.py CHANGED
@@ -1,6 +1,116 @@
1
- import torch
2
  from tqdm import tqdm
3
- from transformers.cache_utils import Cache, DynamicCache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class OmniGenScheduler:
6
  def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
@@ -12,12 +122,13 @@ class OmniGenScheduler:
12
  self.sigma = t
13
 
14
  def crop_kv_cache(self, past_key_values, num_tokens_for_img):
 
15
  crop_past_key_values = ()
16
  for layer_idx in range(len(past_key_values)):
17
  key_states, value_states = past_key_values[layer_idx][:2]
18
  crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
19
- return crop_past_key_values
20
- # return DynamicCache.from_legacy_cache(crop_past_key_values)
21
 
22
  def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
23
  if isinstance(position_ids, list):
@@ -32,24 +143,39 @@ class OmniGenScheduler:
32
  return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
33
  return attention_mask[..., -(num_tokens_for_img+1):, :]
34
 
35
- def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True):
36
- past_key_values = None
 
 
 
 
 
 
 
 
 
 
 
 
37
  for i in tqdm(range(self.num_steps)):
38
  timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
39
- pred, temp_past_key_values = func(z, timesteps, past_key_values=past_key_values, **model_kwargs)
40
  sigma_next = self.sigma[i+1]
41
  sigma = self.sigma[i]
42
  z = z + (sigma_next - sigma) * pred
43
  if i == 0 and use_kv_cache:
44
  num_tokens_for_img = z.size(-1)*z.size(-2) // 4
45
- if isinstance(temp_past_key_values, list):
46
- past_key_values = [self.crop_kv_cache(x, num_tokens_for_img) for x in temp_past_key_values]
47
- model_kwargs['input_ids'] = [None] * len(temp_past_key_values)
48
  else:
49
- past_key_values = self.crop_kv_cache(temp_past_key_values, num_tokens_for_img)
50
  model_kwargs['input_ids'] = None
51
 
52
  model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
53
  model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
 
 
 
 
54
  return z
55
 
 
 
 
1
  from tqdm import tqdm
2
+ from typing import Optional, Dict, Any, Tuple, List
3
+ import gc
4
+
5
+ import torch
6
+ from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
7
+
8
+
9
+
10
+ class OmniGenCache(DynamicCache):
11
+ def __init__(self,
12
+ num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
13
+ if not torch.cuda.is_available():
14
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
15
+ super().__init__()
16
+ self.original_device = []
17
+ self.prefetch_stream = torch.cuda.Stream()
18
+ self.num_tokens_for_img = num_tokens_for_img
19
+ self.offload_kv_cache = offload_kv_cache
20
+
21
+ def prefetch_layer(self, layer_idx: int):
22
+ "Starts prefetching the next layer cache"
23
+ if layer_idx < len(self):
24
+ with torch.cuda.stream(self.prefetch_stream):
25
+ # Prefetch next layer tensors to GPU
26
+ device = self.original_device[layer_idx]
27
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
28
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
29
+
30
+
31
+ def evict_previous_layer(self, layer_idx: int):
32
+ "Moves the previous layer cache to the CPU"
33
+ if len(self) > 2:
34
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
35
+ if layer_idx == 0:
36
+ prev_layer_idx = -1
37
+ else:
38
+ prev_layer_idx = (layer_idx - 1) % len(self)
39
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
40
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
41
+
42
+
43
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
44
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
45
+ if layer_idx < len(self):
46
+ if self.offload_kv_cache:
47
+ # Evict the previous layer if necessary
48
+ torch.cuda.current_stream().synchronize()
49
+ self.evict_previous_layer(layer_idx)
50
+ # Load current layer cache to its original device if not already there
51
+ original_device = self.original_device[layer_idx]
52
+ # self.prefetch_stream.synchronize(original_device)
53
+ torch.cuda.synchronize(self.prefetch_stream)
54
+ key_tensor = self.key_cache[layer_idx]
55
+ value_tensor = self.value_cache[layer_idx]
56
+
57
+ # Prefetch the next layer
58
+ self.prefetch_layer((layer_idx + 1) % len(self))
59
+ else:
60
+ key_tensor = self.key_cache[layer_idx]
61
+ value_tensor = self.value_cache[layer_idx]
62
+ return (key_tensor, value_tensor)
63
+ else:
64
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
65
+
66
+
67
+ def update(
68
+ self,
69
+ key_states: torch.Tensor,
70
+ value_states: torch.Tensor,
71
+ layer_idx: int,
72
+ cache_kwargs: Optional[Dict[str, Any]] = None,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
76
+ Parameters:
77
+ key_states (`torch.Tensor`):
78
+ The new key states to cache.
79
+ value_states (`torch.Tensor`):
80
+ The new value states to cache.
81
+ layer_idx (`int`):
82
+ The index of the layer to cache the states for.
83
+ cache_kwargs (`Dict[str, Any]`, `optional`):
84
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
85
+ Return:
86
+ A tuple containing the updated key and value states.
87
+ """
88
+ # Update the cache
89
+ if len(self.key_cache) < layer_idx:
90
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
91
+ elif len(self.key_cache) == layer_idx:
92
+ # only cache the states for condition tokens
93
+ key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
94
+ value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
95
+
96
+ # Update the number of seen tokens
97
+ if layer_idx == 0:
98
+ self._seen_tokens += key_states.shape[-2]
99
+
100
+ self.key_cache.append(key_states)
101
+ self.value_cache.append(value_states)
102
+ self.original_device.append(key_states.device)
103
+ if self.offload_kv_cache:
104
+ self.evict_previous_layer(layer_idx)
105
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
106
+ else:
107
+ # only cache the states for condition tokens
108
+ key_tensor, value_tensor = self[layer_idx]
109
+ k = torch.cat([key_tensor, key_states], dim=-2)
110
+ v = torch.cat([value_tensor, value_states], dim=-2)
111
+ return k, v
112
+
113
+
114
 
115
  class OmniGenScheduler:
116
  def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
 
122
  self.sigma = t
123
 
124
  def crop_kv_cache(self, past_key_values, num_tokens_for_img):
125
+ # return
126
  crop_past_key_values = ()
127
  for layer_idx in range(len(past_key_values)):
128
  key_states, value_states = past_key_values[layer_idx][:2]
129
  crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
130
+ # return crop_past_key_values
131
+ return DynamicCache.from_legacy_cache(crop_past_key_values)
132
 
133
  def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
134
  if isinstance(position_ids, list):
 
143
  return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
144
  return attention_mask[..., -(num_tokens_for_img+1):, :]
145
 
146
+ def crop_cache(self, cache, num_tokens_for_img):
147
+ for i in range(len(cache.key_cache)):
148
+ cache.key_cache[i] = cache.key_cache[i][..., :-(num_tokens_for_img+1), :]
149
+ cache.value_cache[i] = cache.value_cache[i][..., :-(num_tokens_for_img+1), :]
150
+
151
+ return cache
152
+
153
+ def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True, offload_kv_cache: bool=True):
154
+ num_tokens_for_img = z.size(-1)*z.size(-2) // 4
155
+ if isinstance(model_kwargs['input_ids'], list):
156
+ cache = [OmniGenCache(num_tokens_for_img, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
157
+ else:
158
+ cache = OmniGenCache(num_tokens_for_img, offload_kv_cache) if use_kv_cache else None
159
+ results = {}
160
  for i in tqdm(range(self.num_steps)):
161
  timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
162
+ pred, cache = func(z, timesteps, past_key_values=cache, **model_kwargs)
163
  sigma_next = self.sigma[i+1]
164
  sigma = self.sigma[i]
165
  z = z + (sigma_next - sigma) * pred
166
  if i == 0 and use_kv_cache:
167
  num_tokens_for_img = z.size(-1)*z.size(-2) // 4
168
+ if isinstance(cache, list):
169
+ model_kwargs['input_ids'] = [None] * len(cache)
 
170
  else:
 
171
  model_kwargs['input_ids'] = None
172
 
173
  model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
174
  model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
175
+
176
+ del cache
177
+ torch.cuda.empty_cache()
178
+ gc.collect()
179
  return z
180
 
181
+
OmniGen/transformer.py CHANGED
@@ -29,6 +29,34 @@ class Phi3Transformer(Phi3Model):
29
  Args:
30
  config: Phi3Config
31
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def forward(
34
  self,
@@ -42,6 +70,7 @@ class Phi3Transformer(Phi3Model):
42
  output_hidden_states: Optional[bool] = None,
43
  return_dict: Optional[bool] = None,
44
  cache_position: Optional[torch.LongTensor] = None,
 
45
  ) -> Union[Tuple, BaseModelOutputWithPast]:
46
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
47
  output_hidden_states = (
@@ -75,16 +104,16 @@ class Phi3Transformer(Phi3Model):
75
  "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
76
  )
77
 
78
- if inputs_embeds is None:
79
- inputs_embeds = self.embed_tokens(input_ids)
80
 
81
- if cache_position is None:
82
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
83
- cache_position = torch.arange(
84
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
85
- )
86
- if position_ids is None:
87
- position_ids = cache_position.unsqueeze(0)
88
 
89
  if attention_mask is not None and attention_mask.dim() == 3:
90
  dtype = inputs_embeds.dtype
@@ -104,7 +133,10 @@ class Phi3Transformer(Phi3Model):
104
  all_self_attns = () if output_attentions else None
105
  next_decoder_cache = None
106
 
 
107
  for decoder_layer in self.layers:
 
 
108
  if output_hidden_states:
109
  all_hidden_states += (hidden_states,)
110
 
@@ -120,6 +152,8 @@ class Phi3Transformer(Phi3Model):
120
  cache_position,
121
  )
122
  else:
 
 
123
  layer_outputs = decoder_layer(
124
  hidden_states,
125
  attention_mask=attention_mask,
@@ -142,6 +176,7 @@ class Phi3Transformer(Phi3Model):
142
 
143
  # add hidden states from the last decoder layer
144
  if output_hidden_states:
 
145
  all_hidden_states += (hidden_states,)
146
 
147
  next_cache = next_decoder_cache if use_cache else None
 
29
  Args:
30
  config: Phi3Config
31
  """
32
+ def prefetch_layer(self, layer_idx: int, device: torch.device):
33
+ "Starts prefetching the next layer cache"
34
+ with torch.cuda.stream(self.prefetch_stream):
35
+ # Prefetch next layer tensors to GPU
36
+ for name, param in self.layers[layer_idx].named_parameters():
37
+ param.data = param.data.to(device, non_blocking=True)
38
+
39
+ def evict_previous_layer(self, layer_idx: int):
40
+ "Moves the previous layer cache to the CPU"
41
+ prev_layer_idx = layer_idx - 1
42
+ for name, param in self.layers[prev_layer_idx].named_parameters():
43
+ param.data = param.data.to("cpu", non_blocking=True)
44
+
45
+ def get_offlaod_layer(self, layer_idx: int, device: torch.device):
46
+ # init stream
47
+ if not hasattr(self, "prefetch_stream"):
48
+ self.prefetch_stream = torch.cuda.Stream()
49
+
50
+ # delete previous layer
51
+ torch.cuda.current_stream().synchronize()
52
+ self.evict_previous_layer(layer_idx)
53
+
54
+ # make sure the current layer is ready
55
+ torch.cuda.synchronize(self.prefetch_stream)
56
+
57
+ # load next layer
58
+ self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
59
+
60
 
61
  def forward(
62
  self,
 
70
  output_hidden_states: Optional[bool] = None,
71
  return_dict: Optional[bool] = None,
72
  cache_position: Optional[torch.LongTensor] = None,
73
+ offload_model: Optional[bool] = False,
74
  ) -> Union[Tuple, BaseModelOutputWithPast]:
75
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
76
  output_hidden_states = (
 
104
  "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
105
  )
106
 
107
+ # if inputs_embeds is None:
108
+ # inputs_embeds = self.embed_tokens(input_ids)
109
 
110
+ # if cache_position is None:
111
+ # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
112
+ # cache_position = torch.arange(
113
+ # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
114
+ # )
115
+ # if position_ids is None:
116
+ # position_ids = cache_position.unsqueeze(0)
117
 
118
  if attention_mask is not None and attention_mask.dim() == 3:
119
  dtype = inputs_embeds.dtype
 
133
  all_self_attns = () if output_attentions else None
134
  next_decoder_cache = None
135
 
136
+ layer_idx = -1
137
  for decoder_layer in self.layers:
138
+ layer_idx += 1
139
+
140
  if output_hidden_states:
141
  all_hidden_states += (hidden_states,)
142
 
 
152
  cache_position,
153
  )
154
  else:
155
+ if offload_model and not self.training:
156
+ self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
157
  layer_outputs = decoder_layer(
158
  hidden_states,
159
  attention_mask=attention_mask,
 
176
 
177
  # add hidden states from the last decoder layer
178
  if output_hidden_states:
179
+ print('************')
180
  all_hidden_states += (hidden_states,)
181
 
182
  next_cache = next_decoder_cache if use_cache else None
README.md CHANGED
@@ -1,13 +1,180 @@
1
- ---
2
- title: OmniGen
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">OmniGen: Unified Image Generation</h1>
2
+
3
+
4
+ <p align="center">
5
+ <a href="">
6
+ <img alt="Build" src="https://img.shields.io/badge/Project%20Page-OmniGen-yellow">
7
+ </a>
8
+ <a href="https://arxiv.org/abs/2409.11340">
9
+ <img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2409.11340-b31b1b.svg">
10
+ </a>
11
+ <a href="https://huggingface.co/spaces/Shitao/OmniGen">
12
+ <img alt="License" src="https://img.shields.io/badge/HF%20Demo-🤗-lightblue">
13
+ </a>
14
+ <a href="https://huggingface.co/Shitao/OmniGen-v1">
15
+ <img alt="Build" src="https://img.shields.io/badge/HF%20Model-🤗-yellow">
16
+ </a>
17
+ </p>
18
+
19
+ <h4 align="center">
20
+ <p>
21
+ <a href=#1-news>News</a> |
22
+ <a href=#3-methodology>Methodology</a> |
23
+ <a href=#4-what-can-omnigen-do>Capabilities</a> |
24
+ <a href=#5-quick-start>Quick Start</a> |
25
+ <a href="#6-finetune">Finetune</a> |
26
+ <a href="#license">License</a> |
27
+ <a href="#citation">Citation</a>
28
+ <p>
29
+ </h4>
30
+
31
+
32
+
33
+ ## 1. News
34
+ - 2024-10-28: We release new version of inference code, optimizing the memory usage and time cost. You can refer to [docs/inference.md](docs/inference.md#requiremented-resources) for detailed information.
35
+ - 2024-10-22: :fire: We release the code for OmniGen. Inference: [docs/inference.md](docs/inference.md) Train: [docs/fine-tuning.md](docs/fine-tuning.md)
36
+ - 2024-10-22: :fire: We release the first version of OmniGen. Model Weight: [Shitao/OmniGen-v1](https://huggingface.co/Shitao/OmniGen-v1) HF Demo: [🤗](https://huggingface.co/spaces/Shitao/OmniGen)
37
+
38
+
39
+ ## 2. Overview
40
+
41
+ OmniGen is a unified image generation model that can generate a wide range of images from multi-modal prompts. It is designed to be simple, flexible and easy to use. We provide [inference code](#5-quick-start) so that everyone can explore more functionalities of OmniGen.
42
+
43
+ Existing image generation models often require loading several additional network modules (such as ControlNet, IP-Adapter, Reference-Net, etc.) and performing extra preprocessing steps (e.g., face detection, pose estimation, cropping, etc.) to generate a satisfactory image. However, **we believe that the future image generation paradigm should be more simple and flexible, that is, generating various images directly through arbitrarily multi-modal instructions without the need for additional plugins and operations, similar to how GPT works in language generation.**
44
+
45
+ Due to the limited resources, OmniGen still has room for improvement. We will continue to optimize it, and hope it inspire more universal image generation models. You can also easily fine-tune OmniGen without worrying about designing networks for specific tasks; you just need to prepare the corresponding data, and then run the [script](#6-finetune). Imagination is no longer limited; everyone can construct any image generation task, and perhaps we can achieve very interesting, wonderful and creative things.
46
+
47
+ If you have any questions, ideas or interesting tasks you want OmniGen to accomplish, feel free to discuss with us: 2906698981@qq.com, wangyueze@tju.edu.cn, zhengliu1026@gmail.com. We welcome any feedback to help us improve the model.
48
+
49
+
50
+
51
+
52
+ ## 3. Methodology
53
+
54
+ You can see details in our [paper](https://arxiv.org/abs/2409.11340).
55
+
56
+
57
+ ## 4. What Can OmniGen do?
58
+
59
+
60
+ OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, image editing, and image-conditioned generation. **OmniGen don't need additional plugins or operations, it can automatically identify the features (e.g., required object, human pose, depth mapping) in input images according the text prompt.**
61
+ We showcase some examples in [inference.ipynb](inference.ipynb). And in [inference_demo.ipynb](inference_demo.ipynb), we show an interesting pipeline to generate and modify a image.
62
+
63
+ Here is the illustration of OmniGen's capabilities:
64
+ - You can control the image generation flexibly via OmniGen
65
+ ![demo](./imgs/demo_cases.png)
66
+ - Referring Expression Generation: You can input multiple images and use simple, general language to refer to the objects within those images. OmniGen can automatically recognize the necessary objects in each image and generate new images based on them. No additional operations, such as image cropping or face detection, are required.
67
+ ![demo](./imgs/referring.png)
68
+
69
+ If you are not entirely satisfied with certain functionalities or wish to add new capabilities, you can try [fine-tuning OmniGen](#6-finetune).
70
+
71
+
72
+
73
+ ## 5. Quick Start
74
+
75
+
76
+ ### Using OmniGen
77
+ Install via Github:
78
+ ```bash
79
+ git clone https://github.com/staoxiao/OmniGen.git
80
+ cd OmniGen
81
+ pip install -e .
82
+ ```
83
+
84
+ Here are some examples:
85
+ ```python
86
+ from OmniGen import OmniGenPipeline
87
+
88
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
89
+
90
+ # Text to Image
91
+ images = pipe(
92
+ prompt="A curly-haired man in a red shirt is drinking tea.",
93
+ height=1024,
94
+ width=1024,
95
+ guidance_scale=2.5,
96
+ seed=0,
97
+ )
98
+ images[0].save("example_t2i.png") # save output PIL Image
99
+
100
+ # Multi-modal to Image
101
+ # In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
102
+ # You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
103
+ images = pipe(
104
+ prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
105
+ input_images=["./imgs/test_cases/two_man.jpg"],
106
+ height=1024,
107
+ width=1024,
108
+ guidance_scale=2.5,
109
+ img_guidance_scale=1.6,
110
+ seed=0
111
+ )
112
+ images[0].save("example_ti2i.png") # save output PIL image
113
+ ```
114
+ - For thre required resources and the method to run OmniGen efficiently, please refer to [docs/inference.md#requiremented-resources](docs/inference.md#requiremented-resources).
115
+ - For more examples for image generation, you can refer to [inference.ipynb](inference.ipynb) and [inference_demo.ipynb](inference_demo.ipynb)
116
+ - For more details about the argument in inference, please refer to [docs/inference.md](docs/inference.md).
117
+
118
+
119
+ ### Using Diffusers
120
+ Coming soon.
121
+
122
+
123
+ ### Gradio Demo
124
+
125
+ We construct an online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).
126
+
127
+ For the local gradio demo, you need to install `pip install gradio spaces` , and then you can run:
128
+ ```python
129
+ pip install gradio spaces
130
+ python app.py
131
+ ```
132
+
133
+
134
+
135
+ ## 6. Finetune
136
+ We provide a training script `train.py` to fine-tune OmniGen.
137
+ Here is a toy example about LoRA finetune:
138
+ ```bash
139
+ accelerate launch --num_processes=1 train.py \
140
+ --model_name_or_path Shitao/OmniGen-v1 \
141
+ --batch_size_per_device 2 \
142
+ --condition_dropout_prob 0.01 \
143
+ --lr 1e-3 \
144
+ --use_lora \
145
+ --lora_rank 8 \
146
+ --json_file ./toy_data/toy_subject_data.jsonl \
147
+ --image_path ./toy_data/images \
148
+ --max_input_length_limit 18000 \
149
+ --keep_raw_resolution \
150
+ --max_image_size 1024 \
151
+ --gradient_accumulation_steps 1 \
152
+ --ckpt_every 10 \
153
+ --epochs 200 \
154
+ --log_every 1 \
155
+ --results_dir ./results/toy_finetune_lora
156
+ ```
157
+
158
+ Please refer to [docs/fine-tuning.md](docs/fine-tuning.md) for more details (e.g. full finetune).
159
+
160
+
161
+
162
+ ## License
163
+ This repo is licensed under the [MIT License](LICENSE).
164
+
165
+
166
+ ## Citation
167
+ If you find this repository useful, please consider giving a star ⭐ and citation
168
+ ```
169
+ @article{xiao2024omnigen,
170
+ title={Omnigen: Unified image generation},
171
+ author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
172
+ journal={arXiv preprint arXiv:2409.11340},
173
+ year={2024}
174
+ }
175
+ ```
176
+
177
+
178
+
179
+
180
+
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import os
 
4
  import spaces
5
 
6
  from OmniGen import OmniGenPipeline
@@ -9,11 +10,11 @@ pipe = OmniGenPipeline.from_pretrained(
9
  "Shitao/OmniGen-v1"
10
  )
11
 
12
- @spaces.GPU(duration=180)
13
- # 示例处理函数:生成图像
14
- def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer):
15
  input_images = [img1, img2, img3]
16
- # 去除 None
17
  input_images = [img for img in input_images if img is not None]
18
  if len(input_images) == 0:
19
  input_images = None
@@ -24,25 +25,18 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_gu
24
  height=height,
25
  width=width,
26
  guidance_scale=guidance_scale,
27
- img_guidance_scale=1.6,
28
  num_inference_steps=inference_steps,
29
- separate_cfg_infer=True, # set False can speed up the inference process
30
- use_kv_cache=False,
 
 
 
31
  seed=seed,
32
- # separate_cfg_infer=separate_cfg_infer,
33
  )
34
  img = output[0]
35
  return img
36
- # def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
37
- # input_images = []
38
- # if img1:
39
- # input_images.append(Image.open(img1))
40
- # if img2:
41
- # input_images.append(Image.open(img2))
42
- # if img3:
43
- # input_images.append(Image.open(img3))
44
-
45
- # return input_images[0] if input_images else None
46
 
47
 
48
  def get_example():
@@ -59,6 +53,8 @@ def get_example():
59
  50,
60
  0,
61
  True,
 
 
62
  ],
63
  [
64
  "The woman in <img><|image_1|></img> waves her hand happily in the crowd",
@@ -72,6 +68,8 @@ def get_example():
72
  50,
73
  128,
74
  True,
 
 
75
  ],
76
  [
77
  "A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
@@ -85,6 +83,8 @@ def get_example():
85
  50,
86
  0,
87
  True,
 
 
88
  ],
89
  [
90
  "Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
@@ -98,6 +98,8 @@ def get_example():
98
  50,
99
  168,
100
  True,
 
 
101
  ],
102
  [
103
  "A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
@@ -111,6 +113,8 @@ def get_example():
111
  50,
112
  60,
113
  True,
 
 
114
  ],
115
  [
116
  "A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
@@ -124,6 +128,8 @@ def get_example():
124
  50,
125
  66,
126
  True,
 
 
127
  ],
128
  [
129
  "The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
@@ -137,147 +143,185 @@ def get_example():
137
  50,
138
  0,
139
  True,
 
 
140
  ],
141
  [
142
  "<img><|image_1|><img>\n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.",
143
  "./imgs/demo_cases/t2i_woman_with_book.png",
144
  None,
145
  None,
146
- 1024,
147
- 1024,
148
  2.5,
149
  1.6,
150
  50,
151
  222,
152
  True,
 
 
153
  ],
154
  [
155
  "Detect the skeleton of human in this image: <img><|image_1|></img>.",
156
  "./imgs/test_cases/control.jpg",
157
  None,
158
  None,
159
- 1024,
160
- 1024,
161
  2.0,
162
  1.6,
163
  50,
164
  0,
165
  True,
 
 
166
  ],
167
  [
168
  "Generate a new photo using the following picture and text as conditions: <img><|image_1|><img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
169
  "./imgs/demo_cases/skeletal.png",
170
  None,
171
  None,
172
- 1024,
173
- 1024,
174
  2,
175
  1.6,
176
  50,
177
  42,
178
  True,
 
 
179
  ],
180
  [
181
  "Following the pose of this image <img><|image_1|><img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
182
  "./imgs/demo_cases/edit.png",
183
  None,
184
  None,
185
- 1024,
186
- 1024,
187
  2.0,
188
  1.6,
189
  50,
190
  123,
191
  True,
 
 
192
  ],
193
  [
194
  "Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
195
  "./imgs/demo_cases/edit.png",
196
  None,
197
  None,
198
- 1024,
199
- 1024,
200
  2.0,
201
  1.6,
202
  50,
203
  1,
204
  True,
 
 
205
  ],
206
  [
207
  "<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
208
  "./imgs/test_cases/watch.jpg",
209
  None,
210
  None,
211
- 1024,
212
- 1024,
213
  2.5,
214
  1.6,
215
  50,
216
  0,
217
  True,
 
 
218
  ],
219
  [
220
  "According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
221
  "./imgs/test_cases/icl1.jpg",
222
  "./imgs/test_cases/icl2.jpg",
223
  "./imgs/test_cases/icl3.jpg",
224
- 1024,
225
- 1024,
226
  2.5,
227
  1.6,
228
  50,
229
  1,
230
  True,
 
 
231
  ],
232
  ]
233
  return case
234
 
235
- def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer,):
236
- return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer,)
 
 
237
 
238
  description = """
239
  OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
240
-
241
  For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
242
  For example, use an image of a woman to generate a new image:
243
  prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
244
 
245
  Tips:
 
246
  - Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
 
247
  - Low-quality: More detailed prompt will lead to better results.
248
  - Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
249
  - Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
250
- - For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  """
252
 
253
- separate_cfg_infer_arg = False
254
 
255
- # Gradio 接口
256
  with gr.Blocks() as demo:
257
  gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
258
  gr.Markdown(description)
259
  with gr.Row():
260
  with gr.Column():
261
- # 文本输入框
262
  prompt_input = gr.Textbox(
263
  label="Enter your prompt, use <img><|image_i|></img> to represent i-th input image", placeholder="Type your prompt here..."
264
  )
265
 
266
  with gr.Row(equal_height=True):
267
- # 图片上传框
268
  image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
269
  image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
270
  image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
271
 
272
- # 高度和宽度滑块
273
  height_input = gr.Slider(
274
- label="Height", minimum=256, maximum=2048, value=1024, step=16
275
  )
276
  width_input = gr.Slider(
277
- label="Width", minimum=256, maximum=2048, value=1024, step=16
278
  )
279
 
280
- # 引导尺度输入
281
  guidance_scale_input = gr.Slider(
282
  label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
283
  )
@@ -295,17 +339,24 @@ with gr.Blocks() as demo:
295
  )
296
 
297
  separate_cfg_infer = gr.Checkbox(
298
- label="separate_cfg_infer", info="enable separate cfg infer"
 
 
 
 
 
 
299
  )
300
 
301
- # 生成按钮
302
  generate_button = gr.Button("Generate Image")
 
303
 
304
  with gr.Column():
305
- # 输出图像框
306
  output_image = gr.Image(label="Output Image")
307
 
308
- # 按钮点击事件
309
  generate_button.click(
310
  generate_image,
311
  inputs=[
@@ -320,6 +371,8 @@ with gr.Blocks() as demo:
320
  num_inference_steps,
321
  seed_input,
322
  separate_cfg_infer,
 
 
323
  ],
324
  outputs=output_image,
325
  )
@@ -339,9 +392,13 @@ with gr.Blocks() as demo:
339
  num_inference_steps,
340
  seed_input,
341
  separate_cfg_infer,
 
 
342
  ],
343
  outputs=output_image,
344
  )
345
 
346
- # 启动应用
 
 
347
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
+
5
  import spaces
6
 
7
  from OmniGen import OmniGenPipeline
 
10
  "Shitao/OmniGen-v1"
11
  )
12
 
13
+ @spaces.GPU(duration=300)
14
+ def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
15
+ use_input_image_size_as_output):
16
  input_images = [img1, img2, img3]
17
+ # Delete None
18
  input_images = [img for img in input_images if img is not None]
19
  if len(input_images) == 0:
20
  input_images = None
 
25
  height=height,
26
  width=width,
27
  guidance_scale=guidance_scale,
28
+ img_guidance_scale=img_guidance_scale,
29
  num_inference_steps=inference_steps,
30
+ separate_cfg_infer=separate_cfg_infer,
31
+ use_kv_cache=True,
32
+ offload_kv_cache=True,
33
+ offload_model=offload_model,
34
+ use_input_image_size_as_output=use_input_image_size_as_output,
35
  seed=seed,
 
36
  )
37
  img = output[0]
38
  return img
39
+
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def get_example():
 
53
  50,
54
  0,
55
  True,
56
+ False,
57
+ False,
58
  ],
59
  [
60
  "The woman in <img><|image_1|></img> waves her hand happily in the crowd",
 
68
  50,
69
  128,
70
  True,
71
+ False,
72
+ False,
73
  ],
74
  [
75
  "A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
 
83
  50,
84
  0,
85
  True,
86
+ False,
87
+ False,
88
  ],
89
  [
90
  "Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
 
98
  50,
99
  168,
100
  True,
101
+ False,
102
+ False,
103
  ],
104
  [
105
  "A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
 
113
  50,
114
  60,
115
  True,
116
+ False,
117
+ False,
118
  ],
119
  [
120
  "A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
 
128
  50,
129
  66,
130
  True,
131
+ False,
132
+ False,
133
  ],
134
  [
135
  "The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
 
143
  50,
144
  0,
145
  True,
146
+ False,
147
+ False,
148
  ],
149
  [
150
  "<img><|image_1|><img>\n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.",
151
  "./imgs/demo_cases/t2i_woman_with_book.png",
152
  None,
153
  None,
154
+ None,
155
+ None,
156
  2.5,
157
  1.6,
158
  50,
159
  222,
160
  True,
161
+ False,
162
+ True,
163
  ],
164
  [
165
  "Detect the skeleton of human in this image: <img><|image_1|></img>.",
166
  "./imgs/test_cases/control.jpg",
167
  None,
168
  None,
169
+ None,
170
+ None,
171
  2.0,
172
  1.6,
173
  50,
174
  0,
175
  True,
176
+ False,
177
+ True,
178
  ],
179
  [
180
  "Generate a new photo using the following picture and text as conditions: <img><|image_1|><img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
181
  "./imgs/demo_cases/skeletal.png",
182
  None,
183
  None,
184
+ None,
185
+ None,
186
  2,
187
  1.6,
188
  50,
189
  42,
190
  True,
191
+ False,
192
+ True,
193
  ],
194
  [
195
  "Following the pose of this image <img><|image_1|><img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
196
  "./imgs/demo_cases/edit.png",
197
  None,
198
  None,
199
+ None,
200
+ None,
201
  2.0,
202
  1.6,
203
  50,
204
  123,
205
  True,
206
+ False,
207
+ True,
208
  ],
209
  [
210
  "Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
211
  "./imgs/demo_cases/edit.png",
212
  None,
213
  None,
214
+ None,
215
+ None,
216
  2.0,
217
  1.6,
218
  50,
219
  1,
220
  True,
221
+ False,
222
+ True,
223
  ],
224
  [
225
  "<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
226
  "./imgs/test_cases/watch.jpg",
227
  None,
228
  None,
229
+ None,
230
+ None,
231
  2.5,
232
  1.6,
233
  50,
234
  0,
235
  True,
236
+ False,
237
+ True,
238
  ],
239
  [
240
  "According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
241
  "./imgs/test_cases/icl1.jpg",
242
  "./imgs/test_cases/icl2.jpg",
243
  "./imgs/test_cases/icl3.jpg",
244
+ 224,
245
+ 224,
246
  2.5,
247
  1.6,
248
  50,
249
  1,
250
  True,
251
+ False,
252
+ False,
253
  ],
254
  ]
255
  return case
256
 
257
+ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
258
+ use_input_image_size_as_output):
259
+ return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
260
+ use_input_image_size_as_output)
261
 
262
  description = """
263
  OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
 
264
  For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
265
  For example, use an image of a woman to generate a new image:
266
  prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
267
 
268
  Tips:
269
+ - For out of memory or time cost, you can refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources) to select a appropriate setting.
270
  - Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
271
+ - Not match the prompt: If the image does not match the prompt, please try to increase the `guidance_scale`.
272
  - Low-quality: More detailed prompt will lead to better results.
273
  - Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
274
  - Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
275
+ - For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
276
+ - For image editing task and controlnet task, we recommend to set the height and width of output image as the same as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the `use_input_image_size_as_output` to automatically set the height and width of output image as the same as input image.
277
+
278
+
279
+ """
280
+
281
+ article = """
282
+ ---
283
+ **Citation**
284
+ <br>
285
+ If you find this repository useful, please consider giving a star ⭐ and citation
286
+ ```
287
+ @article{xiao2024omnigen,
288
+ title={Omnigen: Unified image generation},
289
+ author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
290
+ journal={arXiv preprint arXiv:2409.11340},
291
+ year={2024}
292
+ }
293
+ ```
294
+ **Contact**
295
+ <br>
296
+ If you have any questions, please feel free to open an issue or directly reach us out via email.
297
  """
298
 
 
299
 
300
+ # Gradio
301
  with gr.Blocks() as demo:
302
  gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
303
  gr.Markdown(description)
304
  with gr.Row():
305
  with gr.Column():
306
+ # text prompt
307
  prompt_input = gr.Textbox(
308
  label="Enter your prompt, use <img><|image_i|></img> to represent i-th input image", placeholder="Type your prompt here..."
309
  )
310
 
311
  with gr.Row(equal_height=True):
312
+ # input images
313
  image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
314
  image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
315
  image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
316
 
317
+ # slider
318
  height_input = gr.Slider(
319
+ label="Height", minimum=128, maximum=2048, value=1024, step=16
320
  )
321
  width_input = gr.Slider(
322
+ label="Width", minimum=128, maximum=2048, value=1024, step=16
323
  )
324
 
 
325
  guidance_scale_input = gr.Slider(
326
  label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
327
  )
 
339
  )
340
 
341
  separate_cfg_infer = gr.Checkbox(
342
+ label="separate_cfg_infer", info="Whether to use separate inference process for different guidance. This will reduce the memory cost.", value=True,
343
+ )
344
+ offload_model = gr.Checkbox(
345
+ label="offload_model", info="Offload model to CPU, which will significantly reduce the memory cost but slow down the generation speed. You can cancle separate_cfg_infer and set offload_model=True. If both separate_cfg_infer and offload_model be True, further reduce the memory, but slowest generation", value=False,
346
+ )
347
+ use_input_image_size_as_output = gr.Checkbox(
348
+ label="use_input_image_size_as_output", info="Automatically adjust the output image size to be same as input image size. For editing and controlnet task, it can make sure the output image has the same size with input image leading to better performance", value=False,
349
  )
350
 
351
+ # generate
352
  generate_button = gr.Button("Generate Image")
353
+
354
 
355
  with gr.Column():
356
+ # output image
357
  output_image = gr.Image(label="Output Image")
358
 
359
+ # click
360
  generate_button.click(
361
  generate_image,
362
  inputs=[
 
371
  num_inference_steps,
372
  seed_input,
373
  separate_cfg_infer,
374
+ offload_model,
375
+ use_input_image_size_as_output,
376
  ],
377
  outputs=output_image,
378
  )
 
392
  num_inference_steps,
393
  seed_input,
394
  separate_cfg_infer,
395
+ offload_model,
396
+ use_input_image_size_as_output,
397
  ],
398
  outputs=output_image,
399
  )
400
 
401
+ gr.Markdown(article)
402
+
403
+ # launch
404
  demo.launch()
docs/fine-tuning.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning OmniGen
2
+
3
+ Fine-tuning Omnigen can better help you handle specific image generation tasks. For example, by fine-tuning on a person's images, you can generate multiple pictures of that person while maintaining task consistency.
4
+
5
+ A lot of previous work focused on designing new networks to facilitate specific tasks. For instance, ControlNet was proposed to handle image conditions, and IP-Adapter was constructed to maintain ID features. If you want to perform new tasks, you need to build new architectures and repeatedly debug them. Adding and adjusting extra network parameters is usually time-consuming and labor-intensive, which is not user-friendly and cost-efficient enough. However, with Omnigen, all of this becomes very simple.
6
+
7
+ By comparison, Omnigen can accept multi-modal conditional inputs and has been pre-trained on various tasks. You can fine-tune it on any task without designing specialized networks like ControlNet or IP-Adapter for a specific task.
8
+
9
+ **All you need to do is prepare the data and start training. You can break the limitations of previous models, allowing Omnigen to accomplish a variety of interesting tasks, even those that have never been done before.**
10
+
11
+
12
+ ## Installation
13
+
14
+ ```bash
15
+ git clone https://github.com/VectorSpaceLab/OmniGen.git
16
+ cd OmniGen
17
+ pip install -e .
18
+ ```
19
+
20
+
21
+ ## Full fine-tuning
22
+
23
+ ### Fine-tuning command
24
+
25
+ ```bash
26
+ accelerate launch \
27
+ --num_processes=1 \
28
+ --use_fsdp \
29
+ --fsdp_offload_params false \
30
+ --fsdp_sharding_strategy SHARD_GRAD_OP \
31
+ --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
32
+ --fsdp_transformer_layer_cls_to_wrap Phi3DecoderLayer \
33
+ --fsdp_state_dict_type FULL_STATE_DICT \
34
+ --fsdp_forward_prefetch false \
35
+ --fsdp_use_orig_params True \
36
+ --fsdp_cpu_ram_efficient_loading false \
37
+ --fsdp_sync_module_states True \
38
+ train.py \
39
+ --model_name_or_path Shitao/OmniGen-v1 \
40
+ --json_file ./toy_data/toy_data.jsonl \
41
+ --image_path ./toy_data/images \
42
+ --batch_size_per_device 1 \
43
+ --lr 2e-5 \
44
+ --keep_raw_resolution \
45
+ --max_image_size 1024 \
46
+ --gradient_accumulation_steps 1 \
47
+ --ckpt_every 100 \
48
+ --epochs 100 \
49
+ --log_every 1 \
50
+ --results_dir ./results/toy_finetune
51
+ ```
52
+
53
+ Some important arguments:
54
+ - `num_processes`: number of GPU to use for training
55
+ - `model_name_or_path`: path to the pretrained model
56
+ - `json_file`: path to the json file containing the training data, e.g., ./toy_data/toy_data.jsonl
57
+ - `image_path`: path to the image folder, e.g., ./toy_data/images
58
+ - `batch_size_per_device`: batch size per device
59
+ - `lr`: learning rate
60
+ - `keep_raw_resolution`: whether to keep the original resolution of the image, if not, all images will be resized to (max_image_size, max_image_size)
61
+ - `max_image_size`: max image size
62
+ - `gradient_accumulation_steps`: number of steps to accumulate gradients
63
+ - `ckpt_every`: number of steps to save checkpoint
64
+ - `epochs`: number of epochs
65
+ - `log_every`: number of steps to log
66
+ - `results_dir`: path to the results folder
67
+
68
+ The data format of json_file is as follows:
69
+ ```
70
+ {
71
+ "instruction": str,
72
+ "input_images": [str, str, ...],
73
+ "output_images": str
74
+ }
75
+ ```
76
+ You can see a toy example in `./toy_data/toy_data.jsonl`.
77
+
78
+ If an OOM(Out of Memory) issue occurs, you can try to decrease the `batch_size_per_device` or `max_image_size`. You can also try to use LoRA instead of full fine-tuning.
79
+
80
+
81
+ ### Inference
82
+
83
+ The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load saved checkpoint:
84
+ ```python
85
+ from OmniGen import OmniGenPipeline
86
+
87
+ pipe = OmniGenPipeline.from_pretrained("checkpoint_path") # e.g., ./results/toy_finetune/checkpoints/0000200
88
+ ```
89
+
90
+
91
+
92
+
93
+
94
+ ## LoRA fine-tuning
95
+ LoRA fine-tuning is a simple way to fine-tune OmniGen with less GPU memory. To use lora, you should add `--use_lora` and `--lora_rank` to the command.
96
+
97
+ ```bash
98
+ accelerate launch \
99
+ --num_processes=1 \
100
+ train.py \
101
+ --model_name_or_path Shitao/OmniGen-v1 \
102
+ --batch_size_per_device 2 \
103
+ --condition_dropout_prob 0.01 \
104
+ --lr 3e-4 \
105
+ --use_lora \
106
+ --lora_rank 8 \
107
+ --json_file ./toy_data/toy_data.jsonl \
108
+ --image_path ./toy_data/images \
109
+ --max_input_length_limit 18000 \
110
+ --keep_raw_resolution \
111
+ --max_image_size 1024 \
112
+ --gradient_accumulation_steps 1 \
113
+ --ckpt_every 100 \
114
+ --epochs 100 \
115
+ --log_every 1 \
116
+ --results_dir ./results/toy_finetune_lora
117
+ ```
118
+
119
+ ### Inference
120
+
121
+ The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load checkpoint:
122
+ ```python
123
+ from OmniGen import OmniGenPipeline
124
+
125
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
126
+ pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000100
127
+ ```
128
+
129
+
130
+ ## A simple example
131
+
132
+ Here is an example for learning new concepts: "sks dog". We use five images of one dog from [dog-example](https://huggingface.co/datasets/diffusers/dog-example).
133
+
134
+ The json file is `./toy_data/toy_subject_data.jsonl`, and the images have been saved in `./toy_data/images`.
135
+
136
+ ```bash
137
+ accelerate launch \
138
+ --num_processes=1 \
139
+ train.py \
140
+ --model_name_or_path Shitao/OmniGen-v1 \
141
+ --batch_size_per_device 2 \
142
+ --condition_dropout_prob 0.01 \
143
+ --lr 1e-3 \
144
+ --use_lora \
145
+ --lora_rank 8 \
146
+ --json_file ./toy_data/toy_subject_data.jsonl \
147
+ --image_path ./toy_data/images \
148
+ --max_input_length_limit 18000 \
149
+ --keep_raw_resolution \
150
+ --max_image_size 1024 \
151
+ --gradient_accumulation_steps 1 \
152
+ --ckpt_every 100 \
153
+ --epochs 200 \
154
+ --log_every 1 \
155
+ --results_dir ./results/toy_finetune_lora
156
+ ```
157
+
158
+ After training, you can use the following command to generate images:
159
+ ```python
160
+ from OmniGen import OmniGenPipeline
161
+
162
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
163
+ pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000200
164
+
165
+ images = pipe(
166
+ prompt="a photo of sks dog running in the snow",
167
+ height=1024,
168
+ width=1024,
169
+ guidance_scale=3
170
+ )
171
+ images[0].save("example_sks_dog_snow.png")
172
+ ```
docs/inference.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference with OmniGen
2
+
3
+ To handle some complex tasks, image generation models are becoming increasingly sophisticated, leading to more and more cumbersome workflows. Existing image generation models like SD and Flux require loading many additional network modules (such as ControlNet, IP-Adapter, Reference-Net) and extra preprocessing steps (e.g., face detection, pose detection, image cropping) to generate a satisfactory image. This complex workflow is not user-friendly. We believe that future image generation models should be simpler, generating various images directly through instructions, similar to how GPT works in language generation.
4
+
5
+ Therefore, we propose OmniGen, a model capable of handling various image generation tasks within a single framework. The goal of OmniGen is to complete various image generation tasks without relying on any additional components or image preprocessing steps. OmniGen supports tasks including text-to-image generation, image editing, subject-driven image generation, and classical vision tasks, among others. More capabilities can be found in our examples. We provide inference code so you can explore more unknown functionalities yourself.
6
+
7
+
8
+
9
+ ## Install
10
+ ```bash
11
+ git clone https://github.com/staoxiao/OmniGen.git
12
+ cd OmniGen
13
+ pip install -e .
14
+ ```
15
+
16
+
17
+
18
+ ## Generate Images
19
+ You can use the following code to generate images:
20
+ ```python
21
+ from OmniGen import OmniGenPipeline
22
+
23
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
24
+
25
+ # Text to Image
26
+ images = pipe(
27
+ prompt="A curly-haired man in a red shirt is drinking tea.",
28
+ height=1024,
29
+ width=1024,
30
+ guidance_scale=2.5,
31
+ seed=0,
32
+ )
33
+ images[0].save("example_t2i.png") # save output PIL Image
34
+
35
+ # Multi-modal to Image
36
+ # In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
37
+ # You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
38
+ images = pipe(
39
+ prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
40
+ input_images=["./imgs/test_cases/two_man.jpg"],
41
+ height=1024,
42
+ width=1024,
43
+ guidance_scale=2.5,
44
+ img_guidance_scale=1.6,
45
+ max_input_image_size=1024,
46
+ separate_cfg_infer=True,
47
+ use_kv_cache=True,
48
+ offload_kv_cache=True,
49
+ offload_model=False,
50
+ use_input_image_size_as_output=False,
51
+ seed=0,
52
+ )
53
+ images[0].save("example_ti2i.png") # save output PIL image
54
+ ```
55
+
56
+ Some important arguments:
57
+ - `guidance_scale`: The strength of the guidance. Based on our experience, it is usually best to set it between 2 and 3. The higher the value, the more similar the generated image will be to the prompt. If the image appears oversaturated, please reduce the scale.
58
+ - `height` and `width`: The height and width of the generated image. The default value is 1024x1024. OmniGen support any size, but these number must be divisible by 16.
59
+ - `num_inference_steps`: The number of steps to take in the diffusion process. The higher the value, the more detailed the generated image will be.
60
+ - `max_input_image_size`: the maximum size of input image, which will be used to crop the input image to the maximum size. A smaller number will result in faster generation speed and lower memory cost.
61
+ - `separate_cfg_infer`: Whether to use separate inference process for CFG guidance. If set to True, memory cost will be lower. Default is True.
62
+ - `use_kv_cache`: Whether to use key-value cache. Default is True.
63
+ - `offload_kv_cache`: offload the cached key and value to cpu, which can save memory but slow down the generation silightly. Default is True.
64
+ - `offload_model`: offload the model to cpu, which can save memory but slow down the generation. Default is False.
65
+ - `use_input_image_size_as_output`: whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task. Default is False.
66
+ - `seed`: The seed for random number generator.
67
+
68
+ **More examples please refer to [inference.ipynb](../inference.ipynb)**
69
+
70
+
71
+ #### Input data
72
+ OmniGen can accept multi-modal input data. Specifically, you should pass two arguments: `prompt` and `input_images`.
73
+ For text to image generation, you can pass a string as `prompt`, or pass a list of strings as `prompt` to generate multiple images.
74
+
75
+ For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>`.
76
+ For example, if you want to generate an image with a person holding a bouquet of flowers, you can pass the following prompt:
77
+ ```
78
+ prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."
79
+ input_images = ["./imgs/test_cases/liuyifei.png"]
80
+ ```
81
+ The placeholder `<|image_1|>` will be replaced by the image at `input_images[0]`, i.e., `./imgs/test_cases/liuyifei.png`.
82
+
83
+ If you want to generate multiple images, you can pass a list of prompts and a list of image paths. For example:
84
+ ```
85
+ prompt = ["A woman holds a bouquet of flowers and faces the camera.", "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."]
86
+ input_images = [[], ["./imgs/test_cases/liuyifei.png"]]
87
+ ```
88
+
89
+
90
+ #### Gradio Demo
91
+ We have constructed a online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).
92
+
93
+ For the local gradio demo, you can run with the following command:
94
+ ```python
95
+ python app.py
96
+ ```
97
+
98
+
99
+ ## Tips
100
+ - For out of memory or time cost, you can refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources) to select a appropriate setting.
101
+ - Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
102
+ - Not match the prompt: If the image does not match the prompt, please try to increase the `guidance_scale`.
103
+ - Low-quality: More detailed prompt will lead to better results.
104
+ - Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
105
+ - Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
106
+ - For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
107
+ - For image editing task and controlnet task, we recommend to set the height and width of output image as the same
108
+ as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the `use_input_image_size_as_output` to automatically set the height and width of output image as the same as input image.
109
+
110
+
111
+ ## Requiremented Resources
112
+
113
+ We are currently experimenting with some techniques to reduce memory usage and improve speed, including `use_kv_cache, offload_kv_cache, separate_cfg_infer, offload_model`, which you can enable in the pipeline.
114
+ The default setting is`use_kv_cache=True, offload_kv_cache=True, separate_cfg_infer=True, offload_model=False`.
115
+ To reduce memory consumption while maintaining inference speed, quantization is also a method worth exploring and is left for future work.
116
+
117
+ We conducted experiments on the A800 and RTX 3090. The memory requirements and inference times are shown in the table below. You can choose the appropriate settings based on your available resources.
118
+
119
+ **Overall, the text-to-image task requires minimal memory and time costs, comparable to other latest text-to-image models. However, when using input images, the computational cost increases. Memory usage can be reduced by extending the processing time.**
120
+
121
+
122
+ - Different image size.
123
+
124
+ Different image size (`max_input_image_size` is the max size of input image, `height` and `width` are the size of output image) with the default inference settings (`use_kv_cache=True,offload_kv_cache=True,separate_cfg_infer=True`)
125
+
126
+ For A800 GPU:
127
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
128
+ |:-------------|:----------:|:-------------------:|:---------------------:|
129
+ | max_input_image_size=1024,height=1024,width=1024 | 9G, 31s | 12G, 1m6s | 13G, 1m20s |
130
+ | max_input_image_size=512,height=1024,width=1024 | 9G, 31s | 10G, 50s | 10G, 54s |
131
+ | max_input_image_size=768,height=768,width=768 | 9G, 16s | 10G, 32s | 10G, 37s |
132
+ | max_input_image_size=512,height=512,width=512 | 9G, 7s | 9G, 14s | 9G, 15s |
133
+
134
+ For RTX 3090 GPU(24G):
135
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
136
+ |:-------------|:----------:|:-------------------:|:---------------------:|
137
+ | max_input_image_size=1024,height=1024,width=1024 | 9G, 1m17s | 12G, 2m46s | 13G, 3m23s |
138
+ | max_input_image_size=512,height=1024,width=1024 | 9G, 1m18s | 10G, 2m8s | 10G, 2m18s |
139
+ | max_input_image_size=768,height=768,width=768 | 9G, 41s | 10G, 1m22s | 10G, 1m38s |
140
+ | max_input_image_size=512,height=512,width=512 | 9G, 19s | 9G, 36s | 9G, 43s |
141
+
142
+
143
+ You can set smaller `max_input_image_size` to reduce memory usage, but note that the generation quality may be lower.
144
+ And please set the `height` and `width` the same as the size of input image for image editing task.
145
+
146
+
147
+ - Different inference settings
148
+
149
+ Default image size: height=1024, width=1024, max_input_image_size=1024
150
+
151
+ For A800 GPU:
152
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
153
+ |:-------------|:----------:|:-------------------:|:---------------------:|
154
+ | use_kv_cache | 18G, 30s | 36G, 1m | 48G, 1m13s |
155
+ | use_kv_cache,offload_kv_cache | 10G, 30s | 14G, 1m10s | 17G, 1m30s |
156
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer | 9G, 31s | 12G, 1m6s | 13G, 1m20s |
157
+ | use_kv_cache,offload_kv_cache,offload_model | 4G, 55s | 7G, 1m30s | 11G, 1m48s |
158
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model | 3G, 1m23s | 5G, 2m19s | 6G, 2m30s |
159
+
160
+ For RTX 3090 GPU(24G):
161
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
162
+ |:-------------|:----------:|:-------------------:|:---------------------:|
163
+ | use_kv_cache | 18G, 1m14s | OOM | OOM |
164
+ | use_kv_cache,offload_kv_cache | 10G, 1m17s | 14G, 3m11s | 17G, 4m3s |
165
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer | 9G, 1m18s | 12G, 2m46s | 13G, 3m21s |
166
+ | use_kv_cache,offload_kv_cache,offload_model | 4G,3m1s | 7G, 4m14s | 11G, 5m4s |
167
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model | 3G, 4m56s | 5G, 7m49s | 6G, 8m6s |
edit.png → imgs/demo_cases.png RENAMED
File without changes
imgs/demo_cases/edit.png CHANGED

Git LFS Details

  • SHA256: a83fc3b2ab185a93cb10d207a8776f3a04dc187739d87816cfb33f52d46af502
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB

Git LFS Details

  • SHA256: 2fac5461b2c06a99664ba1299fd9fcebd781a26afa5ebc07aa07cb678ebae2af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
imgs/demo_cases/entity.png CHANGED

Git LFS Details

  • SHA256: 5e18387fa43989515fd18dcb4ce8edeab0e32aa539d6c14ce374cb5790d8f64b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB

Git LFS Details

  • SHA256: 7c622ebecd3210c80e8d913158ee3564168c77c576f04b56e34d2d28bfea9e06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
imgs/demo_cases/t2i_woman_with_book.png CHANGED

Git LFS Details

  • SHA256: 624ae749478b4ced358c6482385fd35271cbfe25eea0581d2a323bffebde8b39
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB

Git LFS Details

  • SHA256: fe258160193adeaff960a838de01d7f7294ab09899de534f2dee99043b0c747a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
imgs/overall.jpg ADDED

Git LFS Details

  • SHA256: ffa229632ac0bb248eee87cf823a0dc18c22c0a81a57d4c639e7fb1986d4e029
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
imgs/referring.png ADDED

Git LFS Details

  • SHA256: 393fab6a4d51e84555f75162430e35a64a49670d9e6c3986cd80bca318a4fb3e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.09 MB
requirements.txt CHANGED
@@ -1,9 +1,11 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
7
- timm
8
- peft
9
- safetensors
 
 
 
1
+ torch==2.3.1
2
+ transformers==4.45.2
3
+ datasets==2.20.0
4
+ accelerate==0.26.1
5
+ jupyter==1.0.0
6
+ numpy==1.26.3
7
+ pillow==10.2.0
8
+ torch==2.3.1
9
+ peft==0.9.0
10
+ diffusers==0.30.3
11
+ timm==0.9.16
setup.py CHANGED
@@ -5,7 +5,7 @@ with open("README.md", mode="r", encoding="utf-8") as readme_file:
5
 
6
  setup(
7
  name='OmniGen',
8
- version='1.0.0',
9
  description='OmniGen',
10
  long_description=readme,
11
  long_description_content_type="text/markdown",
@@ -14,10 +14,13 @@ setup(
14
  packages=find_packages(),
15
  include_package_data=True,
16
  install_requires=[
17
- 'torch>=1.6.0',
18
- 'transformers>=4.41.0',
19
  'datasets',
20
- 'accelerate>=0.20.1',
21
- 'diffusers>=0.30.3'
 
 
 
22
  ],
23
- )
 
5
 
6
  setup(
7
  name='OmniGen',
8
+ version='1.0.3',
9
  description='OmniGen',
10
  long_description=readme,
11
  long_description_content_type="text/markdown",
 
14
  packages=find_packages(),
15
  include_package_data=True,
16
  install_requires=[
17
+ 'torch==2.3.1',
18
+ 'transformers==4.45.2',
19
  'datasets',
20
+ 'accelerate==0.26.1',
21
+ 'diffusers==0.30.3',
22
+ "timm",
23
+ "peft==0.9.0",
24
+ "safetensors"
25
  ],
26
+ )
train.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from time import time
3
+ import argparse
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import math
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from copy import deepcopy
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from torchvision import transforms
18
+
19
+ from accelerate import Accelerator
20
+ from accelerate.utils import ProjectConfiguration, set_seed
21
+ from diffusers.optimization import get_scheduler
22
+ from accelerate.utils import DistributedType
23
+ from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model
24
+ from peft.utils import get_peft_model_state_dict
25
+ from huggingface_hub import snapshot_download
26
+ from safetensors.torch import save_file
27
+
28
+ from diffusers.models import AutoencoderKL
29
+
30
+ from OmniGen import OmniGen, OmniGenProcessor
31
+ from OmniGen.train_helper import DatasetFromJson, TrainDataCollator
32
+ from OmniGen.train_helper import training_losses
33
+ from OmniGen.utils import (
34
+ create_logger,
35
+ update_ema,
36
+ requires_grad,
37
+ center_crop_arr,
38
+ crop_arr,
39
+ vae_encode,
40
+ vae_encode_list
41
+ )
42
+
43
+ def main(args):
44
+ # Setup accelerator:
45
+ from accelerate import DistributedDataParallelKwargs as DDPK
46
+ kwargs = DDPK(find_unused_parameters=False)
47
+ accelerator = Accelerator(
48
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
49
+ mixed_precision=args.mixed_precision,
50
+ log_with=args.report_to,
51
+ project_dir=args.results_dir,
52
+ kwargs_handlers=[kwargs],
53
+ )
54
+ device = accelerator.device
55
+ accelerator.init_trackers("tensorboard_log", config=args.__dict__)
56
+
57
+ # Setup an experiment folder:
58
+ checkpoint_dir = f"{args.results_dir}/checkpoints" # Stores saved model checkpoints
59
+ logger = create_logger(args.results_dir)
60
+ if accelerator.is_main_process:
61
+ os.makedirs(checkpoint_dir, exist_ok=True)
62
+ logger.info(f"Experiment directory created at {args.results_dir}")
63
+ json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w'))
64
+
65
+
66
+ # Create model:
67
+ if not os.path.exists(args.model_name_or_path):
68
+ cache_folder = os.getenv('HF_HUB_CACHE')
69
+ args.model_name_or_path = snapshot_download(repo_id=args.model_name_or_path,
70
+ cache_dir=cache_folder,
71
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
72
+ logger.info(f"Downloaded model to {args.model_name_or_path}")
73
+ model = OmniGen.from_pretrained(args.model_name_or_path)
74
+ model.llm.config.use_cache = False
75
+ model.llm.gradient_checkpointing_enable()
76
+ model = model.to(device)
77
+
78
+ if args.vae_path is None:
79
+ print(args.model_name_or_path)
80
+ vae_path = os.path.join(args.model_name_or_path, "vae")
81
+ if os.path.exists(vae_path):
82
+ vae = AutoencoderKL.from_pretrained(vae_path).to(device)
83
+ else:
84
+ logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF")
85
+ logger.info("If you have VAE in local folder, please specify the path with --vae_path")
86
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
87
+ else:
88
+ vae = AutoencoderKL.from_pretrained(args.vae_path).to(device)
89
+
90
+ weight_dtype = torch.float32
91
+ if accelerator.mixed_precision == "fp16":
92
+ weight_dtype = torch.float16
93
+ elif accelerator.mixed_precision == "bf16":
94
+ weight_dtype = torch.bfloat16
95
+ vae.to(dtype=torch.float32)
96
+ model.to(weight_dtype)
97
+
98
+ processor = OmniGenProcessor.from_pretrained(args.model_name_or_path)
99
+
100
+ requires_grad(vae, False)
101
+ if args.use_lora:
102
+ if accelerator.distributed_type == DistributedType.FSDP:
103
+ raise NotImplementedError("FSDP does not support LoRA")
104
+ requires_grad(model, False)
105
+ transformer_lora_config = LoraConfig(
106
+ r=args.lora_rank,
107
+ lora_alpha=args.lora_rank,
108
+ init_lora_weights="gaussian",
109
+ target_modules=["qkv_proj", "o_proj"],
110
+ )
111
+ model.llm.enable_input_require_grads()
112
+ model = get_peft_model(model, transformer_lora_config)
113
+ model.to(weight_dtype)
114
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
115
+ for n,p in model.named_parameters():
116
+ print(n, p.requires_grad)
117
+ opt = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay)
118
+ else:
119
+ opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay)
120
+
121
+ ema = None
122
+ if args.use_ema:
123
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
124
+ requires_grad(ema, False)
125
+
126
+
127
+ # Setup data:
128
+ crop_func = crop_arr
129
+ if not args.keep_raw_resolution:
130
+ crop_func = center_crop_arr
131
+ image_transform = transforms.Compose([
132
+ transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)),
133
+ transforms.ToTensor(),
134
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
135
+ ])
136
+
137
+ dataset = DatasetFromJson(json_file=args.json_file,
138
+ image_path=args.image_path,
139
+ processer=processor,
140
+ image_transform=image_transform,
141
+ max_input_length_limit=args.max_input_length_limit,
142
+ condition_dropout_prob=args.condition_dropout_prob,
143
+ keep_raw_resolution=args.keep_raw_resolution
144
+ )
145
+ collate_fn = TrainDataCollator(pad_token_id=processor.text_tokenizer.eos_token_id, hidden_size=model.llm.config.hidden_size, keep_raw_resolution=args.keep_raw_resolution)
146
+
147
+ loader = DataLoader(
148
+ dataset,
149
+ collate_fn=collate_fn,
150
+ batch_size=args.batch_size_per_device,
151
+ shuffle=True,
152
+ num_workers=args.num_workers,
153
+ pin_memory=True,
154
+ drop_last=True,
155
+ prefetch_factor=2,
156
+ )
157
+
158
+ if accelerator.is_main_process:
159
+ logger.info(f"Dataset contains {len(dataset):,}")
160
+
161
+ num_update_steps_per_epoch = math.ceil(len(loader) / args.gradient_accumulation_steps)
162
+ max_train_steps = args.epochs * num_update_steps_per_epoch
163
+ lr_scheduler = get_scheduler(
164
+ args.lr_scheduler,
165
+ optimizer=opt,
166
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
167
+ num_training_steps=max_train_steps * args.gradient_accumulation_steps,
168
+ )
169
+
170
+ # Prepare models for training:
171
+ model.train() # important! This enables embedding dropout for classifier-free guidance
172
+
173
+ if ema is not None:
174
+ update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
175
+ ema.eval() # EMA model should always be in eval mode
176
+
177
+
178
+ if ema is not None:
179
+ model, ema = accelerator.prepare(model, ema)
180
+ else:
181
+ model = accelerator.prepare(model)
182
+
183
+ opt, loader, lr_scheduler = accelerator.prepare(opt, loader, lr_scheduler)
184
+
185
+
186
+ # Variables for monitoring/logging purposes:
187
+ train_steps, log_steps = 0, 0
188
+ running_loss = 0
189
+ start_time = time()
190
+
191
+ if accelerator.is_main_process:
192
+ logger.info(f"Training for {args.epochs} epochs...")
193
+ for epoch in range(args.epochs):
194
+ if accelerator.is_main_process:
195
+ logger.info(f"Beginning epoch {epoch}...")
196
+
197
+ for data in loader:
198
+ with accelerator.accumulate(model):
199
+ with torch.no_grad():
200
+ output_images = data['output_images']
201
+ input_pixel_values = data['input_pixel_values']
202
+ if isinstance(output_images, list):
203
+ output_images = vae_encode_list(vae, output_images, weight_dtype)
204
+ if input_pixel_values is not None:
205
+ input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype)
206
+ else:
207
+ output_images = vae_encode(vae, output_images, weight_dtype)
208
+ if input_pixel_values is not None:
209
+ input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype)
210
+
211
+
212
+ model_kwargs = dict(input_ids=data['input_ids'], input_img_latents=input_pixel_values, input_image_sizes=data['input_image_sizes'], attention_mask=data['attention_mask'], position_ids=data['position_ids'], padding_latent=data['padding_images'], past_key_values=None, return_past_key_values=False)
213
+
214
+ loss_dict = training_losses(model, output_images, model_kwargs)
215
+ loss = loss_dict["loss"].mean()
216
+
217
+ running_loss += loss.item()
218
+ accelerator.backward(loss)
219
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
220
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
221
+ opt.step()
222
+ lr_scheduler.step()
223
+ opt.zero_grad()
224
+
225
+ log_steps += 1
226
+ train_steps += 1
227
+
228
+ accelerator.log({"training_loss": loss.item()}, step=train_steps)
229
+ if train_steps % args.gradient_accumulation_steps == 0:
230
+ if accelerator.sync_gradients and ema is not None:
231
+ update_ema(ema, model)
232
+
233
+ if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
234
+ torch.cuda.synchronize()
235
+ end_time = time()
236
+ steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time)
237
+ # Reduce loss history over all processes:
238
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
239
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
240
+ avg_loss = avg_loss.item() / accelerator.num_processes
241
+
242
+ if accelerator.is_main_process:
243
+ cur_lr = opt.param_groups[0]["lr"]
244
+ logger.info(f"(step={int(train_steps/args.gradient_accumulation_steps):07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Epoch: {train_steps/len(loader)}, LR: {cur_lr}")
245
+
246
+ # Reset monitoring variables:
247
+ running_loss = 0
248
+ log_steps = 0
249
+ start_time = time()
250
+
251
+
252
+ if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
253
+ if accelerator.distributed_type == DistributedType.FSDP:
254
+ state_dict = accelerator.get_state_dict(model)
255
+ ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
256
+ else:
257
+ if not args.use_lora:
258
+ state_dict = model.module.state_dict()
259
+ ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
260
+
261
+ if accelerator.is_main_process:
262
+ if args.use_lora:
263
+ checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/"
264
+ os.makedirs(checkpoint_path, exist_ok=True)
265
+
266
+ model.module.save_pretrained(checkpoint_path)
267
+ else:
268
+ checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/"
269
+ os.makedirs(checkpoint_path, exist_ok=True)
270
+ torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
271
+ processor.text_tokenizer.save_pretrained(checkpoint_path)
272
+ model.llm.config.save_pretrained(checkpoint_path)
273
+ if ema_state_dict is not None:
274
+ checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema"
275
+ os.makedirs(checkpoint_path, exist_ok=True)
276
+ torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
277
+ processor.text_tokenizer.save_pretrained(checkpoint_path)
278
+ model.llm.config.save_pretrained(checkpoint_path)
279
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
280
+
281
+ dist.barrier()
282
+ accelerator.end_training()
283
+ model.eval()
284
+
285
+ if accelerator.is_main_process:
286
+ logger.info("Done!")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ parser = argparse.ArgumentParser()
291
+ parser.add_argument("--results_dir", type=str, default="results")
292
+ parser.add_argument("--model_name_or_path", type=str, default="OmniGen")
293
+ parser.add_argument("--json_file", type=str)
294
+ parser.add_argument("--image_path", type=str, default=None)
295
+ parser.add_argument("--epochs", type=int, default=1400)
296
+ parser.add_argument("--batch_size_per_device", type=int, default=1)
297
+ parser.add_argument("--vae_path", type=str, default=None)
298
+ parser.add_argument("--num_workers", type=int, default=4)
299
+ parser.add_argument("--log_every", type=int, default=100)
300
+ parser.add_argument("--ckpt_every", type=int, default=20000)
301
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
302
+ parser.add_argument("--lr", type=float, default=1e-4)
303
+ parser.add_argument("--max_input_length_limit", type=int, default=1024)
304
+ parser.add_argument("--condition_dropout_prob", type=float, default=0.1)
305
+ parser.add_argument("--adam_weight_decay", type=float, default=0.0)
306
+ parser.add_argument(
307
+ "--keep_raw_resolution",
308
+ action="store_true",
309
+ help="multiple_resolutions",
310
+ )
311
+ parser.add_argument("--max_image_size", type=int, default=1344)
312
+
313
+ parser.add_argument(
314
+ "--use_lora",
315
+ action="store_true",
316
+ )
317
+ parser.add_argument(
318
+ "--lora_rank",
319
+ type=int,
320
+ default=8
321
+ )
322
+
323
+ parser.add_argument(
324
+ "--use_ema",
325
+ action="store_true",
326
+ help="Whether or not to use ema.",
327
+ )
328
+ parser.add_argument(
329
+ "--lr_scheduler",
330
+ type=str,
331
+ default="constant",
332
+ help=(
333
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
334
+ ' "constant", "constant_with_warmup"]'
335
+ ),
336
+ )
337
+ parser.add_argument(
338
+ "--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler."
339
+ )
340
+ parser.add_argument(
341
+ "--report_to",
342
+ type=str,
343
+ default="tensorboard",
344
+ help=(
345
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
346
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
347
+ ),
348
+ )
349
+ parser.add_argument(
350
+ "--mixed_precision",
351
+ type=str,
352
+ default="bf16",
353
+ choices=["no", "fp16", "bf16"],
354
+ help=(
355
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
356
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
357
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
358
+ ),
359
+ )
360
+ parser.add_argument(
361
+ "--gradient_accumulation_steps",
362
+ type=int,
363
+ default=1,
364
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
365
+ )
366
+
367
+
368
+ args = parser.parse_args()
369
+ assert args.max_image_size % 16 == 0, "Image size must be divisible by 16."
370
+
371
+ main(args)
372
+
373
+