File size: 6,287 Bytes
d1e3547 9c6373a 2955fb2 9c6373a d1e3547 3578624 d1e3547 3578624 d1e3547 3578624 d1e3547 3578624 d1e3547 3578624 d1e3547 083e245 d1e3547 6244a14 d1e3547 6244a14 d1e3547 3578624 6244a14 3578624 d1e3547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
---
pipeline_tag: any-to-any
license: apache-2.0
library_name: transformers
---
<div align='center'>
<h1>Emu3: Next-Token Prediction is All You Need</h1h1>
<h3></h3>
[Emu3 Team, BAAI](https://www.baai.ac.cn/english.html)
| [Project Page](https://emu.baai.ac.cn) | [Paper](https://huggingface.co/papers/2409.18869) | [🤗HF Models](https://huggingface.co/collections/BAAI/emu3-66f4e64f70850ff358a2e60f) | [github](https://github.com/baaivision/Emu3)
| [Demo](https://huggingface.co/spaces/BAAI/Emu3) |
</div>
<div align='center'>
<img src="https://github.com/baaivision/Emu3/blob/main/assets/arch.png?raw=True" class="interpolation-image" alt="arch." height="80%" width="70%" />
</div>
We introduce **Emu3**, a new suite of state-of-the-art multimodal models trained solely with **<i>next-token prediction</i>**! By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences.
### Emu3 excels in both generation and perception
**Emu3** outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship open models such as SDXL, LLaVA-1.6 and OpenSora-1.2, while eliminating the need for diffusion or compositional architectures.
<div align='center'>
<img src="https://github.com/baaivision/Emu3/blob/main/assets/comparison.png?raw=True" class="interpolation-image" alt="comparison." height="80%" width="80%" />
</div>
### Highlights
- **Emu3** is capable of generating high-quality images following the text input, by simply predicting the next vision token. The model naturally supports flexible resolutions and styles.
- **Emu3** shows strong vision-language understanding capabilities to see the physical world and provides coherent text responses. Notably, this capability is achieved without depending on a CLIP and a pretrained LLM.
- **Emu3** simply generates a video causally by predicting the next token in a video sequence, unlike the video diffusion model as in Sora. With a video in context, Emu3 can also naturally extend the video and predict what will happen next.
### Model Information
The **Emu3-Stage1** model is the pre-trained weights of the first stage of the pre-training process of Emu3. The pre-training process of Emu3 is conducted in two stages. In the first stage, **which does not utilize video data**, training begins from scratch with a context length of 5120 for text and image data. The model supports image captioning and can generate images at a resolution of 512x512. You can use our [training scripts](https://github.com/baaivision/Emu3/tree/main/scripts) for further instruction tuning for more **image generation and perception tasks**.
#### Quickstart
```python
from PIL import Image
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
import torch
import sys
sys.path.append(PATH_TO_BAAI_Emu3-Stage1_MODEL)
from processing_emu3 import Emu3Processor
# model path
EMU_HUB = "BAAI/Emu3-Stage1"
VQ_HUB = "BAAI/Emu3-VisionTokenizer"
# prepare model and processor
model = AutoModelForCausalLM.from_pretrained(
EMU_HUB,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left")
image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer, chat_template="{image_prompt}{text_prompt}")
# Image Generation
# prepare input
POSITIVE_PROMPT = " masterpiece, film grained, best quality."
NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
classifier_free_guidance = 3.0
prompt = "a portrait of young girl."
prompt += POSITIVE_PROMPT
kwargs = dict(
mode='G',
ratio="1:1",
image_area=model.config.image_area,
return_tensors="pt",
padding="longest",
)
pos_inputs = processor(text=prompt, **kwargs)
neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
use_cache=True,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.pad_token_id,
max_new_tokens=40960,
do_sample=True,
top_k=2048,
)
h = pos_inputs.image_size[:, 0]
w = pos_inputs.image_size[:, 1]
constrained_fn = processor.build_prefix_constrained_fn(h, w)
logits_processor = LogitsProcessorList([
UnbatchedClassifierFreeGuidanceLogitsProcessor(
classifier_free_guidance,
model,
unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
),
PrefixConstrainedLogitsProcessor(
constrained_fn ,
num_beams=1,
),
])
# generate
outputs = model.generate(
pos_inputs.input_ids.to("cuda:0"),
GENERATION_CONFIG,
logits_processor=logits_processor,
attention_mask=pos_inputs.attention_mask.to("cuda:0"),
)
mm_list = processor.decode(outputs[0])
for idx, im in enumerate(mm_list):
if not isinstance(im, Image.Image):
continue
im.save(f"result_{idx}.png")
# Multimodal Understanding
text = "The image depicts "
image = Image.open("assets/demo.png")
inputs = processor(
text=text,
image=image,
mode='U',
padding="longest",
return_tensors="pt",
)
GENERATION_CONFIG = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=1024,
)
outputs = model.generate(
inputs.input_ids.to("cuda:0"),
GENERATION_CONFIG,
attention_mask=inputs.attention_mask.to("cuda:0"),
)
outputs = outputs[:, inputs.input_ids.shape[-1]:]
answers = processor.batch_decode(outputs, skip_special_tokens=True)
for ans in answers:
print(ans)
``` |