Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,505 Bytes
5041f6c 656934c b63cc04 656934c 5741e23 656934c 1112f1b 656934c 5041f6c 656934c 1112f1b 656934c 1112f1b 5041f6c 1112f1b 656934c 5041f6c d258d19 656934c 5041f6c 656934c 5041f6c 656934c 5041f6c 656934c 5041f6c 8e64bf0 656934c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import torch
from t2v_metrics import VQAScore, list_all_vqascore_models
torch.jit.script = lambda f: f # Avoid script error in lambda
def update_model(model_name):
return VQAScore(model=model_name, device="cuda")
# Use global variables for model pipe and current model name
global model_pipe, cur_model_name
cur_model_name = "clip-flant5-xl"
model_pipe = update_model(cur_model_name)
# Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
try:
from spaces import GPU
except ImportError:
GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
@GPU(duration=20)
def generate(model_name, image, text):
global model_pipe, cur_model_name
if model_name != cur_model_name:
cur_model_name = model_name # Update the current model name
model_pipe = update_model(model_name)
print("Image:", image) # Debug: Print image path
print("Text:", text) # Debug: Print text input
print("Using model:", model_name)
try:
result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
print("Result:", result)
except RuntimeError as e:
print(f"RuntimeError during model inference: {e}")
raise e
return result
@GPU(duration=20)
def rank_images(model_name, images, text):
global model_pipe, cur_model_name
if model_name != cur_model_name:
cur_model_name = model_name # Update the current model name
model_pipe = update_model(model_name)
print("Images:", images) # Debug: Print image paths
print("Text:", text) # Debug: Print text input
print("Using model:", model_name)
try:
results = model_pipe(images=images, texts=[text] * len(images)).cpu()[:, 0].tolist() # Perform the model inference on all images
ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results
ranked_images = [img for img, score in ranked_results]
print("Ranked Results:", ranked_results)
except RuntimeError as e:
print(f"RuntimeError during model inference: {e}")
raise e
return ranked_images
# Create the first demo
demo_vqascore = gr.Interface(
fn=generate, # function to call
inputs=[
gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"),
gr.Image(type="filepath"),
gr.Textbox(label="Prompt")
], # define the types of inputs
outputs="number", # define the type of output
title="VQAScore", # title of the app
description="This model evaluates the similarity between an image and a text prompt."
)
# Create the second demo
demo_vqascore_ranking = gr.Interface(
fn=rank_images, # function to call
inputs=[
gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"),
gr.Gallery(label="Generated Images"),
gr.Textbox(label="Prompt")
], # define the types of inputs
outputs=gr.Gallery(label="Ranked Images"), # define the type of output
title="VQAScore Ranking", # title of the app
description="This model ranks a gallery of images based on their similarity to a text prompt."
)
# Combine the demos into a tabbed interface
tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"])
# Launch the tabbed interface
tabbed_interface.queue()
tabbed_interface.launch()
|