File size: 3,690 Bytes
70de5aa
aefcd17
70de5aa
 
ae059f3
aefcd17
70de5aa
 
 
aefcd17
 
 
 
70de5aa
aefcd17
fd49881
 
aefcd17
 
 
 
 
fd49881
 
 
70de5aa
b6d576b
 
067841b
 
 
 
b6d576b
 
522a810
 
aefcd17
57f8bde
067841b
 
 
522a810
 
067841b
 
522a810
 
 
 
 
70de5aa
 
74ed7fe
0a3520e
465440b
0a3520e
067841b
0a3520e
2f7ffcd
 
 
 
 
 
 
 
067841b
 
 
 
8ebcb0c
2f7ffcd
aefcd17
067841b
70de5aa
 
42dea4a
8ebcb0c
 
 
2f7ffcd
067841b
 
522a810
067841b
 
 
 
aefcd17
067841b
 
aefcd17
522a810
067841b
70de5aa
522a810
70de5aa
067841b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import requests 
from PIL import Image
from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel
import gradio as gr
import os
from concurrent.futures import ThreadPoolExecutor

# Load the model, tokenizer, and image processor with error handling
def load_model_and_components(model_name):
    model = VisionEncoderDecoderModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    image_processor = AutoImageProcessor.from_pretrained(model_name)
    return model, tokenizer, image_processor

# Preload both models in parallel
def preload_models():
    models = {}
    model_names = ["laicsiifes/swin-distilbertimbau", "laicsiifes/swin-gportuguese-2"]
    with ThreadPoolExecutor() as executor:
        results = executor.map(load_model_and_components, model_names)
    for name, result in zip(model_names, results):
        models[name] = result
    return models

models = preload_models()

# Predefined images for selection
image_folder = "images"
predefined_images = [
    Image.open(os.path.join(image_folder, fname)).convert("RGB")
    for fname in os.listdir(image_folder) \
    if  fname.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.ppm')) 
]

# Function to preprocess the image to RGB format
def preprocess_image(image):
    if image is None:
        return None, None
    pil_image = image.convert("RGB")
    return pil_image, None

# Function to process the image and generate a caption
def generate_caption(image, selected_model):
    if image is None:
        return "Please upload an image to generate a caption."
    model, tokenizer, image_processor = models[selected_model]
    pixel_values = image_processor(image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return caption

# Define UI
with gr.Blocks(theme=gr.themes.Citrus(primary_hue="blue", secondary_hue="orange")) as interface:
    gr.Markdown("""
        # Welcome to the LAICSI-IFES space for Vision Encoder-Decoder (VED) demonstration
        ---
        ### Select an available model: Swin-DistilBERTimbau (168M) or Swin-GPorTuguese-2 (240M)
    """)
    with gr.Row(variant='panel'):
        with gr.Column():
            model_selector = gr.Dropdown(
                choices=list(models.keys()),
                value="laicsiifes/swin-distilbertimbau",
                label="Select Model"
            )
            
    gr.Markdown("""
        ---
        ### Upload image or example images below, and click `Generate`
        """)
    
    with gr.Row(variant='panel'):
        with gr.Column():
            image_display = gr.Image(type="pil", label="Image Preview", image_mode="RGB", height=400)
        with gr.Column():
            output_text = gr.Textbox(label="Generated Caption")
            generate_button = gr.Button("Generate")
            
    gr.Markdown("""---""")
    
    with gr.Row(variant='panel'):
            examples = gr.Examples(
                examples=predefined_images,
                fn=preprocess_image,
                inputs=[image_display],
                outputs=[image_display, output_text],
                label="Examples"
            )

    # Define actions
    model_selector.change(fn=lambda: (None, None), outputs=[image_display, output_text])

    image_display.upload(fn=preprocess_image, inputs=[image_display], outputs=[image_display, output_text])
    image_display.clear(fn=lambda: None, outputs=[output_text])

    generate_button.click(fn=generate_caption, inputs=[image_display, model_selector], outputs=output_text)

    interface.launch(share=False)