File size: 1,763 Bytes
c5c9abd
 
 
 
 
85e220e
 
c5c9abd
 
 
a0cbb76
 
85e220e
a0cbb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5c9abd
 
 
 
 
85e220e
 
 
 
 
 
 
 
c5c9abd
a0cbb76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
---
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- DPO
- DiffusionDPO
inference: true
---

# Aligned Diffusion Model via DPO

Diffusion model aligned with the following reward models and Direct Preference Optimization (DPO) algorithm
```
close-sourced vlm: claude3-opus  gemini-1.5  gpt-4o  gpt-4v
open-sourced vlm: internvl-1.5
score model: hps-2.1
```

## How to Use

You can load the model and perform inference as follows:
```python
from diffusers import StableDiffusionPipeline, UNet2DConditionModel

pretrained_model_name = "runwayml/stable-diffusion-v1-5"

dpo_unet = UNet2DConditionModel.from_pretrained(
        "path/to/checkpoint",
        subfolder='unet',
        torch_dtype=torch.float16
    ).to('cuda')

pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name, torch_dtype=torch.float16)
pipeline = pipeline.to('cuda')
pipeline.safety_checker = None
pipeline.unet = dpo_unet

generator = torch.Generator(device='cuda')
generator = generator.manual_seed(1)

prompt = "a pink flower"

image = pipeline(prompt=prompt, generator=generator, guidance_scale=gs).images[0]

```


## Citation
```
@misc{chen2024mjbenchmultimodalrewardmodel,
    title={MJ-Bench: Is Your Multimodal Reward Model Really a Good Judge for Text-to-Image Generation?}, 
    author={Zhaorun Chen and Yichao Du and Zichen Wen and Yiyang Zhou and Chenhang Cui and Zhenzhen Weng and Haoqin Tu and Chaoqi Wang and Zhengwei Tong and Qinglan Huang and Canyu Chen and Qinghao Ye and Zhihong Zhu and Yuqing Zhang and Jiawei Zhou and Zhuokai Zhao and Rafael Rafailov and Chelsea Finn and Huaxiu Yao},
    year={2024},
    eprint={2407.04842},
    archivePrefix={arXiv},
    primaryClass={cs.CV},
    url={https://arxiv.org/abs/2407.04842}, 
}
```