|
import gradio as gr |
|
from google import genai |
|
from utils import * |
|
from PIL import Image |
|
import os |
|
|
|
|
|
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) |
|
system_instruction = load_verifier_prompt() |
|
generation_config = types.GenerateContentConfig( |
|
system_instruction=system_instruction, |
|
response_mime_type="application/json", |
|
response_schema=list[Grading], |
|
seed=1994, |
|
) |
|
|
|
|
|
def make_inputs(prompt, image): |
|
inputs = [] |
|
inputs.extend(prepare_inputs(prompt=prompt, image=image)) |
|
return inputs |
|
|
|
|
|
def format_response(response: dict): |
|
out = "" |
|
for key, value in response.items(): |
|
score = f"* **{key}**: {value['score']} (explanation: {value['explanation']})\n" |
|
out += score |
|
return out |
|
|
|
|
|
def grade(prompt, image): |
|
inputs = make_inputs(prompt, image) |
|
response = client.models.generate_content( |
|
model="gemini-2.0-flash", contents=types.Content(parts=inputs, role="user"), config=generation_config |
|
) |
|
parsed_response = response.parsed[0] |
|
return format_response(parsed_response) |
|
|
|
|
|
examples = [ |
|
["realistic photo a shiny black SUV car with a mountain in the background.", Image.open("car.jpg")], |
|
["photo a green and funny creature standing in front a lightweight forest.", Image.open("green_creature.jpg")], |
|
] |
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 520px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown( |
|
f"""# Grade images with Gemini 2.0 Flash |
|
|
|
Following aspects are considered during grading: |
|
|
|
* Accuracy to Prompt |
|
* Creativity and Originality |
|
* Visual Quality and Realism |
|
* Consistency and Cohesion |
|
* Emotional or Thematic Resonance |
|
|
|
The [system prompt](./verifier_prompt.txt) comes from the paper: [Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps](https://arxiv.org/abs/2501.09732). |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
prompt = gr.Text( |
|
label="Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter the prompt that generated the image to be graded.", |
|
container=False, |
|
) |
|
run_button = gr.Button("Run", scale=0) |
|
|
|
image = gr.Image(format="png", type="pil", label="Image", placeholder="The image to be graded.") |
|
|
|
result = gr.Markdown(label="Grading Output") |
|
|
|
gr.Examples(examples=examples, fn=grade, inputs=[prompt, image], outputs=[result], cache_examples=True) |
|
|
|
gr.on(triggers=[run_button.click, prompt.submit], fn=grade, inputs=[prompt, image], outputs=[result]) |
|
|
|
demo.launch() |
|
|