PickScore / app.py
micole66's picture
Update app.py
dcef9e4
raw
history blame
3.92 kB
import time
from PIL import Image
import gradio as gr
from glob import glob
import torch
from transformers import AutoModel, AutoProcessor
DEFAULT_EXAMPLE_PATH = f'examples/example_0'
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"Using device: {device} ({weight_dtype})")
print("Loading model...")
model_pretrained_name_or_path = "facebook/metaclip-h14-fullcc2.5b"
processor = AutoProcessor.from_pretrained(model_pretrained_name_or_path)
model = AutoModel.from_pretrained(model_pretrained_name_or_path, torch_dtype=weight_dtype).eval().to(device)
print("Model loaded.")
def calc_probs(prompt, images):
print("Processing inputs...")
image_inputs = processor(
images=images,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(device)
image_inputs = {k: v.to(weight_dtype) for k, v in image_inputs.items()}
text_inputs = processor(
text=prompt,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(device)
with torch.no_grad():
print("Embedding images and text...")
image_embs = model.get_image_features(**image_inputs)
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
text_embs = model.get_text_features(**text_inputs)
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
print("Calculating scores...")
scores = model.logit_scale.exp() * (text_embs.float() @ image_embs.float().T)[0]
print("Calculating probabilities...")
probs = torch.softmax(scores, dim=-1)
return probs.cpu().tolist()
def predict(prompt, image_1, image_2):
print(f"Starting prediction for prompt: {prompt}")
start_time = time.time()
probs = calc_probs(prompt, [image_1, image_2])
print(f"Prediction: {probs} ({time.time() - start_time:.2f} seconds, ) ")
if device == "cuda":
print(f"GPU mem used: {round(torch.cuda.max_memory_allocated(device) / 1024 / 1024 / 1024, 2)}/{round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024 / 1024, 2)} GB")
return str(round(probs[0], 3)), str(round(probs[1], 3))
with gr.Blocks(title="PickScore v1") as demo:
gr.Markdown("# PickScore v1")
gr.Markdown(
"This is a demo for the PickScore model - see [paper](https://arxiv.org/abs/2305.01569), [code](https://github.com/yuvalkirstain/PickScore), [dataset](https://huggingface.co/datasets/pickapic-anonymous/pickapic_v1), and [model](https://huggingface.co/yuvalkirstain/PickScore_v1).")
gr.Markdown("## Instructions")
gr.Markdown("Write a prompt, place two images, and press run to get their PickScore!")
with gr.Row():
prompt = gr.inputs.Textbox(lines=1, label="Prompt",
default=open(f'{DEFAULT_EXAMPLE_PATH}/prompt.txt').readline())
with gr.Row():
image_1 = gr.components.Image(type="pil", label="image 1",
value=Image.open(f'{DEFAULT_EXAMPLE_PATH}/image_1.png'))
image_2 = gr.components.Image(type="pil", label="image 2",
value=Image.open(f'{DEFAULT_EXAMPLE_PATH}/image_2.png'))
with gr.Row():
pred_1 = gr.outputs.Textbox(label="Probability 1")
pred_2 = gr.outputs.Textbox(label="Probability 2")
btn = gr.Button("Run")
btn.click(fn=predict, inputs=[prompt, image_1, image_2], outputs=[pred_1, pred_2])
prompt.change(lambda: ("", ""), inputs=[], outputs=[pred_1, pred_2])
gr.Examples(
[[open(f'{path}/prompt.txt').readline(), f'{path}/image_1.png', f'{path}/image_2.png'] for path in
glob(f'examples/*')],
[prompt, image_1, image_2],
[pred_1, pred_2],
predict
)
demo.queue(concurrency_count=5).launch()