mattmdjaga commited on
Commit
9e9fb6f
β€’
1 Parent(s): ae479b2

App init + req.txt

Browse files
Files changed (2) hide show
  1. app.py +166 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import gradio as gr
7
+ from datasets import load_dataset
8
+ from transformers import AutoProcessor, AutoModel
9
+ import torch
10
+ from PIL import Image
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = torch.float16 if device == "cuda" else torch.float32
14
+
15
+ # Load example dataset
16
+ dataset = load_dataset("xzuyn/dalle-3_vs_sd-v1-5_dpo", num_proc=8)
17
+
18
+ processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
19
+ model_name = "yuvalkirstain/PickScore_v1"
20
+ processor = AutoProcessor.from_pretrained(processor_name)
21
+ model = AutoModel.from_pretrained(model_name, torch_dtype=dtype).to(device)
22
+
23
+
24
+ def decode_image(image: str) -> Image:
25
+ """
26
+ Decodes base64 string to PIL image.
27
+ Args:
28
+ image: base64 string
29
+ Returns:
30
+ PIL image
31
+ """
32
+ img_byte_arr = base64.b64decode(image)
33
+ img_byte_arr = io.BytesIO(img_byte_arr)
34
+ img_byte_arr = Image.open(img_byte_arr)
35
+ return img_byte_arr
36
+
37
+
38
+ def get_preference(img_1: Image.Image, img_2: Image.Image, caption: str) -> Image.Image:
39
+ """
40
+ Returns the preference of the caption for the two images.
41
+ Args:
42
+ img_1: PIL image
43
+ img_2: PIL image
44
+ caption: string
45
+ Returns:
46
+ preference image: PIL image
47
+ """
48
+ imgs = [img_1, img_2]
49
+ logits = get_logits(caption, imgs)
50
+ preference = logits.argmax().item()
51
+
52
+ return imgs[preference]
53
+
54
+
55
+ def sample_example() -> Tuple[Image.Image, Image.Image, Image.Image, str]:
56
+ """
57
+ Samples a random example from the dataset and displays it.
58
+
59
+ Returns:
60
+ img_1: PIL image
61
+ img_2: PIL image
62
+ preference: PIL image
63
+ caption: string
64
+ """
65
+ example = dataset["train"][np.random.randint(0, len(dataset["train"]))]
66
+ img_1 = decode_image(example["jpg_0"])
67
+ img_2 = decode_image(example["jpg_1"])
68
+ caption = example["caption"]
69
+ imgs = [img_1, img_2]
70
+ logits = get_logits(caption, imgs)
71
+ preference = logits.argmax().item()
72
+ return (img_1, img_2, imgs[preference], caption)
73
+
74
+
75
+ def get_logits(caption: str, imgs: List[Image.Image]) -> torch.Tensor:
76
+ """
77
+ Returns the logits for the caption and images.
78
+
79
+ Args:
80
+ caption: string
81
+ imgs: list of PIL images
82
+ Returns:
83
+ logits: torch.Tensor
84
+ """
85
+
86
+ inputs = processor(
87
+ text=caption,
88
+ images=imgs,
89
+ return_tensors="pt",
90
+ padding=True,
91
+ truncation=True,
92
+ max_length=77,
93
+ ).to(device)
94
+ inputs["pixel_values"] = (
95
+ inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"]
96
+ )
97
+ outputs = model(**inputs)
98
+ logits_per_image = outputs.logits_per_image
99
+
100
+ return logits_per_image
101
+
102
+
103
+ ### Description
104
+ title = r"""
105
+ <h1 align="center">Aesthetic Scorer: CLIP fine-tuned for DPO scoring </h1>
106
+ """
107
+
108
+ description = r"""
109
+ <b> This is a demo for the paper <a href="https://arxiv.org/abs/2109.04436">Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation </a> </b> <br>
110
+
111
+ How to use this demo: <br>
112
+ 1. Upload two images generated using the same caption.
113
+ 2. Enter the caption used to generate the images.
114
+ 3. Click on the "Get Preference" button to get the image which scores higher on user preferences according to the model. <br>
115
+ <b> OR </b> <br>
116
+ 1. Click on the "Random Example" button to get a random example from a <a href="https://huggingface.co/datasets/xzuyn/dalle-3_vs_sd-v1-5_dpo">Dalle 3 vs SD 1.5 DPO dataset. </a><br>
117
+
118
+ This demo demonstrates the use of this CLIP variant for DPO scoring. The scores can then be used for DPO fine-tuning with these <a href="https://github.com/huggingface/diffusers/tree/main/examples/research_projects/diffusion_dpo">scripts. </a><br>
119
+
120
+ Accuracy on the <a href="https://huggingface.co/datasets/xzuyn/dalle-3_vs_sd-v1-5_dpo">Dalle 3 vs SD 1.5 DPO dataset: </a><br>
121
+ <a href="https://huggingface.co/yuvalkirstain/PickScore_v1">PickScore_v1</a> - 97.3 <br>
122
+ <a href="https://huggingface.co/CIDAS/clipseg-rd64-refined">CLIPSeg</a> - 70.9 <br>
123
+ <a href="https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K">CLIP-ViT-H-14-laion2B-s32B-b79K</a> - 82.3 <br>
124
+ """
125
+
126
+ citation = r"""
127
+ πŸ“ **Citation**
128
+ ```bibtex
129
+ @inproceedings{Kirstain2023PickaPicAO,
130
+ title={Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation},
131
+ author={Yuval Kirstain and Adam Polyak and Uriel Singer and Shahbuland Matiana and Joe Penna and Omer Levy},
132
+ year={2023}
133
+ }
134
+ ```
135
+ """
136
+
137
+ with gr.Blocks() as demo:
138
+ gr.Markdown(title)
139
+ gr.Markdown(description)
140
+
141
+ with gr.Row():
142
+ first_image = gr.Image(height=400, width=400, label="First Image")
143
+ second_image = gr.Image(height=400, width=400, label="Second Image")
144
+
145
+ caption_box = gr.Textbox(lines=1, label="Caption")
146
+
147
+ with gr.Row():
148
+ image_button = gr.Button("Get Preference")
149
+ random_example = gr.Button("Random Example")
150
+
151
+ image_output = gr.Image(height=500, width=500, label="Preference")
152
+
153
+ image_button.click(
154
+ get_preference,
155
+ inputs=[first_image, second_image, caption_box],
156
+ outputs=image_output,
157
+ )
158
+
159
+ random_example.click(
160
+ sample_example, outputs=[first_image, second_image, image_output, caption_box]
161
+ )
162
+
163
+ gr.Markdown(citation)
164
+
165
+ if __name__ == "__main__":
166
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers