import io import base64 from typing import List, Tuple import numpy as np import gradio as gr from datasets import load_dataset from transformers import AutoProcessor, AutoModel import torch from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # Load example dataset dataset = load_dataset("xzuyn/dalle-3_vs_sd-v1-5_dpo", num_proc=4) processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" model_name = "yuvalkirstain/PickScore_v1" processor = AutoProcessor.from_pretrained(processor_name) model = AutoModel.from_pretrained(model_name, torch_dtype=dtype).to(device) def decode_image(image: str) -> Image: """ Decodes base64 string to PIL image. Args: image: base64 string Returns: PIL image """ img_byte_arr = base64.b64decode(image) img_byte_arr = io.BytesIO(img_byte_arr) img_byte_arr = Image.open(img_byte_arr) return img_byte_arr def get_preference(img_1: Image.Image, img_2: Image.Image, caption: str) -> Image.Image: """ Returns the preference of the caption for the two images. Args: img_1: PIL image img_2: PIL image caption: string Returns: preference image: PIL image """ imgs = [img_1, img_2] logits = get_logits(caption, imgs) preference = logits.argmax().item() return imgs[preference] def sample_example() -> Tuple[Image.Image, Image.Image, Image.Image, str]: """ Samples a random example from the dataset and displays it. Returns: img_1: PIL image img_2: PIL image preference: PIL image caption: string """ example = dataset["train"][np.random.randint(0, len(dataset["train"]))] img_1 = decode_image(example["jpg_0"]) img_2 = decode_image(example["jpg_1"]) caption = example["caption"] imgs = [img_1, img_2] logits = get_logits(caption, imgs) preference = logits.argmax().item() return (img_1, img_2, imgs[preference], caption) def get_logits(caption: str, imgs: List[Image.Image]) -> torch.Tensor: """ Returns the logits for the caption and images. Args: caption: string imgs: list of PIL images Returns: logits: torch.Tensor """ inputs = processor( text=caption, images=imgs, return_tensors="pt", padding=True, truncation=True, max_length=77, ).to(device) inputs["pixel_values"] = ( inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"] ) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image return logits_per_image ### Description title = r"""