File size: 3,505 Bytes
5041f6c
 
 
 
656934c
b63cc04
656934c
 
5741e23
656934c
 
1112f1b
 
 
656934c
 
 
 
 
 
 
5041f6c
656934c
 
1112f1b
656934c
1112f1b
 
5041f6c
 
1112f1b
656934c
5041f6c
d258d19
656934c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5041f6c
 
 
 
656934c
5041f6c
656934c
 
5041f6c
656934c
 
 
 
 
5041f6c
 
 
8e64bf0
 
656934c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
from t2v_metrics import VQAScore, list_all_vqascore_models

torch.jit.script = lambda f: f  # Avoid script error in lambda

def update_model(model_name):
    return VQAScore(model=model_name, device="cuda")

# Use global variables for model pipe and current model name
global model_pipe, cur_model_name
cur_model_name = "clip-flant5-xl"
model_pipe = update_model(cur_model_name)

# Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
try:
    from spaces import GPU
except ImportError:
    GPU = lambda duration: (lambda f: f)  # Dummy decorator if spaces.GPU is not available

@GPU(duration=20)
def generate(model_name, image, text):
    global model_pipe, cur_model_name
    
    if model_name != cur_model_name:
        cur_model_name = model_name  # Update the current model name
        model_pipe = update_model(model_name)
    
    print("Image:", image)  # Debug: Print image path
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    
    try:
        result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()  # Perform the model inference
        print("Result:", result)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return result

@GPU(duration=20)
def rank_images(model_name, images, text):
    global model_pipe, cur_model_name
    
    if model_name != cur_model_name:
        cur_model_name = model_name  # Update the current model name
        model_pipe = update_model(model_name)
    
    print("Images:", images)  # Debug: Print image paths
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    
    try:
        results = model_pipe(images=images, texts=[text] * len(images)).cpu()[:, 0].tolist()  # Perform the model inference on all images
        ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True)  # Rank results
        ranked_images = [img for img, score in ranked_results]
        print("Ranked Results:", ranked_results)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return ranked_images

# Create the first demo
demo_vqascore = gr.Interface(
    fn=generate,  # function to call
    inputs=[
        gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), 
        gr.Image(type="filepath"), 
        gr.Textbox(label="Prompt")
    ],  # define the types of inputs
    outputs="number",  # define the type of output
    title="VQAScore",  # title of the app
    description="This model evaluates the similarity between an image and a text prompt."
)

# Create the second demo
demo_vqascore_ranking = gr.Interface(
    fn=rank_images,  # function to call
    inputs=[
        gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), 
        gr.Gallery(label="Generated Images"), 
        gr.Textbox(label="Prompt")
    ],  # define the types of inputs
    outputs=gr.Gallery(label="Ranked Images"),  # define the type of output
    title="VQAScore Ranking",  # title of the app
    description="This model ranks a gallery of images based on their similarity to a text prompt."
)

# Combine the demos into a tabbed interface
tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"])

# Launch the tabbed interface
tabbed_interface.queue()
tabbed_interface.launch()