yrr
commited on
Commit
•
a713a09
1
Parent(s):
0d3229d
update inference code
Browse files- LICENSE +21 -0
- OmniGen/model.py +6 -12
- OmniGen/pipeline.py +94 -17
- OmniGen/processor.py +5 -2
- OmniGen/scheduler.py +137 -11
- OmniGen/transformer.py +44 -9
- README.md +180 -13
- app.py +106 -49
- docs/fine-tuning.md +172 -0
- docs/inference.md +167 -0
- edit.png → imgs/demo_cases.png +2 -2
- imgs/demo_cases/edit.png +2 -2
- imgs/demo_cases/entity.png +2 -2
- imgs/demo_cases/t2i_woman_with_book.png +2 -2
- imgs/overall.jpg +3 -0
- imgs/referring.png +3 -0
- requirements.txt +11 -9
- setup.py +9 -6
- train.py +373 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 VectorSpaceLab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
OmniGen/model.py
CHANGED
@@ -312,7 +312,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
312 |
return latents, num_tokens, shapes
|
313 |
|
314 |
|
315 |
-
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
|
316 |
"""
|
317 |
|
318 |
"""
|
@@ -335,7 +335,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
335 |
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
336 |
else:
|
337 |
input_emb = torch.cat([time_token, x], dim=1)
|
338 |
-
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
|
339 |
output, past_key_values = output.last_hidden_state, output.past_key_values
|
340 |
if input_is_list:
|
341 |
image_embedding = output[:, -max(num_tokens):]
|
@@ -357,12 +357,9 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
357 |
return latents
|
358 |
|
359 |
@torch.no_grad()
|
360 |
-
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
361 |
-
"""
|
362 |
-
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
363 |
-
"""
|
364 |
self.llm.config.use_cache = use_kv_cache
|
365 |
-
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
|
366 |
if use_img_cfg:
|
367 |
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
368 |
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
@@ -376,10 +373,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
376 |
|
377 |
|
378 |
@torch.no_grad()
|
379 |
-
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache,
|
380 |
-
"""
|
381 |
-
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
382 |
-
"""
|
383 |
self.llm.config.use_cache = use_kv_cache
|
384 |
if past_key_values is None:
|
385 |
past_key_values = [None] * len(attention_mask)
|
@@ -390,7 +384,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
390 |
|
391 |
model_out, pask_key_values = [], []
|
392 |
for i in range(len(input_ids)):
|
393 |
-
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])
|
394 |
model_out.append(temp_out)
|
395 |
pask_key_values.append(temp_pask_key_values)
|
396 |
|
|
|
312 |
return latents, num_tokens, shapes
|
313 |
|
314 |
|
315 |
+
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
316 |
"""
|
317 |
|
318 |
"""
|
|
|
335 |
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
336 |
else:
|
337 |
input_emb = torch.cat([time_token, x], dim=1)
|
338 |
+
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
339 |
output, past_key_values = output.last_hidden_state, output.past_key_values
|
340 |
if input_is_list:
|
341 |
image_embedding = output[:, -max(num_tokens):]
|
|
|
357 |
return latents
|
358 |
|
359 |
@torch.no_grad()
|
360 |
+
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
|
|
|
|
|
361 |
self.llm.config.use_cache = use_kv_cache
|
362 |
+
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
363 |
if use_img_cfg:
|
364 |
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
365 |
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
|
373 |
|
374 |
|
375 |
@torch.no_grad()
|
376 |
+
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
|
|
|
|
|
377 |
self.llm.config.use_cache = use_kv_cache
|
378 |
if past_key_values is None:
|
379 |
past_key_values = [None] * len(attention_mask)
|
|
|
384 |
|
385 |
model_out, pask_key_values = [], []
|
386 |
for i in range(len(input_ids)):
|
387 |
+
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
388 |
model_out.append(temp_out)
|
389 |
pask_key_values.append(temp_pask_key_values)
|
390 |
|
OmniGen/pipeline.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import inspect
|
3 |
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
4 |
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
@@ -33,7 +34,7 @@ EXAMPLE_DOC_STRING = """
|
|
33 |
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
34 |
>>> image = pipe(
|
35 |
... prompt,
|
36 |
-
... guidance_scale=
|
37 |
... num_inference_steps=50,
|
38 |
... ).images[0]
|
39 |
>>> image.save("t2i.png")
|
@@ -41,7 +42,7 @@ EXAMPLE_DOC_STRING = """
|
|
41 |
"""
|
42 |
|
43 |
|
44 |
-
|
45 |
class OmniGenPipeline:
|
46 |
def __init__(
|
47 |
self,
|
@@ -53,10 +54,21 @@ class OmniGenPipeline:
|
|
53 |
self.model = model
|
54 |
self.processor = processor
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
self.model.eval()
|
59 |
-
self.vae.
|
|
|
|
|
60 |
|
61 |
@classmethod
|
62 |
def from_pretrained(cls, model_name, vae_path: str=None):
|
@@ -84,7 +96,6 @@ class OmniGenPipeline:
|
|
84 |
model = PeftModel.from_pretrained(self.model, lora_path)
|
85 |
model.merge_and_unload()
|
86 |
|
87 |
-
|
88 |
self.model = model
|
89 |
|
90 |
def to(self, device: Union[str, torch.device]):
|
@@ -92,6 +103,7 @@ class OmniGenPipeline:
|
|
92 |
device = torch.device(device)
|
93 |
self.model.to(device)
|
94 |
self.vae.to(device)
|
|
|
95 |
|
96 |
def vae_encode(self, x, dtype):
|
97 |
if self.vae.config.shift_factor is not None:
|
@@ -107,6 +119,17 @@ class OmniGenPipeline:
|
|
107 |
return [x.to(self.device) for x in data]
|
108 |
return data.to(self.device)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
@torch.no_grad()
|
112 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
@@ -120,8 +143,12 @@ class OmniGenPipeline:
|
|
120 |
guidance_scale: float = 3,
|
121 |
use_img_guidance: bool = True,
|
122 |
img_guidance_scale: float = 1.6,
|
123 |
-
|
|
|
|
|
124 |
use_kv_cache: bool = True,
|
|
|
|
|
125 |
dtype: torch.dtype = torch.bfloat16,
|
126 |
seed: int = None,
|
127 |
):
|
@@ -149,31 +176,50 @@ class OmniGenPipeline:
|
|
149 |
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
150 |
img_guidance_scale (`float`, *optional*, defaults to 1.6):
|
151 |
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
|
|
152 |
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
153 |
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
154 |
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
Examples:
|
159 |
|
160 |
Returns:
|
161 |
A list with the generated images.
|
162 |
"""
|
163 |
-
|
164 |
-
if
|
165 |
-
|
166 |
-
|
|
|
167 |
if input_images is None:
|
168 |
use_img_guidance = False
|
169 |
if isinstance(prompt, str):
|
170 |
prompt = [prompt]
|
171 |
input_images = [input_images] if input_images is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer)
|
174 |
|
175 |
num_prompt = len(prompt)
|
176 |
num_cfg = 2 if use_img_guidance else 1
|
|
|
|
|
|
|
|
|
|
|
177 |
latent_size_h, latent_size_w = height//8, width//8
|
178 |
|
179 |
if seed is not None:
|
@@ -183,6 +229,7 @@ class OmniGenPipeline:
|
|
183 |
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
|
184 |
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
|
185 |
|
|
|
186 |
input_img_latents = []
|
187 |
if separate_cfg_infer:
|
188 |
for temp_pixel_values in input_data['input_pixel_values']:
|
@@ -195,6 +242,10 @@ class OmniGenPipeline:
|
|
195 |
for img in input_data['input_pixel_values']:
|
196 |
img = self.vae_encode(img.to(self.device), dtype)
|
197 |
input_img_latents.append(img)
|
|
|
|
|
|
|
|
|
198 |
|
199 |
model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
|
200 |
input_img_latents=input_img_latents,
|
@@ -204,7 +255,9 @@ class OmniGenPipeline:
|
|
204 |
cfg_scale=guidance_scale,
|
205 |
img_cfg_scale=img_guidance_scale,
|
206 |
use_img_cfg=use_img_guidance,
|
207 |
-
use_kv_cache=use_kv_cache
|
|
|
|
|
208 |
|
209 |
if separate_cfg_infer:
|
210 |
func = self.model.forward_with_separate_cfg
|
@@ -212,16 +265,38 @@ class OmniGenPipeline:
|
|
212 |
func = self.model.forward_with_cfg
|
213 |
self.model.to(dtype)
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
scheduler = OmniGenScheduler(num_steps=num_inference_steps)
|
216 |
-
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache)
|
217 |
samples = samples.chunk((1+num_cfg), dim=0)[0]
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
samples = samples.to(torch.float32)
|
220 |
if self.vae.config.shift_factor is not None:
|
221 |
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
222 |
else:
|
223 |
samples = samples / self.vae.config.scaling_factor
|
224 |
samples = self.vae.decode(samples).sample
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
|
227 |
output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
@@ -229,4 +304,6 @@ class OmniGenPipeline:
|
|
229 |
for i, sample in enumerate(output_samples):
|
230 |
output_images.append(Image.fromarray(sample))
|
231 |
|
|
|
|
|
232 |
return output_images
|
|
|
1 |
import os
|
2 |
import inspect
|
3 |
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
+
import gc
|
5 |
|
6 |
from PIL import Image
|
7 |
import numpy as np
|
|
|
34 |
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
35 |
>>> image = pipe(
|
36 |
... prompt,
|
37 |
+
... guidance_scale=2.5,
|
38 |
... num_inference_steps=50,
|
39 |
... ).images[0]
|
40 |
>>> image.save("t2i.png")
|
|
|
42 |
"""
|
43 |
|
44 |
|
45 |
+
90
|
46 |
class OmniGenPipeline:
|
47 |
def __init__(
|
48 |
self,
|
|
|
54 |
self.model = model
|
55 |
self.processor = processor
|
56 |
|
57 |
+
if torch.cuda.is_available():
|
58 |
+
self.device = torch.device("cuda")
|
59 |
+
elif torch.backends.mps.is_available():
|
60 |
+
self.device = torch.device("mps")
|
61 |
+
elif is_torch_npu_available():
|
62 |
+
self.device = torch.device("npu")
|
63 |
+
else:
|
64 |
+
logger.info("Don't detect any available devices, using CPU instead")
|
65 |
+
self.device = torch.device("cpu")
|
66 |
+
|
67 |
+
self.model.to(torch.bfloat16)
|
68 |
self.model.eval()
|
69 |
+
self.vae.eval()
|
70 |
+
|
71 |
+
self.model_cpu_offload = False
|
72 |
|
73 |
@classmethod
|
74 |
def from_pretrained(cls, model_name, vae_path: str=None):
|
|
|
96 |
model = PeftModel.from_pretrained(self.model, lora_path)
|
97 |
model.merge_and_unload()
|
98 |
|
|
|
99 |
self.model = model
|
100 |
|
101 |
def to(self, device: Union[str, torch.device]):
|
|
|
103 |
device = torch.device(device)
|
104 |
self.model.to(device)
|
105 |
self.vae.to(device)
|
106 |
+
self.device = device
|
107 |
|
108 |
def vae_encode(self, x, dtype):
|
109 |
if self.vae.config.shift_factor is not None:
|
|
|
119 |
return [x.to(self.device) for x in data]
|
120 |
return data.to(self.device)
|
121 |
|
122 |
+
def enable_model_cpu_offload(self):
|
123 |
+
self.model_cpu_offload = True
|
124 |
+
self.model.to("cpu")
|
125 |
+
self.vae.to("cpu")
|
126 |
+
torch.cuda.empty_cache() # Clear VRAM
|
127 |
+
gc.collect() # Run garbage collection to free system RAM
|
128 |
+
|
129 |
+
def disable_model_cpu_offload(self):
|
130 |
+
self.model_cpu_offload = False
|
131 |
+
self.model.to(self.device)
|
132 |
+
self.vae.to(self.device)
|
133 |
|
134 |
@torch.no_grad()
|
135 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
143 |
guidance_scale: float = 3,
|
144 |
use_img_guidance: bool = True,
|
145 |
img_guidance_scale: float = 1.6,
|
146 |
+
max_input_image_size: int = 1024,
|
147 |
+
separate_cfg_infer: bool = True,
|
148 |
+
offload_model: bool = False,
|
149 |
use_kv_cache: bool = True,
|
150 |
+
offload_kv_cache: bool = True,
|
151 |
+
use_input_image_size_as_output: bool = False,
|
152 |
dtype: torch.dtype = torch.bfloat16,
|
153 |
seed: int = None,
|
154 |
):
|
|
|
176 |
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
177 |
img_guidance_scale (`float`, *optional*, defaults to 1.6):
|
178 |
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
179 |
+
max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
|
180 |
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
181 |
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
182 |
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
183 |
+
offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
|
184 |
+
offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
|
185 |
+
use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
|
186 |
+
seed (`int`, *optional*):
|
187 |
+
A random seed for generating output.
|
188 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
189 |
+
data type for the model
|
190 |
Examples:
|
191 |
|
192 |
Returns:
|
193 |
A list with the generated images.
|
194 |
"""
|
195 |
+
# check inputs:
|
196 |
+
if use_input_image_size_as_output:
|
197 |
+
assert isinstance(prompt, str) and len(input_images) == 1, "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images"
|
198 |
+
else:
|
199 |
+
assert height%16 == 0 and width%16 == 0, "The height and width must be a multiple of 16."
|
200 |
if input_images is None:
|
201 |
use_img_guidance = False
|
202 |
if isinstance(prompt, str):
|
203 |
prompt = [prompt]
|
204 |
input_images = [input_images] if input_images is not None else None
|
205 |
+
|
206 |
+
# set model and processor
|
207 |
+
if max_input_image_size != self.processor.max_image_size:
|
208 |
+
self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size)
|
209 |
+
if offload_model:
|
210 |
+
self.enable_model_cpu_offload()
|
211 |
+
else:
|
212 |
+
self.disable_model_cpu_offload()
|
213 |
|
214 |
+
input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer, use_input_image_size_as_output=use_input_image_size_as_output)
|
215 |
|
216 |
num_prompt = len(prompt)
|
217 |
num_cfg = 2 if use_img_guidance else 1
|
218 |
+
if use_input_image_size_as_output:
|
219 |
+
if separate_cfg_infer:
|
220 |
+
height, width = input_data['input_pixel_values'][0][0].shape[-2:]
|
221 |
+
else:
|
222 |
+
height, width = input_data['input_pixel_values'][0].shape[-2:]
|
223 |
latent_size_h, latent_size_w = height//8, width//8
|
224 |
|
225 |
if seed is not None:
|
|
|
229 |
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
|
230 |
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
|
231 |
|
232 |
+
if input_images is not None and self.model_cpu_offload: self.vae.to(self.device)
|
233 |
input_img_latents = []
|
234 |
if separate_cfg_infer:
|
235 |
for temp_pixel_values in input_data['input_pixel_values']:
|
|
|
242 |
for img in input_data['input_pixel_values']:
|
243 |
img = self.vae_encode(img.to(self.device), dtype)
|
244 |
input_img_latents.append(img)
|
245 |
+
if input_images is not None and self.model_cpu_offload:
|
246 |
+
self.vae.to('cpu')
|
247 |
+
torch.cuda.empty_cache() # Clear VRAM
|
248 |
+
gc.collect() # Run garbage collection to free system RAM
|
249 |
|
250 |
model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
|
251 |
input_img_latents=input_img_latents,
|
|
|
255 |
cfg_scale=guidance_scale,
|
256 |
img_cfg_scale=img_guidance_scale,
|
257 |
use_img_cfg=use_img_guidance,
|
258 |
+
use_kv_cache=use_kv_cache,
|
259 |
+
offload_model=offload_model,
|
260 |
+
)
|
261 |
|
262 |
if separate_cfg_infer:
|
263 |
func = self.model.forward_with_separate_cfg
|
|
|
265 |
func = self.model.forward_with_cfg
|
266 |
self.model.to(dtype)
|
267 |
|
268 |
+
if self.model_cpu_offload:
|
269 |
+
for name, param in self.model.named_parameters():
|
270 |
+
if 'layers' in name and 'layers.0' not in name:
|
271 |
+
param.data = param.data.cpu()
|
272 |
+
else:
|
273 |
+
param.data = param.data.to(self.device)
|
274 |
+
for buffer_name, buffer in self.model.named_buffers():
|
275 |
+
setattr(self.model, buffer_name, buffer.to(self.device))
|
276 |
+
# else:
|
277 |
+
# self.model.to(self.device)
|
278 |
+
|
279 |
scheduler = OmniGenScheduler(num_steps=num_inference_steps)
|
280 |
+
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
|
281 |
samples = samples.chunk((1+num_cfg), dim=0)[0]
|
282 |
|
283 |
+
if self.model_cpu_offload:
|
284 |
+
self.model.to('cpu')
|
285 |
+
torch.cuda.empty_cache()
|
286 |
+
gc.collect()
|
287 |
+
|
288 |
+
self.vae.to(self.device)
|
289 |
samples = samples.to(torch.float32)
|
290 |
if self.vae.config.shift_factor is not None:
|
291 |
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
292 |
else:
|
293 |
samples = samples / self.vae.config.scaling_factor
|
294 |
samples = self.vae.decode(samples).sample
|
295 |
+
|
296 |
+
if self.model_cpu_offload:
|
297 |
+
self.vae.to('cpu')
|
298 |
+
torch.cuda.empty_cache()
|
299 |
+
gc.collect()
|
300 |
|
301 |
output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
|
302 |
output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
|
|
304 |
for i, sample in enumerate(output_samples):
|
305 |
output_images.append(Image.fromarray(sample))
|
306 |
|
307 |
+
torch.cuda.empty_cache() # Clear VRAM
|
308 |
+
gc.collect() # Run garbage collection to free system RAM
|
309 |
return output_images
|
OmniGen/processor.py
CHANGED
@@ -108,6 +108,7 @@ class OmniGenProcessor:
|
|
108 |
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
109 |
use_img_cfg: bool = True,
|
110 |
separate_cfg_input: bool = False,
|
|
|
111 |
) -> Dict:
|
112 |
|
113 |
if input_images is None:
|
@@ -138,7 +139,10 @@ class OmniGenProcessor:
|
|
138 |
else:
|
139 |
img_cfg_mllm_input = neg_mllm_input
|
140 |
|
141 |
-
|
|
|
|
|
|
|
142 |
|
143 |
if separate_cfg_input:
|
144 |
return self.separate_collator(input_data)
|
@@ -295,7 +299,6 @@ class OmniGenSeparateCollator(OmniGenCollator):
|
|
295 |
cfg_mllm_inputs = [f[1] for f in features]
|
296 |
img_cfg_mllm_input = [f[2] for f in features]
|
297 |
target_img_size = [f[3] for f in features]
|
298 |
-
|
299 |
|
300 |
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
|
301 |
|
|
|
108 |
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
109 |
use_img_cfg: bool = True,
|
110 |
separate_cfg_input: bool = False,
|
111 |
+
use_input_image_size_as_output: bool=False,
|
112 |
) -> Dict:
|
113 |
|
114 |
if input_images is None:
|
|
|
139 |
else:
|
140 |
img_cfg_mllm_input = neg_mllm_input
|
141 |
|
142 |
+
if use_input_image_size_as_output:
|
143 |
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
|
144 |
+
else:
|
145 |
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
146 |
|
147 |
if separate_cfg_input:
|
148 |
return self.separate_collator(input_data)
|
|
|
299 |
cfg_mllm_inputs = [f[1] for f in features]
|
300 |
img_cfg_mllm_input = [f[2] for f in features]
|
301 |
target_img_size = [f[3] for f in features]
|
|
|
302 |
|
303 |
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
|
304 |
|
OmniGen/scheduler.py
CHANGED
@@ -1,6 +1,116 @@
|
|
1 |
-
import torch
|
2 |
from tqdm import tqdm
|
3 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class OmniGenScheduler:
|
6 |
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
@@ -12,12 +122,13 @@ class OmniGenScheduler:
|
|
12 |
self.sigma = t
|
13 |
|
14 |
def crop_kv_cache(self, past_key_values, num_tokens_for_img):
|
|
|
15 |
crop_past_key_values = ()
|
16 |
for layer_idx in range(len(past_key_values)):
|
17 |
key_states, value_states = past_key_values[layer_idx][:2]
|
18 |
crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
|
19 |
-
return crop_past_key_values
|
20 |
-
|
21 |
|
22 |
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
23 |
if isinstance(position_ids, list):
|
@@ -32,24 +143,39 @@ class OmniGenScheduler:
|
|
32 |
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
33 |
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
34 |
|
35 |
-
def
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
for i in tqdm(range(self.num_steps)):
|
38 |
timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
|
39 |
-
pred,
|
40 |
sigma_next = self.sigma[i+1]
|
41 |
sigma = self.sigma[i]
|
42 |
z = z + (sigma_next - sigma) * pred
|
43 |
if i == 0 and use_kv_cache:
|
44 |
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
45 |
-
if isinstance(
|
46 |
-
|
47 |
-
model_kwargs['input_ids'] = [None] * len(temp_past_key_values)
|
48 |
else:
|
49 |
-
past_key_values = self.crop_kv_cache(temp_past_key_values, num_tokens_for_img)
|
50 |
model_kwargs['input_ids'] = None
|
51 |
|
52 |
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
53 |
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
|
|
|
|
|
|
|
|
54 |
return z
|
55 |
|
|
|
|
|
|
1 |
from tqdm import tqdm
|
2 |
+
from typing import Optional, Dict, Any, Tuple, List
|
3 |
+
import gc
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class OmniGenCache(DynamicCache):
|
11 |
+
def __init__(self,
|
12 |
+
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
|
13 |
+
if not torch.cuda.is_available():
|
14 |
+
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
15 |
+
super().__init__()
|
16 |
+
self.original_device = []
|
17 |
+
self.prefetch_stream = torch.cuda.Stream()
|
18 |
+
self.num_tokens_for_img = num_tokens_for_img
|
19 |
+
self.offload_kv_cache = offload_kv_cache
|
20 |
+
|
21 |
+
def prefetch_layer(self, layer_idx: int):
|
22 |
+
"Starts prefetching the next layer cache"
|
23 |
+
if layer_idx < len(self):
|
24 |
+
with torch.cuda.stream(self.prefetch_stream):
|
25 |
+
# Prefetch next layer tensors to GPU
|
26 |
+
device = self.original_device[layer_idx]
|
27 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
28 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
29 |
+
|
30 |
+
|
31 |
+
def evict_previous_layer(self, layer_idx: int):
|
32 |
+
"Moves the previous layer cache to the CPU"
|
33 |
+
if len(self) > 2:
|
34 |
+
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
|
35 |
+
if layer_idx == 0:
|
36 |
+
prev_layer_idx = -1
|
37 |
+
else:
|
38 |
+
prev_layer_idx = (layer_idx - 1) % len(self)
|
39 |
+
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
40 |
+
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
41 |
+
|
42 |
+
|
43 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
44 |
+
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
45 |
+
if layer_idx < len(self):
|
46 |
+
if self.offload_kv_cache:
|
47 |
+
# Evict the previous layer if necessary
|
48 |
+
torch.cuda.current_stream().synchronize()
|
49 |
+
self.evict_previous_layer(layer_idx)
|
50 |
+
# Load current layer cache to its original device if not already there
|
51 |
+
original_device = self.original_device[layer_idx]
|
52 |
+
# self.prefetch_stream.synchronize(original_device)
|
53 |
+
torch.cuda.synchronize(self.prefetch_stream)
|
54 |
+
key_tensor = self.key_cache[layer_idx]
|
55 |
+
value_tensor = self.value_cache[layer_idx]
|
56 |
+
|
57 |
+
# Prefetch the next layer
|
58 |
+
self.prefetch_layer((layer_idx + 1) % len(self))
|
59 |
+
else:
|
60 |
+
key_tensor = self.key_cache[layer_idx]
|
61 |
+
value_tensor = self.value_cache[layer_idx]
|
62 |
+
return (key_tensor, value_tensor)
|
63 |
+
else:
|
64 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
65 |
+
|
66 |
+
|
67 |
+
def update(
|
68 |
+
self,
|
69 |
+
key_states: torch.Tensor,
|
70 |
+
value_states: torch.Tensor,
|
71 |
+
layer_idx: int,
|
72 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
73 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
74 |
+
"""
|
75 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
76 |
+
Parameters:
|
77 |
+
key_states (`torch.Tensor`):
|
78 |
+
The new key states to cache.
|
79 |
+
value_states (`torch.Tensor`):
|
80 |
+
The new value states to cache.
|
81 |
+
layer_idx (`int`):
|
82 |
+
The index of the layer to cache the states for.
|
83 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
84 |
+
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
|
85 |
+
Return:
|
86 |
+
A tuple containing the updated key and value states.
|
87 |
+
"""
|
88 |
+
# Update the cache
|
89 |
+
if len(self.key_cache) < layer_idx:
|
90 |
+
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
|
91 |
+
elif len(self.key_cache) == layer_idx:
|
92 |
+
# only cache the states for condition tokens
|
93 |
+
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
|
94 |
+
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
|
95 |
+
|
96 |
+
# Update the number of seen tokens
|
97 |
+
if layer_idx == 0:
|
98 |
+
self._seen_tokens += key_states.shape[-2]
|
99 |
+
|
100 |
+
self.key_cache.append(key_states)
|
101 |
+
self.value_cache.append(value_states)
|
102 |
+
self.original_device.append(key_states.device)
|
103 |
+
if self.offload_kv_cache:
|
104 |
+
self.evict_previous_layer(layer_idx)
|
105 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
106 |
+
else:
|
107 |
+
# only cache the states for condition tokens
|
108 |
+
key_tensor, value_tensor = self[layer_idx]
|
109 |
+
k = torch.cat([key_tensor, key_states], dim=-2)
|
110 |
+
v = torch.cat([value_tensor, value_states], dim=-2)
|
111 |
+
return k, v
|
112 |
+
|
113 |
+
|
114 |
|
115 |
class OmniGenScheduler:
|
116 |
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
|
|
122 |
self.sigma = t
|
123 |
|
124 |
def crop_kv_cache(self, past_key_values, num_tokens_for_img):
|
125 |
+
# return
|
126 |
crop_past_key_values = ()
|
127 |
for layer_idx in range(len(past_key_values)):
|
128 |
key_states, value_states = past_key_values[layer_idx][:2]
|
129 |
crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
|
130 |
+
# return crop_past_key_values
|
131 |
+
return DynamicCache.from_legacy_cache(crop_past_key_values)
|
132 |
|
133 |
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
134 |
if isinstance(position_ids, list):
|
|
|
143 |
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
144 |
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
145 |
|
146 |
+
def crop_cache(self, cache, num_tokens_for_img):
|
147 |
+
for i in range(len(cache.key_cache)):
|
148 |
+
cache.key_cache[i] = cache.key_cache[i][..., :-(num_tokens_for_img+1), :]
|
149 |
+
cache.value_cache[i] = cache.value_cache[i][..., :-(num_tokens_for_img+1), :]
|
150 |
+
|
151 |
+
return cache
|
152 |
+
|
153 |
+
def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True, offload_kv_cache: bool=True):
|
154 |
+
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
155 |
+
if isinstance(model_kwargs['input_ids'], list):
|
156 |
+
cache = [OmniGenCache(num_tokens_for_img, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
|
157 |
+
else:
|
158 |
+
cache = OmniGenCache(num_tokens_for_img, offload_kv_cache) if use_kv_cache else None
|
159 |
+
results = {}
|
160 |
for i in tqdm(range(self.num_steps)):
|
161 |
timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
|
162 |
+
pred, cache = func(z, timesteps, past_key_values=cache, **model_kwargs)
|
163 |
sigma_next = self.sigma[i+1]
|
164 |
sigma = self.sigma[i]
|
165 |
z = z + (sigma_next - sigma) * pred
|
166 |
if i == 0 and use_kv_cache:
|
167 |
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
168 |
+
if isinstance(cache, list):
|
169 |
+
model_kwargs['input_ids'] = [None] * len(cache)
|
|
|
170 |
else:
|
|
|
171 |
model_kwargs['input_ids'] = None
|
172 |
|
173 |
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
174 |
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
175 |
+
|
176 |
+
del cache
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
gc.collect()
|
179 |
return z
|
180 |
|
181 |
+
|
OmniGen/transformer.py
CHANGED
@@ -29,6 +29,34 @@ class Phi3Transformer(Phi3Model):
|
|
29 |
Args:
|
30 |
config: Phi3Config
|
31 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def forward(
|
34 |
self,
|
@@ -42,6 +70,7 @@ class Phi3Transformer(Phi3Model):
|
|
42 |
output_hidden_states: Optional[bool] = None,
|
43 |
return_dict: Optional[bool] = None,
|
44 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
45 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
46 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
47 |
output_hidden_states = (
|
@@ -75,16 +104,16 @@ class Phi3Transformer(Phi3Model):
|
|
75 |
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
76 |
)
|
77 |
|
78 |
-
if inputs_embeds is None:
|
79 |
-
|
80 |
|
81 |
-
if cache_position is None:
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
if position_ids is None:
|
87 |
-
|
88 |
|
89 |
if attention_mask is not None and attention_mask.dim() == 3:
|
90 |
dtype = inputs_embeds.dtype
|
@@ -104,7 +133,10 @@ class Phi3Transformer(Phi3Model):
|
|
104 |
all_self_attns = () if output_attentions else None
|
105 |
next_decoder_cache = None
|
106 |
|
|
|
107 |
for decoder_layer in self.layers:
|
|
|
|
|
108 |
if output_hidden_states:
|
109 |
all_hidden_states += (hidden_states,)
|
110 |
|
@@ -120,6 +152,8 @@ class Phi3Transformer(Phi3Model):
|
|
120 |
cache_position,
|
121 |
)
|
122 |
else:
|
|
|
|
|
123 |
layer_outputs = decoder_layer(
|
124 |
hidden_states,
|
125 |
attention_mask=attention_mask,
|
@@ -142,6 +176,7 @@ class Phi3Transformer(Phi3Model):
|
|
142 |
|
143 |
# add hidden states from the last decoder layer
|
144 |
if output_hidden_states:
|
|
|
145 |
all_hidden_states += (hidden_states,)
|
146 |
|
147 |
next_cache = next_decoder_cache if use_cache else None
|
|
|
29 |
Args:
|
30 |
config: Phi3Config
|
31 |
"""
|
32 |
+
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
33 |
+
"Starts prefetching the next layer cache"
|
34 |
+
with torch.cuda.stream(self.prefetch_stream):
|
35 |
+
# Prefetch next layer tensors to GPU
|
36 |
+
for name, param in self.layers[layer_idx].named_parameters():
|
37 |
+
param.data = param.data.to(device, non_blocking=True)
|
38 |
+
|
39 |
+
def evict_previous_layer(self, layer_idx: int):
|
40 |
+
"Moves the previous layer cache to the CPU"
|
41 |
+
prev_layer_idx = layer_idx - 1
|
42 |
+
for name, param in self.layers[prev_layer_idx].named_parameters():
|
43 |
+
param.data = param.data.to("cpu", non_blocking=True)
|
44 |
+
|
45 |
+
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
46 |
+
# init stream
|
47 |
+
if not hasattr(self, "prefetch_stream"):
|
48 |
+
self.prefetch_stream = torch.cuda.Stream()
|
49 |
+
|
50 |
+
# delete previous layer
|
51 |
+
torch.cuda.current_stream().synchronize()
|
52 |
+
self.evict_previous_layer(layer_idx)
|
53 |
+
|
54 |
+
# make sure the current layer is ready
|
55 |
+
torch.cuda.synchronize(self.prefetch_stream)
|
56 |
+
|
57 |
+
# load next layer
|
58 |
+
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
59 |
+
|
60 |
|
61 |
def forward(
|
62 |
self,
|
|
|
70 |
output_hidden_states: Optional[bool] = None,
|
71 |
return_dict: Optional[bool] = None,
|
72 |
cache_position: Optional[torch.LongTensor] = None,
|
73 |
+
offload_model: Optional[bool] = False,
|
74 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
75 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
76 |
output_hidden_states = (
|
|
|
104 |
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
105 |
)
|
106 |
|
107 |
+
# if inputs_embeds is None:
|
108 |
+
# inputs_embeds = self.embed_tokens(input_ids)
|
109 |
|
110 |
+
# if cache_position is None:
|
111 |
+
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
112 |
+
# cache_position = torch.arange(
|
113 |
+
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
114 |
+
# )
|
115 |
+
# if position_ids is None:
|
116 |
+
# position_ids = cache_position.unsqueeze(0)
|
117 |
|
118 |
if attention_mask is not None and attention_mask.dim() == 3:
|
119 |
dtype = inputs_embeds.dtype
|
|
|
133 |
all_self_attns = () if output_attentions else None
|
134 |
next_decoder_cache = None
|
135 |
|
136 |
+
layer_idx = -1
|
137 |
for decoder_layer in self.layers:
|
138 |
+
layer_idx += 1
|
139 |
+
|
140 |
if output_hidden_states:
|
141 |
all_hidden_states += (hidden_states,)
|
142 |
|
|
|
152 |
cache_position,
|
153 |
)
|
154 |
else:
|
155 |
+
if offload_model and not self.training:
|
156 |
+
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
157 |
layer_outputs = decoder_layer(
|
158 |
hidden_states,
|
159 |
attention_mask=attention_mask,
|
|
|
176 |
|
177 |
# add hidden states from the last decoder layer
|
178 |
if output_hidden_states:
|
179 |
+
print('************')
|
180 |
all_hidden_states += (hidden_states,)
|
181 |
|
182 |
next_cache = next_decoder_cache if use_cache else None
|
README.md
CHANGED
@@ -1,13 +1,180 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">OmniGen: Unified Image Generation</h1>
|
2 |
+
|
3 |
+
|
4 |
+
<p align="center">
|
5 |
+
<a href="">
|
6 |
+
<img alt="Build" src="https://img.shields.io/badge/Project%20Page-OmniGen-yellow">
|
7 |
+
</a>
|
8 |
+
<a href="https://arxiv.org/abs/2409.11340">
|
9 |
+
<img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2409.11340-b31b1b.svg">
|
10 |
+
</a>
|
11 |
+
<a href="https://huggingface.co/spaces/Shitao/OmniGen">
|
12 |
+
<img alt="License" src="https://img.shields.io/badge/HF%20Demo-🤗-lightblue">
|
13 |
+
</a>
|
14 |
+
<a href="https://huggingface.co/Shitao/OmniGen-v1">
|
15 |
+
<img alt="Build" src="https://img.shields.io/badge/HF%20Model-🤗-yellow">
|
16 |
+
</a>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
<h4 align="center">
|
20 |
+
<p>
|
21 |
+
<a href=#1-news>News</a> |
|
22 |
+
<a href=#3-methodology>Methodology</a> |
|
23 |
+
<a href=#4-what-can-omnigen-do>Capabilities</a> |
|
24 |
+
<a href=#5-quick-start>Quick Start</a> |
|
25 |
+
<a href="#6-finetune">Finetune</a> |
|
26 |
+
<a href="#license">License</a> |
|
27 |
+
<a href="#citation">Citation</a>
|
28 |
+
<p>
|
29 |
+
</h4>
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
## 1. News
|
34 |
+
- 2024-10-28: We release new version of inference code, optimizing the memory usage and time cost. You can refer to [docs/inference.md](docs/inference.md#requiremented-resources) for detailed information.
|
35 |
+
- 2024-10-22: :fire: We release the code for OmniGen. Inference: [docs/inference.md](docs/inference.md) Train: [docs/fine-tuning.md](docs/fine-tuning.md)
|
36 |
+
- 2024-10-22: :fire: We release the first version of OmniGen. Model Weight: [Shitao/OmniGen-v1](https://huggingface.co/Shitao/OmniGen-v1) HF Demo: [🤗](https://huggingface.co/spaces/Shitao/OmniGen)
|
37 |
+
|
38 |
+
|
39 |
+
## 2. Overview
|
40 |
+
|
41 |
+
OmniGen is a unified image generation model that can generate a wide range of images from multi-modal prompts. It is designed to be simple, flexible and easy to use. We provide [inference code](#5-quick-start) so that everyone can explore more functionalities of OmniGen.
|
42 |
+
|
43 |
+
Existing image generation models often require loading several additional network modules (such as ControlNet, IP-Adapter, Reference-Net, etc.) and performing extra preprocessing steps (e.g., face detection, pose estimation, cropping, etc.) to generate a satisfactory image. However, **we believe that the future image generation paradigm should be more simple and flexible, that is, generating various images directly through arbitrarily multi-modal instructions without the need for additional plugins and operations, similar to how GPT works in language generation.**
|
44 |
+
|
45 |
+
Due to the limited resources, OmniGen still has room for improvement. We will continue to optimize it, and hope it inspire more universal image generation models. You can also easily fine-tune OmniGen without worrying about designing networks for specific tasks; you just need to prepare the corresponding data, and then run the [script](#6-finetune). Imagination is no longer limited; everyone can construct any image generation task, and perhaps we can achieve very interesting, wonderful and creative things.
|
46 |
+
|
47 |
+
If you have any questions, ideas or interesting tasks you want OmniGen to accomplish, feel free to discuss with us: 2906698981@qq.com, wangyueze@tju.edu.cn, zhengliu1026@gmail.com. We welcome any feedback to help us improve the model.
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
## 3. Methodology
|
53 |
+
|
54 |
+
You can see details in our [paper](https://arxiv.org/abs/2409.11340).
|
55 |
+
|
56 |
+
|
57 |
+
## 4. What Can OmniGen do?
|
58 |
+
|
59 |
+
|
60 |
+
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, image editing, and image-conditioned generation. **OmniGen don't need additional plugins or operations, it can automatically identify the features (e.g., required object, human pose, depth mapping) in input images according the text prompt.**
|
61 |
+
We showcase some examples in [inference.ipynb](inference.ipynb). And in [inference_demo.ipynb](inference_demo.ipynb), we show an interesting pipeline to generate and modify a image.
|
62 |
+
|
63 |
+
Here is the illustration of OmniGen's capabilities:
|
64 |
+
- You can control the image generation flexibly via OmniGen
|
65 |
+
![demo](./imgs/demo_cases.png)
|
66 |
+
- Referring Expression Generation: You can input multiple images and use simple, general language to refer to the objects within those images. OmniGen can automatically recognize the necessary objects in each image and generate new images based on them. No additional operations, such as image cropping or face detection, are required.
|
67 |
+
![demo](./imgs/referring.png)
|
68 |
+
|
69 |
+
If you are not entirely satisfied with certain functionalities or wish to add new capabilities, you can try [fine-tuning OmniGen](#6-finetune).
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
## 5. Quick Start
|
74 |
+
|
75 |
+
|
76 |
+
### Using OmniGen
|
77 |
+
Install via Github:
|
78 |
+
```bash
|
79 |
+
git clone https://github.com/staoxiao/OmniGen.git
|
80 |
+
cd OmniGen
|
81 |
+
pip install -e .
|
82 |
+
```
|
83 |
+
|
84 |
+
Here are some examples:
|
85 |
+
```python
|
86 |
+
from OmniGen import OmniGenPipeline
|
87 |
+
|
88 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
89 |
+
|
90 |
+
# Text to Image
|
91 |
+
images = pipe(
|
92 |
+
prompt="A curly-haired man in a red shirt is drinking tea.",
|
93 |
+
height=1024,
|
94 |
+
width=1024,
|
95 |
+
guidance_scale=2.5,
|
96 |
+
seed=0,
|
97 |
+
)
|
98 |
+
images[0].save("example_t2i.png") # save output PIL Image
|
99 |
+
|
100 |
+
# Multi-modal to Image
|
101 |
+
# In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
|
102 |
+
# You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
|
103 |
+
images = pipe(
|
104 |
+
prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
105 |
+
input_images=["./imgs/test_cases/two_man.jpg"],
|
106 |
+
height=1024,
|
107 |
+
width=1024,
|
108 |
+
guidance_scale=2.5,
|
109 |
+
img_guidance_scale=1.6,
|
110 |
+
seed=0
|
111 |
+
)
|
112 |
+
images[0].save("example_ti2i.png") # save output PIL image
|
113 |
+
```
|
114 |
+
- For thre required resources and the method to run OmniGen efficiently, please refer to [docs/inference.md#requiremented-resources](docs/inference.md#requiremented-resources).
|
115 |
+
- For more examples for image generation, you can refer to [inference.ipynb](inference.ipynb) and [inference_demo.ipynb](inference_demo.ipynb)
|
116 |
+
- For more details about the argument in inference, please refer to [docs/inference.md](docs/inference.md).
|
117 |
+
|
118 |
+
|
119 |
+
### Using Diffusers
|
120 |
+
Coming soon.
|
121 |
+
|
122 |
+
|
123 |
+
### Gradio Demo
|
124 |
+
|
125 |
+
We construct an online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).
|
126 |
+
|
127 |
+
For the local gradio demo, you need to install `pip install gradio spaces` , and then you can run:
|
128 |
+
```python
|
129 |
+
pip install gradio spaces
|
130 |
+
python app.py
|
131 |
+
```
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## 6. Finetune
|
136 |
+
We provide a training script `train.py` to fine-tune OmniGen.
|
137 |
+
Here is a toy example about LoRA finetune:
|
138 |
+
```bash
|
139 |
+
accelerate launch --num_processes=1 train.py \
|
140 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
141 |
+
--batch_size_per_device 2 \
|
142 |
+
--condition_dropout_prob 0.01 \
|
143 |
+
--lr 1e-3 \
|
144 |
+
--use_lora \
|
145 |
+
--lora_rank 8 \
|
146 |
+
--json_file ./toy_data/toy_subject_data.jsonl \
|
147 |
+
--image_path ./toy_data/images \
|
148 |
+
--max_input_length_limit 18000 \
|
149 |
+
--keep_raw_resolution \
|
150 |
+
--max_image_size 1024 \
|
151 |
+
--gradient_accumulation_steps 1 \
|
152 |
+
--ckpt_every 10 \
|
153 |
+
--epochs 200 \
|
154 |
+
--log_every 1 \
|
155 |
+
--results_dir ./results/toy_finetune_lora
|
156 |
+
```
|
157 |
+
|
158 |
+
Please refer to [docs/fine-tuning.md](docs/fine-tuning.md) for more details (e.g. full finetune).
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
## License
|
163 |
+
This repo is licensed under the [MIT License](LICENSE).
|
164 |
+
|
165 |
+
|
166 |
+
## Citation
|
167 |
+
If you find this repository useful, please consider giving a star ⭐ and citation
|
168 |
+
```
|
169 |
+
@article{xiao2024omnigen,
|
170 |
+
title={Omnigen: Unified image generation},
|
171 |
+
author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
|
172 |
+
journal={arXiv preprint arXiv:2409.11340},
|
173 |
+
year={2024}
|
174 |
+
}
|
175 |
+
```
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import os
|
|
|
4 |
import spaces
|
5 |
|
6 |
from OmniGen import OmniGenPipeline
|
@@ -9,11 +10,11 @@ pipe = OmniGenPipeline.from_pretrained(
|
|
9 |
"Shitao/OmniGen-v1"
|
10 |
)
|
11 |
|
12 |
-
@spaces.GPU(duration=
|
13 |
-
|
14 |
-
|
15 |
input_images = [img1, img2, img3]
|
16 |
-
#
|
17 |
input_images = [img for img in input_images if img is not None]
|
18 |
if len(input_images) == 0:
|
19 |
input_images = None
|
@@ -24,25 +25,18 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_gu
|
|
24 |
height=height,
|
25 |
width=width,
|
26 |
guidance_scale=guidance_scale,
|
27 |
-
img_guidance_scale=
|
28 |
num_inference_steps=inference_steps,
|
29 |
-
separate_cfg_infer=
|
30 |
-
use_kv_cache=
|
|
|
|
|
|
|
31 |
seed=seed,
|
32 |
-
# separate_cfg_infer=separate_cfg_infer,
|
33 |
)
|
34 |
img = output[0]
|
35 |
return img
|
36 |
-
|
37 |
-
# input_images = []
|
38 |
-
# if img1:
|
39 |
-
# input_images.append(Image.open(img1))
|
40 |
-
# if img2:
|
41 |
-
# input_images.append(Image.open(img2))
|
42 |
-
# if img3:
|
43 |
-
# input_images.append(Image.open(img3))
|
44 |
-
|
45 |
-
# return input_images[0] if input_images else None
|
46 |
|
47 |
|
48 |
def get_example():
|
@@ -59,6 +53,8 @@ def get_example():
|
|
59 |
50,
|
60 |
0,
|
61 |
True,
|
|
|
|
|
62 |
],
|
63 |
[
|
64 |
"The woman in <img><|image_1|></img> waves her hand happily in the crowd",
|
@@ -72,6 +68,8 @@ def get_example():
|
|
72 |
50,
|
73 |
128,
|
74 |
True,
|
|
|
|
|
75 |
],
|
76 |
[
|
77 |
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
@@ -85,6 +83,8 @@ def get_example():
|
|
85 |
50,
|
86 |
0,
|
87 |
True,
|
|
|
|
|
88 |
],
|
89 |
[
|
90 |
"Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
|
@@ -98,6 +98,8 @@ def get_example():
|
|
98 |
50,
|
99 |
168,
|
100 |
True,
|
|
|
|
|
101 |
],
|
102 |
[
|
103 |
"A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
|
@@ -111,6 +113,8 @@ def get_example():
|
|
111 |
50,
|
112 |
60,
|
113 |
True,
|
|
|
|
|
114 |
],
|
115 |
[
|
116 |
"A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
|
@@ -124,6 +128,8 @@ def get_example():
|
|
124 |
50,
|
125 |
66,
|
126 |
True,
|
|
|
|
|
127 |
],
|
128 |
[
|
129 |
"The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
|
@@ -137,147 +143,185 @@ def get_example():
|
|
137 |
50,
|
138 |
0,
|
139 |
True,
|
|
|
|
|
140 |
],
|
141 |
[
|
142 |
"<img><|image_1|><img>\n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.",
|
143 |
"./imgs/demo_cases/t2i_woman_with_book.png",
|
144 |
None,
|
145 |
None,
|
146 |
-
|
147 |
-
|
148 |
2.5,
|
149 |
1.6,
|
150 |
50,
|
151 |
222,
|
152 |
True,
|
|
|
|
|
153 |
],
|
154 |
[
|
155 |
"Detect the skeleton of human in this image: <img><|image_1|></img>.",
|
156 |
"./imgs/test_cases/control.jpg",
|
157 |
None,
|
158 |
None,
|
159 |
-
|
160 |
-
|
161 |
2.0,
|
162 |
1.6,
|
163 |
50,
|
164 |
0,
|
165 |
True,
|
|
|
|
|
166 |
],
|
167 |
[
|
168 |
"Generate a new photo using the following picture and text as conditions: <img><|image_1|><img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
169 |
"./imgs/demo_cases/skeletal.png",
|
170 |
None,
|
171 |
None,
|
172 |
-
|
173 |
-
|
174 |
2,
|
175 |
1.6,
|
176 |
50,
|
177 |
42,
|
178 |
True,
|
|
|
|
|
179 |
],
|
180 |
[
|
181 |
"Following the pose of this image <img><|image_1|><img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
182 |
"./imgs/demo_cases/edit.png",
|
183 |
None,
|
184 |
None,
|
185 |
-
|
186 |
-
|
187 |
2.0,
|
188 |
1.6,
|
189 |
50,
|
190 |
123,
|
191 |
True,
|
|
|
|
|
192 |
],
|
193 |
[
|
194 |
"Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
195 |
"./imgs/demo_cases/edit.png",
|
196 |
None,
|
197 |
None,
|
198 |
-
|
199 |
-
|
200 |
2.0,
|
201 |
1.6,
|
202 |
50,
|
203 |
1,
|
204 |
True,
|
|
|
|
|
205 |
],
|
206 |
[
|
207 |
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
208 |
"./imgs/test_cases/watch.jpg",
|
209 |
None,
|
210 |
None,
|
211 |
-
|
212 |
-
|
213 |
2.5,
|
214 |
1.6,
|
215 |
50,
|
216 |
0,
|
217 |
True,
|
|
|
|
|
218 |
],
|
219 |
[
|
220 |
"According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
|
221 |
"./imgs/test_cases/icl1.jpg",
|
222 |
"./imgs/test_cases/icl2.jpg",
|
223 |
"./imgs/test_cases/icl3.jpg",
|
224 |
-
|
225 |
-
|
226 |
2.5,
|
227 |
1.6,
|
228 |
50,
|
229 |
1,
|
230 |
True,
|
|
|
|
|
231 |
],
|
232 |
]
|
233 |
return case
|
234 |
|
235 |
-
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer,
|
236 |
-
|
|
|
|
|
237 |
|
238 |
description = """
|
239 |
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
240 |
-
|
241 |
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
242 |
For example, use an image of a woman to generate a new image:
|
243 |
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
244 |
|
245 |
Tips:
|
|
|
246 |
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
|
|
247 |
- Low-quality: More detailed prompt will lead to better results.
|
248 |
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
249 |
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
250 |
-
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
"""
|
252 |
|
253 |
-
separate_cfg_infer_arg = False
|
254 |
|
255 |
-
# Gradio
|
256 |
with gr.Blocks() as demo:
|
257 |
gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
|
258 |
gr.Markdown(description)
|
259 |
with gr.Row():
|
260 |
with gr.Column():
|
261 |
-
#
|
262 |
prompt_input = gr.Textbox(
|
263 |
label="Enter your prompt, use <img><|image_i|></img> to represent i-th input image", placeholder="Type your prompt here..."
|
264 |
)
|
265 |
|
266 |
with gr.Row(equal_height=True):
|
267 |
-
#
|
268 |
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
|
269 |
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
|
270 |
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
|
271 |
|
272 |
-
#
|
273 |
height_input = gr.Slider(
|
274 |
-
label="Height", minimum=
|
275 |
)
|
276 |
width_input = gr.Slider(
|
277 |
-
label="Width", minimum=
|
278 |
)
|
279 |
|
280 |
-
# 引导尺度输入
|
281 |
guidance_scale_input = gr.Slider(
|
282 |
label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
|
283 |
)
|
@@ -295,17 +339,24 @@ with gr.Blocks() as demo:
|
|
295 |
)
|
296 |
|
297 |
separate_cfg_infer = gr.Checkbox(
|
298 |
-
label="separate_cfg_infer", info="
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
)
|
300 |
|
301 |
-
#
|
302 |
generate_button = gr.Button("Generate Image")
|
|
|
303 |
|
304 |
with gr.Column():
|
305 |
-
#
|
306 |
output_image = gr.Image(label="Output Image")
|
307 |
|
308 |
-
#
|
309 |
generate_button.click(
|
310 |
generate_image,
|
311 |
inputs=[
|
@@ -320,6 +371,8 @@ with gr.Blocks() as demo:
|
|
320 |
num_inference_steps,
|
321 |
seed_input,
|
322 |
separate_cfg_infer,
|
|
|
|
|
323 |
],
|
324 |
outputs=output_image,
|
325 |
)
|
@@ -339,9 +392,13 @@ with gr.Blocks() as demo:
|
|
339 |
num_inference_steps,
|
340 |
seed_input,
|
341 |
separate_cfg_infer,
|
|
|
|
|
342 |
],
|
343 |
outputs=output_image,
|
344 |
)
|
345 |
|
346 |
-
|
|
|
|
|
347 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import os
|
4 |
+
|
5 |
import spaces
|
6 |
|
7 |
from OmniGen import OmniGenPipeline
|
|
|
10 |
"Shitao/OmniGen-v1"
|
11 |
)
|
12 |
|
13 |
+
@spaces.GPU(duration=300)
|
14 |
+
def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
|
15 |
+
use_input_image_size_as_output):
|
16 |
input_images = [img1, img2, img3]
|
17 |
+
# Delete None
|
18 |
input_images = [img for img in input_images if img is not None]
|
19 |
if len(input_images) == 0:
|
20 |
input_images = None
|
|
|
25 |
height=height,
|
26 |
width=width,
|
27 |
guidance_scale=guidance_scale,
|
28 |
+
img_guidance_scale=img_guidance_scale,
|
29 |
num_inference_steps=inference_steps,
|
30 |
+
separate_cfg_infer=separate_cfg_infer,
|
31 |
+
use_kv_cache=True,
|
32 |
+
offload_kv_cache=True,
|
33 |
+
offload_model=offload_model,
|
34 |
+
use_input_image_size_as_output=use_input_image_size_as_output,
|
35 |
seed=seed,
|
|
|
36 |
)
|
37 |
img = output[0]
|
38 |
return img
|
39 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def get_example():
|
|
|
53 |
50,
|
54 |
0,
|
55 |
True,
|
56 |
+
False,
|
57 |
+
False,
|
58 |
],
|
59 |
[
|
60 |
"The woman in <img><|image_1|></img> waves her hand happily in the crowd",
|
|
|
68 |
50,
|
69 |
128,
|
70 |
True,
|
71 |
+
False,
|
72 |
+
False,
|
73 |
],
|
74 |
[
|
75 |
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
|
|
83 |
50,
|
84 |
0,
|
85 |
True,
|
86 |
+
False,
|
87 |
+
False,
|
88 |
],
|
89 |
[
|
90 |
"Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
|
|
|
98 |
50,
|
99 |
168,
|
100 |
True,
|
101 |
+
False,
|
102 |
+
False,
|
103 |
],
|
104 |
[
|
105 |
"A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
|
|
|
113 |
50,
|
114 |
60,
|
115 |
True,
|
116 |
+
False,
|
117 |
+
False,
|
118 |
],
|
119 |
[
|
120 |
"A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
|
|
|
128 |
50,
|
129 |
66,
|
130 |
True,
|
131 |
+
False,
|
132 |
+
False,
|
133 |
],
|
134 |
[
|
135 |
"The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
|
|
|
143 |
50,
|
144 |
0,
|
145 |
True,
|
146 |
+
False,
|
147 |
+
False,
|
148 |
],
|
149 |
[
|
150 |
"<img><|image_1|><img>\n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.",
|
151 |
"./imgs/demo_cases/t2i_woman_with_book.png",
|
152 |
None,
|
153 |
None,
|
154 |
+
None,
|
155 |
+
None,
|
156 |
2.5,
|
157 |
1.6,
|
158 |
50,
|
159 |
222,
|
160 |
True,
|
161 |
+
False,
|
162 |
+
True,
|
163 |
],
|
164 |
[
|
165 |
"Detect the skeleton of human in this image: <img><|image_1|></img>.",
|
166 |
"./imgs/test_cases/control.jpg",
|
167 |
None,
|
168 |
None,
|
169 |
+
None,
|
170 |
+
None,
|
171 |
2.0,
|
172 |
1.6,
|
173 |
50,
|
174 |
0,
|
175 |
True,
|
176 |
+
False,
|
177 |
+
True,
|
178 |
],
|
179 |
[
|
180 |
"Generate a new photo using the following picture and text as conditions: <img><|image_1|><img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
181 |
"./imgs/demo_cases/skeletal.png",
|
182 |
None,
|
183 |
None,
|
184 |
+
None,
|
185 |
+
None,
|
186 |
2,
|
187 |
1.6,
|
188 |
50,
|
189 |
42,
|
190 |
True,
|
191 |
+
False,
|
192 |
+
True,
|
193 |
],
|
194 |
[
|
195 |
"Following the pose of this image <img><|image_1|><img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
196 |
"./imgs/demo_cases/edit.png",
|
197 |
None,
|
198 |
None,
|
199 |
+
None,
|
200 |
+
None,
|
201 |
2.0,
|
202 |
1.6,
|
203 |
50,
|
204 |
123,
|
205 |
True,
|
206 |
+
False,
|
207 |
+
True,
|
208 |
],
|
209 |
[
|
210 |
"Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
211 |
"./imgs/demo_cases/edit.png",
|
212 |
None,
|
213 |
None,
|
214 |
+
None,
|
215 |
+
None,
|
216 |
2.0,
|
217 |
1.6,
|
218 |
50,
|
219 |
1,
|
220 |
True,
|
221 |
+
False,
|
222 |
+
True,
|
223 |
],
|
224 |
[
|
225 |
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
226 |
"./imgs/test_cases/watch.jpg",
|
227 |
None,
|
228 |
None,
|
229 |
+
None,
|
230 |
+
None,
|
231 |
2.5,
|
232 |
1.6,
|
233 |
50,
|
234 |
0,
|
235 |
True,
|
236 |
+
False,
|
237 |
+
True,
|
238 |
],
|
239 |
[
|
240 |
"According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
|
241 |
"./imgs/test_cases/icl1.jpg",
|
242 |
"./imgs/test_cases/icl2.jpg",
|
243 |
"./imgs/test_cases/icl3.jpg",
|
244 |
+
224,
|
245 |
+
224,
|
246 |
2.5,
|
247 |
1.6,
|
248 |
50,
|
249 |
1,
|
250 |
True,
|
251 |
+
False,
|
252 |
+
False,
|
253 |
],
|
254 |
]
|
255 |
return case
|
256 |
|
257 |
+
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
|
258 |
+
use_input_image_size_as_output):
|
259 |
+
return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
|
260 |
+
use_input_image_size_as_output)
|
261 |
|
262 |
description = """
|
263 |
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
|
|
264 |
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
265 |
For example, use an image of a woman to generate a new image:
|
266 |
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
267 |
|
268 |
Tips:
|
269 |
+
- For out of memory or time cost, you can refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources) to select a appropriate setting.
|
270 |
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
271 |
+
- Not match the prompt: If the image does not match the prompt, please try to increase the `guidance_scale`.
|
272 |
- Low-quality: More detailed prompt will lead to better results.
|
273 |
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
274 |
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
275 |
+
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
276 |
+
- For image editing task and controlnet task, we recommend to set the height and width of output image as the same as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the `use_input_image_size_as_output` to automatically set the height and width of output image as the same as input image.
|
277 |
+
|
278 |
+
|
279 |
+
"""
|
280 |
+
|
281 |
+
article = """
|
282 |
+
---
|
283 |
+
**Citation**
|
284 |
+
<br>
|
285 |
+
If you find this repository useful, please consider giving a star ⭐ and citation
|
286 |
+
```
|
287 |
+
@article{xiao2024omnigen,
|
288 |
+
title={Omnigen: Unified image generation},
|
289 |
+
author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
|
290 |
+
journal={arXiv preprint arXiv:2409.11340},
|
291 |
+
year={2024}
|
292 |
+
}
|
293 |
+
```
|
294 |
+
**Contact**
|
295 |
+
<br>
|
296 |
+
If you have any questions, please feel free to open an issue or directly reach us out via email.
|
297 |
"""
|
298 |
|
|
|
299 |
|
300 |
+
# Gradio
|
301 |
with gr.Blocks() as demo:
|
302 |
gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
|
303 |
gr.Markdown(description)
|
304 |
with gr.Row():
|
305 |
with gr.Column():
|
306 |
+
# text prompt
|
307 |
prompt_input = gr.Textbox(
|
308 |
label="Enter your prompt, use <img><|image_i|></img> to represent i-th input image", placeholder="Type your prompt here..."
|
309 |
)
|
310 |
|
311 |
with gr.Row(equal_height=True):
|
312 |
+
# input images
|
313 |
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
|
314 |
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
|
315 |
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
|
316 |
|
317 |
+
# slider
|
318 |
height_input = gr.Slider(
|
319 |
+
label="Height", minimum=128, maximum=2048, value=1024, step=16
|
320 |
)
|
321 |
width_input = gr.Slider(
|
322 |
+
label="Width", minimum=128, maximum=2048, value=1024, step=16
|
323 |
)
|
324 |
|
|
|
325 |
guidance_scale_input = gr.Slider(
|
326 |
label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
|
327 |
)
|
|
|
339 |
)
|
340 |
|
341 |
separate_cfg_infer = gr.Checkbox(
|
342 |
+
label="separate_cfg_infer", info="Whether to use separate inference process for different guidance. This will reduce the memory cost.", value=True,
|
343 |
+
)
|
344 |
+
offload_model = gr.Checkbox(
|
345 |
+
label="offload_model", info="Offload model to CPU, which will significantly reduce the memory cost but slow down the generation speed. You can cancle separate_cfg_infer and set offload_model=True. If both separate_cfg_infer and offload_model be True, further reduce the memory, but slowest generation", value=False,
|
346 |
+
)
|
347 |
+
use_input_image_size_as_output = gr.Checkbox(
|
348 |
+
label="use_input_image_size_as_output", info="Automatically adjust the output image size to be same as input image size. For editing and controlnet task, it can make sure the output image has the same size with input image leading to better performance", value=False,
|
349 |
)
|
350 |
|
351 |
+
# generate
|
352 |
generate_button = gr.Button("Generate Image")
|
353 |
+
|
354 |
|
355 |
with gr.Column():
|
356 |
+
# output image
|
357 |
output_image = gr.Image(label="Output Image")
|
358 |
|
359 |
+
# click
|
360 |
generate_button.click(
|
361 |
generate_image,
|
362 |
inputs=[
|
|
|
371 |
num_inference_steps,
|
372 |
seed_input,
|
373 |
separate_cfg_infer,
|
374 |
+
offload_model,
|
375 |
+
use_input_image_size_as_output,
|
376 |
],
|
377 |
outputs=output_image,
|
378 |
)
|
|
|
392 |
num_inference_steps,
|
393 |
seed_input,
|
394 |
separate_cfg_infer,
|
395 |
+
offload_model,
|
396 |
+
use_input_image_size_as_output,
|
397 |
],
|
398 |
outputs=output_image,
|
399 |
)
|
400 |
|
401 |
+
gr.Markdown(article)
|
402 |
+
|
403 |
+
# launch
|
404 |
demo.launch()
|
docs/fine-tuning.md
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fine-tuning OmniGen
|
2 |
+
|
3 |
+
Fine-tuning Omnigen can better help you handle specific image generation tasks. For example, by fine-tuning on a person's images, you can generate multiple pictures of that person while maintaining task consistency.
|
4 |
+
|
5 |
+
A lot of previous work focused on designing new networks to facilitate specific tasks. For instance, ControlNet was proposed to handle image conditions, and IP-Adapter was constructed to maintain ID features. If you want to perform new tasks, you need to build new architectures and repeatedly debug them. Adding and adjusting extra network parameters is usually time-consuming and labor-intensive, which is not user-friendly and cost-efficient enough. However, with Omnigen, all of this becomes very simple.
|
6 |
+
|
7 |
+
By comparison, Omnigen can accept multi-modal conditional inputs and has been pre-trained on various tasks. You can fine-tune it on any task without designing specialized networks like ControlNet or IP-Adapter for a specific task.
|
8 |
+
|
9 |
+
**All you need to do is prepare the data and start training. You can break the limitations of previous models, allowing Omnigen to accomplish a variety of interesting tasks, even those that have never been done before.**
|
10 |
+
|
11 |
+
|
12 |
+
## Installation
|
13 |
+
|
14 |
+
```bash
|
15 |
+
git clone https://github.com/VectorSpaceLab/OmniGen.git
|
16 |
+
cd OmniGen
|
17 |
+
pip install -e .
|
18 |
+
```
|
19 |
+
|
20 |
+
|
21 |
+
## Full fine-tuning
|
22 |
+
|
23 |
+
### Fine-tuning command
|
24 |
+
|
25 |
+
```bash
|
26 |
+
accelerate launch \
|
27 |
+
--num_processes=1 \
|
28 |
+
--use_fsdp \
|
29 |
+
--fsdp_offload_params false \
|
30 |
+
--fsdp_sharding_strategy SHARD_GRAD_OP \
|
31 |
+
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
|
32 |
+
--fsdp_transformer_layer_cls_to_wrap Phi3DecoderLayer \
|
33 |
+
--fsdp_state_dict_type FULL_STATE_DICT \
|
34 |
+
--fsdp_forward_prefetch false \
|
35 |
+
--fsdp_use_orig_params True \
|
36 |
+
--fsdp_cpu_ram_efficient_loading false \
|
37 |
+
--fsdp_sync_module_states True \
|
38 |
+
train.py \
|
39 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
40 |
+
--json_file ./toy_data/toy_data.jsonl \
|
41 |
+
--image_path ./toy_data/images \
|
42 |
+
--batch_size_per_device 1 \
|
43 |
+
--lr 2e-5 \
|
44 |
+
--keep_raw_resolution \
|
45 |
+
--max_image_size 1024 \
|
46 |
+
--gradient_accumulation_steps 1 \
|
47 |
+
--ckpt_every 100 \
|
48 |
+
--epochs 100 \
|
49 |
+
--log_every 1 \
|
50 |
+
--results_dir ./results/toy_finetune
|
51 |
+
```
|
52 |
+
|
53 |
+
Some important arguments:
|
54 |
+
- `num_processes`: number of GPU to use for training
|
55 |
+
- `model_name_or_path`: path to the pretrained model
|
56 |
+
- `json_file`: path to the json file containing the training data, e.g., ./toy_data/toy_data.jsonl
|
57 |
+
- `image_path`: path to the image folder, e.g., ./toy_data/images
|
58 |
+
- `batch_size_per_device`: batch size per device
|
59 |
+
- `lr`: learning rate
|
60 |
+
- `keep_raw_resolution`: whether to keep the original resolution of the image, if not, all images will be resized to (max_image_size, max_image_size)
|
61 |
+
- `max_image_size`: max image size
|
62 |
+
- `gradient_accumulation_steps`: number of steps to accumulate gradients
|
63 |
+
- `ckpt_every`: number of steps to save checkpoint
|
64 |
+
- `epochs`: number of epochs
|
65 |
+
- `log_every`: number of steps to log
|
66 |
+
- `results_dir`: path to the results folder
|
67 |
+
|
68 |
+
The data format of json_file is as follows:
|
69 |
+
```
|
70 |
+
{
|
71 |
+
"instruction": str,
|
72 |
+
"input_images": [str, str, ...],
|
73 |
+
"output_images": str
|
74 |
+
}
|
75 |
+
```
|
76 |
+
You can see a toy example in `./toy_data/toy_data.jsonl`.
|
77 |
+
|
78 |
+
If an OOM(Out of Memory) issue occurs, you can try to decrease the `batch_size_per_device` or `max_image_size`. You can also try to use LoRA instead of full fine-tuning.
|
79 |
+
|
80 |
+
|
81 |
+
### Inference
|
82 |
+
|
83 |
+
The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load saved checkpoint:
|
84 |
+
```python
|
85 |
+
from OmniGen import OmniGenPipeline
|
86 |
+
|
87 |
+
pipe = OmniGenPipeline.from_pretrained("checkpoint_path") # e.g., ./results/toy_finetune/checkpoints/0000200
|
88 |
+
```
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
## LoRA fine-tuning
|
95 |
+
LoRA fine-tuning is a simple way to fine-tune OmniGen with less GPU memory. To use lora, you should add `--use_lora` and `--lora_rank` to the command.
|
96 |
+
|
97 |
+
```bash
|
98 |
+
accelerate launch \
|
99 |
+
--num_processes=1 \
|
100 |
+
train.py \
|
101 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
102 |
+
--batch_size_per_device 2 \
|
103 |
+
--condition_dropout_prob 0.01 \
|
104 |
+
--lr 3e-4 \
|
105 |
+
--use_lora \
|
106 |
+
--lora_rank 8 \
|
107 |
+
--json_file ./toy_data/toy_data.jsonl \
|
108 |
+
--image_path ./toy_data/images \
|
109 |
+
--max_input_length_limit 18000 \
|
110 |
+
--keep_raw_resolution \
|
111 |
+
--max_image_size 1024 \
|
112 |
+
--gradient_accumulation_steps 1 \
|
113 |
+
--ckpt_every 100 \
|
114 |
+
--epochs 100 \
|
115 |
+
--log_every 1 \
|
116 |
+
--results_dir ./results/toy_finetune_lora
|
117 |
+
```
|
118 |
+
|
119 |
+
### Inference
|
120 |
+
|
121 |
+
The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load checkpoint:
|
122 |
+
```python
|
123 |
+
from OmniGen import OmniGenPipeline
|
124 |
+
|
125 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
126 |
+
pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000100
|
127 |
+
```
|
128 |
+
|
129 |
+
|
130 |
+
## A simple example
|
131 |
+
|
132 |
+
Here is an example for learning new concepts: "sks dog". We use five images of one dog from [dog-example](https://huggingface.co/datasets/diffusers/dog-example).
|
133 |
+
|
134 |
+
The json file is `./toy_data/toy_subject_data.jsonl`, and the images have been saved in `./toy_data/images`.
|
135 |
+
|
136 |
+
```bash
|
137 |
+
accelerate launch \
|
138 |
+
--num_processes=1 \
|
139 |
+
train.py \
|
140 |
+
--model_name_or_path Shitao/OmniGen-v1 \
|
141 |
+
--batch_size_per_device 2 \
|
142 |
+
--condition_dropout_prob 0.01 \
|
143 |
+
--lr 1e-3 \
|
144 |
+
--use_lora \
|
145 |
+
--lora_rank 8 \
|
146 |
+
--json_file ./toy_data/toy_subject_data.jsonl \
|
147 |
+
--image_path ./toy_data/images \
|
148 |
+
--max_input_length_limit 18000 \
|
149 |
+
--keep_raw_resolution \
|
150 |
+
--max_image_size 1024 \
|
151 |
+
--gradient_accumulation_steps 1 \
|
152 |
+
--ckpt_every 100 \
|
153 |
+
--epochs 200 \
|
154 |
+
--log_every 1 \
|
155 |
+
--results_dir ./results/toy_finetune_lora
|
156 |
+
```
|
157 |
+
|
158 |
+
After training, you can use the following command to generate images:
|
159 |
+
```python
|
160 |
+
from OmniGen import OmniGenPipeline
|
161 |
+
|
162 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
163 |
+
pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000200
|
164 |
+
|
165 |
+
images = pipe(
|
166 |
+
prompt="a photo of sks dog running in the snow",
|
167 |
+
height=1024,
|
168 |
+
width=1024,
|
169 |
+
guidance_scale=3
|
170 |
+
)
|
171 |
+
images[0].save("example_sks_dog_snow.png")
|
172 |
+
```
|
docs/inference.md
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inference with OmniGen
|
2 |
+
|
3 |
+
To handle some complex tasks, image generation models are becoming increasingly sophisticated, leading to more and more cumbersome workflows. Existing image generation models like SD and Flux require loading many additional network modules (such as ControlNet, IP-Adapter, Reference-Net) and extra preprocessing steps (e.g., face detection, pose detection, image cropping) to generate a satisfactory image. This complex workflow is not user-friendly. We believe that future image generation models should be simpler, generating various images directly through instructions, similar to how GPT works in language generation.
|
4 |
+
|
5 |
+
Therefore, we propose OmniGen, a model capable of handling various image generation tasks within a single framework. The goal of OmniGen is to complete various image generation tasks without relying on any additional components or image preprocessing steps. OmniGen supports tasks including text-to-image generation, image editing, subject-driven image generation, and classical vision tasks, among others. More capabilities can be found in our examples. We provide inference code so you can explore more unknown functionalities yourself.
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
## Install
|
10 |
+
```bash
|
11 |
+
git clone https://github.com/staoxiao/OmniGen.git
|
12 |
+
cd OmniGen
|
13 |
+
pip install -e .
|
14 |
+
```
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
## Generate Images
|
19 |
+
You can use the following code to generate images:
|
20 |
+
```python
|
21 |
+
from OmniGen import OmniGenPipeline
|
22 |
+
|
23 |
+
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
|
24 |
+
|
25 |
+
# Text to Image
|
26 |
+
images = pipe(
|
27 |
+
prompt="A curly-haired man in a red shirt is drinking tea.",
|
28 |
+
height=1024,
|
29 |
+
width=1024,
|
30 |
+
guidance_scale=2.5,
|
31 |
+
seed=0,
|
32 |
+
)
|
33 |
+
images[0].save("example_t2i.png") # save output PIL Image
|
34 |
+
|
35 |
+
# Multi-modal to Image
|
36 |
+
# In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
|
37 |
+
# You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
|
38 |
+
images = pipe(
|
39 |
+
prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
40 |
+
input_images=["./imgs/test_cases/two_man.jpg"],
|
41 |
+
height=1024,
|
42 |
+
width=1024,
|
43 |
+
guidance_scale=2.5,
|
44 |
+
img_guidance_scale=1.6,
|
45 |
+
max_input_image_size=1024,
|
46 |
+
separate_cfg_infer=True,
|
47 |
+
use_kv_cache=True,
|
48 |
+
offload_kv_cache=True,
|
49 |
+
offload_model=False,
|
50 |
+
use_input_image_size_as_output=False,
|
51 |
+
seed=0,
|
52 |
+
)
|
53 |
+
images[0].save("example_ti2i.png") # save output PIL image
|
54 |
+
```
|
55 |
+
|
56 |
+
Some important arguments:
|
57 |
+
- `guidance_scale`: The strength of the guidance. Based on our experience, it is usually best to set it between 2 and 3. The higher the value, the more similar the generated image will be to the prompt. If the image appears oversaturated, please reduce the scale.
|
58 |
+
- `height` and `width`: The height and width of the generated image. The default value is 1024x1024. OmniGen support any size, but these number must be divisible by 16.
|
59 |
+
- `num_inference_steps`: The number of steps to take in the diffusion process. The higher the value, the more detailed the generated image will be.
|
60 |
+
- `max_input_image_size`: the maximum size of input image, which will be used to crop the input image to the maximum size. A smaller number will result in faster generation speed and lower memory cost.
|
61 |
+
- `separate_cfg_infer`: Whether to use separate inference process for CFG guidance. If set to True, memory cost will be lower. Default is True.
|
62 |
+
- `use_kv_cache`: Whether to use key-value cache. Default is True.
|
63 |
+
- `offload_kv_cache`: offload the cached key and value to cpu, which can save memory but slow down the generation silightly. Default is True.
|
64 |
+
- `offload_model`: offload the model to cpu, which can save memory but slow down the generation. Default is False.
|
65 |
+
- `use_input_image_size_as_output`: whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task. Default is False.
|
66 |
+
- `seed`: The seed for random number generator.
|
67 |
+
|
68 |
+
**More examples please refer to [inference.ipynb](../inference.ipynb)**
|
69 |
+
|
70 |
+
|
71 |
+
#### Input data
|
72 |
+
OmniGen can accept multi-modal input data. Specifically, you should pass two arguments: `prompt` and `input_images`.
|
73 |
+
For text to image generation, you can pass a string as `prompt`, or pass a list of strings as `prompt` to generate multiple images.
|
74 |
+
|
75 |
+
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>`.
|
76 |
+
For example, if you want to generate an image with a person holding a bouquet of flowers, you can pass the following prompt:
|
77 |
+
```
|
78 |
+
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."
|
79 |
+
input_images = ["./imgs/test_cases/liuyifei.png"]
|
80 |
+
```
|
81 |
+
The placeholder `<|image_1|>` will be replaced by the image at `input_images[0]`, i.e., `./imgs/test_cases/liuyifei.png`.
|
82 |
+
|
83 |
+
If you want to generate multiple images, you can pass a list of prompts and a list of image paths. For example:
|
84 |
+
```
|
85 |
+
prompt = ["A woman holds a bouquet of flowers and faces the camera.", "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."]
|
86 |
+
input_images = [[], ["./imgs/test_cases/liuyifei.png"]]
|
87 |
+
```
|
88 |
+
|
89 |
+
|
90 |
+
#### Gradio Demo
|
91 |
+
We have constructed a online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).
|
92 |
+
|
93 |
+
For the local gradio demo, you can run with the following command:
|
94 |
+
```python
|
95 |
+
python app.py
|
96 |
+
```
|
97 |
+
|
98 |
+
|
99 |
+
## Tips
|
100 |
+
- For out of memory or time cost, you can refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources) to select a appropriate setting.
|
101 |
+
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
102 |
+
- Not match the prompt: If the image does not match the prompt, please try to increase the `guidance_scale`.
|
103 |
+
- Low-quality: More detailed prompt will lead to better results.
|
104 |
+
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
105 |
+
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
106 |
+
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
107 |
+
- For image editing task and controlnet task, we recommend to set the height and width of output image as the same
|
108 |
+
as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the `use_input_image_size_as_output` to automatically set the height and width of output image as the same as input image.
|
109 |
+
|
110 |
+
|
111 |
+
## Requiremented Resources
|
112 |
+
|
113 |
+
We are currently experimenting with some techniques to reduce memory usage and improve speed, including `use_kv_cache, offload_kv_cache, separate_cfg_infer, offload_model`, which you can enable in the pipeline.
|
114 |
+
The default setting is`use_kv_cache=True, offload_kv_cache=True, separate_cfg_infer=True, offload_model=False`.
|
115 |
+
To reduce memory consumption while maintaining inference speed, quantization is also a method worth exploring and is left for future work.
|
116 |
+
|
117 |
+
We conducted experiments on the A800 and RTX 3090. The memory requirements and inference times are shown in the table below. You can choose the appropriate settings based on your available resources.
|
118 |
+
|
119 |
+
**Overall, the text-to-image task requires minimal memory and time costs, comparable to other latest text-to-image models. However, when using input images, the computational cost increases. Memory usage can be reduced by extending the processing time.**
|
120 |
+
|
121 |
+
|
122 |
+
- Different image size.
|
123 |
+
|
124 |
+
Different image size (`max_input_image_size` is the max size of input image, `height` and `width` are the size of output image) with the default inference settings (`use_kv_cache=True,offload_kv_cache=True,separate_cfg_infer=True`)
|
125 |
+
|
126 |
+
For A800 GPU:
|
127 |
+
| Settings | Only Text | Text + Single Image | Text + Two Images |
|
128 |
+
|:-------------|:----------:|:-------------------:|:---------------------:|
|
129 |
+
| max_input_image_size=1024,height=1024,width=1024 | 9G, 31s | 12G, 1m6s | 13G, 1m20s |
|
130 |
+
| max_input_image_size=512,height=1024,width=1024 | 9G, 31s | 10G, 50s | 10G, 54s |
|
131 |
+
| max_input_image_size=768,height=768,width=768 | 9G, 16s | 10G, 32s | 10G, 37s |
|
132 |
+
| max_input_image_size=512,height=512,width=512 | 9G, 7s | 9G, 14s | 9G, 15s |
|
133 |
+
|
134 |
+
For RTX 3090 GPU(24G):
|
135 |
+
| Settings | Only Text | Text + Single Image | Text + Two Images |
|
136 |
+
|:-------------|:----------:|:-------------------:|:---------------------:|
|
137 |
+
| max_input_image_size=1024,height=1024,width=1024 | 9G, 1m17s | 12G, 2m46s | 13G, 3m23s |
|
138 |
+
| max_input_image_size=512,height=1024,width=1024 | 9G, 1m18s | 10G, 2m8s | 10G, 2m18s |
|
139 |
+
| max_input_image_size=768,height=768,width=768 | 9G, 41s | 10G, 1m22s | 10G, 1m38s |
|
140 |
+
| max_input_image_size=512,height=512,width=512 | 9G, 19s | 9G, 36s | 9G, 43s |
|
141 |
+
|
142 |
+
|
143 |
+
You can set smaller `max_input_image_size` to reduce memory usage, but note that the generation quality may be lower.
|
144 |
+
And please set the `height` and `width` the same as the size of input image for image editing task.
|
145 |
+
|
146 |
+
|
147 |
+
- Different inference settings
|
148 |
+
|
149 |
+
Default image size: height=1024, width=1024, max_input_image_size=1024
|
150 |
+
|
151 |
+
For A800 GPU:
|
152 |
+
| Settings | Only Text | Text + Single Image | Text + Two Images |
|
153 |
+
|:-------------|:----------:|:-------------------:|:---------------------:|
|
154 |
+
| use_kv_cache | 18G, 30s | 36G, 1m | 48G, 1m13s |
|
155 |
+
| use_kv_cache,offload_kv_cache | 10G, 30s | 14G, 1m10s | 17G, 1m30s |
|
156 |
+
| use_kv_cache,offload_kv_cache,separate_cfg_infer | 9G, 31s | 12G, 1m6s | 13G, 1m20s |
|
157 |
+
| use_kv_cache,offload_kv_cache,offload_model | 4G, 55s | 7G, 1m30s | 11G, 1m48s |
|
158 |
+
| use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model | 3G, 1m23s | 5G, 2m19s | 6G, 2m30s |
|
159 |
+
|
160 |
+
For RTX 3090 GPU(24G):
|
161 |
+
| Settings | Only Text | Text + Single Image | Text + Two Images |
|
162 |
+
|:-------------|:----------:|:-------------------:|:---------------------:|
|
163 |
+
| use_kv_cache | 18G, 1m14s | OOM | OOM |
|
164 |
+
| use_kv_cache,offload_kv_cache | 10G, 1m17s | 14G, 3m11s | 17G, 4m3s |
|
165 |
+
| use_kv_cache,offload_kv_cache,separate_cfg_infer | 9G, 1m18s | 12G, 2m46s | 13G, 3m21s |
|
166 |
+
| use_kv_cache,offload_kv_cache,offload_model | 4G,3m1s | 7G, 4m14s | 11G, 5m4s |
|
167 |
+
| use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model | 3G, 4m56s | 5G, 7m49s | 6G, 8m6s |
|
edit.png → imgs/demo_cases.png
RENAMED
File without changes
|
imgs/demo_cases/edit.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/entity.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/t2i_woman_with_book.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/overall.jpg
ADDED
Git LFS Details
|
imgs/referring.png
ADDED
Git LFS Details
|
requirements.txt
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
transformers==4.45.2
|
3 |
+
datasets==2.20.0
|
4 |
+
accelerate==0.26.1
|
5 |
+
jupyter==1.0.0
|
6 |
+
numpy==1.26.3
|
7 |
+
pillow==10.2.0
|
8 |
+
torch==2.3.1
|
9 |
+
peft==0.9.0
|
10 |
+
diffusers==0.30.3
|
11 |
+
timm==0.9.16
|
setup.py
CHANGED
@@ -5,7 +5,7 @@ with open("README.md", mode="r", encoding="utf-8") as readme_file:
|
|
5 |
|
6 |
setup(
|
7 |
name='OmniGen',
|
8 |
-
version='1.0.
|
9 |
description='OmniGen',
|
10 |
long_description=readme,
|
11 |
long_description_content_type="text/markdown",
|
@@ -14,10 +14,13 @@ setup(
|
|
14 |
packages=find_packages(),
|
15 |
include_package_data=True,
|
16 |
install_requires=[
|
17 |
-
'torch
|
18 |
-
'transformers
|
19 |
'datasets',
|
20 |
-
'accelerate
|
21 |
-
'diffusers
|
|
|
|
|
|
|
22 |
],
|
23 |
-
)
|
|
|
5 |
|
6 |
setup(
|
7 |
name='OmniGen',
|
8 |
+
version='1.0.3',
|
9 |
description='OmniGen',
|
10 |
long_description=readme,
|
11 |
long_description_content_type="text/markdown",
|
|
|
14 |
packages=find_packages(),
|
15 |
include_package_data=True,
|
16 |
install_requires=[
|
17 |
+
'torch==2.3.1',
|
18 |
+
'transformers==4.45.2',
|
19 |
'datasets',
|
20 |
+
'accelerate==0.26.1',
|
21 |
+
'diffusers==0.30.3',
|
22 |
+
"timm",
|
23 |
+
"peft==0.9.0",
|
24 |
+
"safetensors"
|
25 |
],
|
26 |
+
)
|
train.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from time import time
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from copy import deepcopy
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
from torchvision import transforms
|
18 |
+
|
19 |
+
from accelerate import Accelerator
|
20 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
21 |
+
from diffusers.optimization import get_scheduler
|
22 |
+
from accelerate.utils import DistributedType
|
23 |
+
from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model
|
24 |
+
from peft.utils import get_peft_model_state_dict
|
25 |
+
from huggingface_hub import snapshot_download
|
26 |
+
from safetensors.torch import save_file
|
27 |
+
|
28 |
+
from diffusers.models import AutoencoderKL
|
29 |
+
|
30 |
+
from OmniGen import OmniGen, OmniGenProcessor
|
31 |
+
from OmniGen.train_helper import DatasetFromJson, TrainDataCollator
|
32 |
+
from OmniGen.train_helper import training_losses
|
33 |
+
from OmniGen.utils import (
|
34 |
+
create_logger,
|
35 |
+
update_ema,
|
36 |
+
requires_grad,
|
37 |
+
center_crop_arr,
|
38 |
+
crop_arr,
|
39 |
+
vae_encode,
|
40 |
+
vae_encode_list
|
41 |
+
)
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# Setup accelerator:
|
45 |
+
from accelerate import DistributedDataParallelKwargs as DDPK
|
46 |
+
kwargs = DDPK(find_unused_parameters=False)
|
47 |
+
accelerator = Accelerator(
|
48 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
49 |
+
mixed_precision=args.mixed_precision,
|
50 |
+
log_with=args.report_to,
|
51 |
+
project_dir=args.results_dir,
|
52 |
+
kwargs_handlers=[kwargs],
|
53 |
+
)
|
54 |
+
device = accelerator.device
|
55 |
+
accelerator.init_trackers("tensorboard_log", config=args.__dict__)
|
56 |
+
|
57 |
+
# Setup an experiment folder:
|
58 |
+
checkpoint_dir = f"{args.results_dir}/checkpoints" # Stores saved model checkpoints
|
59 |
+
logger = create_logger(args.results_dir)
|
60 |
+
if accelerator.is_main_process:
|
61 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
62 |
+
logger.info(f"Experiment directory created at {args.results_dir}")
|
63 |
+
json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w'))
|
64 |
+
|
65 |
+
|
66 |
+
# Create model:
|
67 |
+
if not os.path.exists(args.model_name_or_path):
|
68 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
69 |
+
args.model_name_or_path = snapshot_download(repo_id=args.model_name_or_path,
|
70 |
+
cache_dir=cache_folder,
|
71 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
72 |
+
logger.info(f"Downloaded model to {args.model_name_or_path}")
|
73 |
+
model = OmniGen.from_pretrained(args.model_name_or_path)
|
74 |
+
model.llm.config.use_cache = False
|
75 |
+
model.llm.gradient_checkpointing_enable()
|
76 |
+
model = model.to(device)
|
77 |
+
|
78 |
+
if args.vae_path is None:
|
79 |
+
print(args.model_name_or_path)
|
80 |
+
vae_path = os.path.join(args.model_name_or_path, "vae")
|
81 |
+
if os.path.exists(vae_path):
|
82 |
+
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
|
83 |
+
else:
|
84 |
+
logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF")
|
85 |
+
logger.info("If you have VAE in local folder, please specify the path with --vae_path")
|
86 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
|
87 |
+
else:
|
88 |
+
vae = AutoencoderKL.from_pretrained(args.vae_path).to(device)
|
89 |
+
|
90 |
+
weight_dtype = torch.float32
|
91 |
+
if accelerator.mixed_precision == "fp16":
|
92 |
+
weight_dtype = torch.float16
|
93 |
+
elif accelerator.mixed_precision == "bf16":
|
94 |
+
weight_dtype = torch.bfloat16
|
95 |
+
vae.to(dtype=torch.float32)
|
96 |
+
model.to(weight_dtype)
|
97 |
+
|
98 |
+
processor = OmniGenProcessor.from_pretrained(args.model_name_or_path)
|
99 |
+
|
100 |
+
requires_grad(vae, False)
|
101 |
+
if args.use_lora:
|
102 |
+
if accelerator.distributed_type == DistributedType.FSDP:
|
103 |
+
raise NotImplementedError("FSDP does not support LoRA")
|
104 |
+
requires_grad(model, False)
|
105 |
+
transformer_lora_config = LoraConfig(
|
106 |
+
r=args.lora_rank,
|
107 |
+
lora_alpha=args.lora_rank,
|
108 |
+
init_lora_weights="gaussian",
|
109 |
+
target_modules=["qkv_proj", "o_proj"],
|
110 |
+
)
|
111 |
+
model.llm.enable_input_require_grads()
|
112 |
+
model = get_peft_model(model, transformer_lora_config)
|
113 |
+
model.to(weight_dtype)
|
114 |
+
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
115 |
+
for n,p in model.named_parameters():
|
116 |
+
print(n, p.requires_grad)
|
117 |
+
opt = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay)
|
118 |
+
else:
|
119 |
+
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay)
|
120 |
+
|
121 |
+
ema = None
|
122 |
+
if args.use_ema:
|
123 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
124 |
+
requires_grad(ema, False)
|
125 |
+
|
126 |
+
|
127 |
+
# Setup data:
|
128 |
+
crop_func = crop_arr
|
129 |
+
if not args.keep_raw_resolution:
|
130 |
+
crop_func = center_crop_arr
|
131 |
+
image_transform = transforms.Compose([
|
132 |
+
transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)),
|
133 |
+
transforms.ToTensor(),
|
134 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
135 |
+
])
|
136 |
+
|
137 |
+
dataset = DatasetFromJson(json_file=args.json_file,
|
138 |
+
image_path=args.image_path,
|
139 |
+
processer=processor,
|
140 |
+
image_transform=image_transform,
|
141 |
+
max_input_length_limit=args.max_input_length_limit,
|
142 |
+
condition_dropout_prob=args.condition_dropout_prob,
|
143 |
+
keep_raw_resolution=args.keep_raw_resolution
|
144 |
+
)
|
145 |
+
collate_fn = TrainDataCollator(pad_token_id=processor.text_tokenizer.eos_token_id, hidden_size=model.llm.config.hidden_size, keep_raw_resolution=args.keep_raw_resolution)
|
146 |
+
|
147 |
+
loader = DataLoader(
|
148 |
+
dataset,
|
149 |
+
collate_fn=collate_fn,
|
150 |
+
batch_size=args.batch_size_per_device,
|
151 |
+
shuffle=True,
|
152 |
+
num_workers=args.num_workers,
|
153 |
+
pin_memory=True,
|
154 |
+
drop_last=True,
|
155 |
+
prefetch_factor=2,
|
156 |
+
)
|
157 |
+
|
158 |
+
if accelerator.is_main_process:
|
159 |
+
logger.info(f"Dataset contains {len(dataset):,}")
|
160 |
+
|
161 |
+
num_update_steps_per_epoch = math.ceil(len(loader) / args.gradient_accumulation_steps)
|
162 |
+
max_train_steps = args.epochs * num_update_steps_per_epoch
|
163 |
+
lr_scheduler = get_scheduler(
|
164 |
+
args.lr_scheduler,
|
165 |
+
optimizer=opt,
|
166 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
167 |
+
num_training_steps=max_train_steps * args.gradient_accumulation_steps,
|
168 |
+
)
|
169 |
+
|
170 |
+
# Prepare models for training:
|
171 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
172 |
+
|
173 |
+
if ema is not None:
|
174 |
+
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
|
175 |
+
ema.eval() # EMA model should always be in eval mode
|
176 |
+
|
177 |
+
|
178 |
+
if ema is not None:
|
179 |
+
model, ema = accelerator.prepare(model, ema)
|
180 |
+
else:
|
181 |
+
model = accelerator.prepare(model)
|
182 |
+
|
183 |
+
opt, loader, lr_scheduler = accelerator.prepare(opt, loader, lr_scheduler)
|
184 |
+
|
185 |
+
|
186 |
+
# Variables for monitoring/logging purposes:
|
187 |
+
train_steps, log_steps = 0, 0
|
188 |
+
running_loss = 0
|
189 |
+
start_time = time()
|
190 |
+
|
191 |
+
if accelerator.is_main_process:
|
192 |
+
logger.info(f"Training for {args.epochs} epochs...")
|
193 |
+
for epoch in range(args.epochs):
|
194 |
+
if accelerator.is_main_process:
|
195 |
+
logger.info(f"Beginning epoch {epoch}...")
|
196 |
+
|
197 |
+
for data in loader:
|
198 |
+
with accelerator.accumulate(model):
|
199 |
+
with torch.no_grad():
|
200 |
+
output_images = data['output_images']
|
201 |
+
input_pixel_values = data['input_pixel_values']
|
202 |
+
if isinstance(output_images, list):
|
203 |
+
output_images = vae_encode_list(vae, output_images, weight_dtype)
|
204 |
+
if input_pixel_values is not None:
|
205 |
+
input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype)
|
206 |
+
else:
|
207 |
+
output_images = vae_encode(vae, output_images, weight_dtype)
|
208 |
+
if input_pixel_values is not None:
|
209 |
+
input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype)
|
210 |
+
|
211 |
+
|
212 |
+
model_kwargs = dict(input_ids=data['input_ids'], input_img_latents=input_pixel_values, input_image_sizes=data['input_image_sizes'], attention_mask=data['attention_mask'], position_ids=data['position_ids'], padding_latent=data['padding_images'], past_key_values=None, return_past_key_values=False)
|
213 |
+
|
214 |
+
loss_dict = training_losses(model, output_images, model_kwargs)
|
215 |
+
loss = loss_dict["loss"].mean()
|
216 |
+
|
217 |
+
running_loss += loss.item()
|
218 |
+
accelerator.backward(loss)
|
219 |
+
if args.max_grad_norm is not None and accelerator.sync_gradients:
|
220 |
+
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
221 |
+
opt.step()
|
222 |
+
lr_scheduler.step()
|
223 |
+
opt.zero_grad()
|
224 |
+
|
225 |
+
log_steps += 1
|
226 |
+
train_steps += 1
|
227 |
+
|
228 |
+
accelerator.log({"training_loss": loss.item()}, step=train_steps)
|
229 |
+
if train_steps % args.gradient_accumulation_steps == 0:
|
230 |
+
if accelerator.sync_gradients and ema is not None:
|
231 |
+
update_ema(ema, model)
|
232 |
+
|
233 |
+
if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
|
234 |
+
torch.cuda.synchronize()
|
235 |
+
end_time = time()
|
236 |
+
steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time)
|
237 |
+
# Reduce loss history over all processes:
|
238 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
239 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
240 |
+
avg_loss = avg_loss.item() / accelerator.num_processes
|
241 |
+
|
242 |
+
if accelerator.is_main_process:
|
243 |
+
cur_lr = opt.param_groups[0]["lr"]
|
244 |
+
logger.info(f"(step={int(train_steps/args.gradient_accumulation_steps):07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Epoch: {train_steps/len(loader)}, LR: {cur_lr}")
|
245 |
+
|
246 |
+
# Reset monitoring variables:
|
247 |
+
running_loss = 0
|
248 |
+
log_steps = 0
|
249 |
+
start_time = time()
|
250 |
+
|
251 |
+
|
252 |
+
if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
|
253 |
+
if accelerator.distributed_type == DistributedType.FSDP:
|
254 |
+
state_dict = accelerator.get_state_dict(model)
|
255 |
+
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
|
256 |
+
else:
|
257 |
+
if not args.use_lora:
|
258 |
+
state_dict = model.module.state_dict()
|
259 |
+
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
|
260 |
+
|
261 |
+
if accelerator.is_main_process:
|
262 |
+
if args.use_lora:
|
263 |
+
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/"
|
264 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
265 |
+
|
266 |
+
model.module.save_pretrained(checkpoint_path)
|
267 |
+
else:
|
268 |
+
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/"
|
269 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
270 |
+
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
|
271 |
+
processor.text_tokenizer.save_pretrained(checkpoint_path)
|
272 |
+
model.llm.config.save_pretrained(checkpoint_path)
|
273 |
+
if ema_state_dict is not None:
|
274 |
+
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema"
|
275 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
276 |
+
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
|
277 |
+
processor.text_tokenizer.save_pretrained(checkpoint_path)
|
278 |
+
model.llm.config.save_pretrained(checkpoint_path)
|
279 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
280 |
+
|
281 |
+
dist.barrier()
|
282 |
+
accelerator.end_training()
|
283 |
+
model.eval()
|
284 |
+
|
285 |
+
if accelerator.is_main_process:
|
286 |
+
logger.info("Done!")
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
parser = argparse.ArgumentParser()
|
291 |
+
parser.add_argument("--results_dir", type=str, default="results")
|
292 |
+
parser.add_argument("--model_name_or_path", type=str, default="OmniGen")
|
293 |
+
parser.add_argument("--json_file", type=str)
|
294 |
+
parser.add_argument("--image_path", type=str, default=None)
|
295 |
+
parser.add_argument("--epochs", type=int, default=1400)
|
296 |
+
parser.add_argument("--batch_size_per_device", type=int, default=1)
|
297 |
+
parser.add_argument("--vae_path", type=str, default=None)
|
298 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
299 |
+
parser.add_argument("--log_every", type=int, default=100)
|
300 |
+
parser.add_argument("--ckpt_every", type=int, default=20000)
|
301 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
302 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
303 |
+
parser.add_argument("--max_input_length_limit", type=int, default=1024)
|
304 |
+
parser.add_argument("--condition_dropout_prob", type=float, default=0.1)
|
305 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.0)
|
306 |
+
parser.add_argument(
|
307 |
+
"--keep_raw_resolution",
|
308 |
+
action="store_true",
|
309 |
+
help="multiple_resolutions",
|
310 |
+
)
|
311 |
+
parser.add_argument("--max_image_size", type=int, default=1344)
|
312 |
+
|
313 |
+
parser.add_argument(
|
314 |
+
"--use_lora",
|
315 |
+
action="store_true",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--lora_rank",
|
319 |
+
type=int,
|
320 |
+
default=8
|
321 |
+
)
|
322 |
+
|
323 |
+
parser.add_argument(
|
324 |
+
"--use_ema",
|
325 |
+
action="store_true",
|
326 |
+
help="Whether or not to use ema.",
|
327 |
+
)
|
328 |
+
parser.add_argument(
|
329 |
+
"--lr_scheduler",
|
330 |
+
type=str,
|
331 |
+
default="constant",
|
332 |
+
help=(
|
333 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
334 |
+
' "constant", "constant_with_warmup"]'
|
335 |
+
),
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler."
|
339 |
+
)
|
340 |
+
parser.add_argument(
|
341 |
+
"--report_to",
|
342 |
+
type=str,
|
343 |
+
default="tensorboard",
|
344 |
+
help=(
|
345 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
346 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
347 |
+
),
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--mixed_precision",
|
351 |
+
type=str,
|
352 |
+
default="bf16",
|
353 |
+
choices=["no", "fp16", "bf16"],
|
354 |
+
help=(
|
355 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
356 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
357 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
358 |
+
),
|
359 |
+
)
|
360 |
+
parser.add_argument(
|
361 |
+
"--gradient_accumulation_steps",
|
362 |
+
type=int,
|
363 |
+
default=1,
|
364 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
365 |
+
)
|
366 |
+
|
367 |
+
|
368 |
+
args = parser.parse_args()
|
369 |
+
assert args.max_image_size % 16 == 0, "Image size must be divisible by 16."
|
370 |
+
|
371 |
+
main(args)
|
372 |
+
|
373 |
+
|