|
import requests |
|
from PIL import Image, UnidentifiedImageError |
|
from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel |
|
import gradio as gr |
|
|
|
|
|
def load_model_and_components(model_name): |
|
try: |
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
image_processor = AutoImageProcessor.from_pretrained(model_name) |
|
return model, tokenizer, image_processor |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model components: {e}") |
|
|
|
current_model_name = "laicsiifes/swin-distilbertimbau" |
|
model, tokenizer, image_processor = load_model_and_components(current_model_name) |
|
|
|
|
|
def generate_caption(image): |
|
try: |
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values |
|
generated_ids = model.generate(pixel_values) |
|
caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return caption |
|
except Exception: |
|
return "Please upload a valid image." |
|
|
|
|
|
predefined_images_urls = [ |
|
"http://images.cocodataset.org/val2014/COCO_val2014_000000458153.jpg", |
|
"http://images.cocodataset.org/val2014/COCO_val2014_000000000074.jpg" |
|
] |
|
|
|
|
|
predefined_images = [] |
|
for url in predefined_images_urls: |
|
try: |
|
predefined_images.append(Image.open(requests.get(url, stream=True).raw)) |
|
except Exception as e: |
|
print(f"Error loading predefined image from {url}: {e}") |
|
|
|
|
|
def app(image=None, predefined_image=None): |
|
try: |
|
if predefined_image is not None: |
|
image = predefined_image |
|
elif image is None: |
|
return "Please upload a valid image." |
|
return generate_caption(image) |
|
except Exception: |
|
return "Please upload a valid image." |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown(""" |
|
# Welcome to the LAICSI-IFES space for Vision Encoder-Decoder (VED) demonstration |
|
|
|
### Be patient with the Swin-GPorTuguese-2 as it is heavier than the Swin-DistilBERTimbau. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
model_selector = gr.Dropdown(choices=["laicsiifes/swin-distilbertimbau", "laicsiifes/swin-gportuguese-2"], |
|
value="laicsiifes/swin-distilbertimbau", |
|
label="Select Model") |
|
loading_message = gr.Textbox(label="Status Message") |
|
image_display = gr.Image(type="pil", label="Image Preview", interactive=False) |
|
upload_button = gr.File(label="Upload an Image", file_types=["image"], type="filepath") |
|
predefined_images_display = gr.Gallery(predefined_images_urls, label="Choose a Predefined Image") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Generated Caption") |
|
|
|
|
|
def handle_uploaded_image(image): |
|
try: |
|
if image is None: |
|
return None, "Please upload a valid image." |
|
pil_image = Image.open(image) |
|
return pil_image, generate_caption(pil_image) |
|
except Exception: |
|
return None, "Please upload a valid image." |
|
|
|
def handle_predefined_image(evt: gr.SelectData, _): |
|
try: |
|
if not evt: |
|
return None, "Please upload a valid image." |
|
pil_image = Image.open(requests.get(evt.value['image']['url'], stream=True).raw) |
|
return pil_image, generate_caption(pil_image) |
|
except Exception: |
|
return None, "Please upload a valid image." |
|
|
|
def switch_model(selected_model): |
|
gr.Info("Loading model... Please wait.") |
|
return "Loading model... Please wait.", None, None, None |
|
|
|
def load_new_model(selected_model): |
|
global model, tokenizer, image_processor |
|
model, tokenizer, image_processor = load_model_and_components(selected_model) |
|
return "Model loaded successfully.", None, None, None |
|
|
|
model_selector.change(fn=switch_model, inputs=model_selector, outputs=[loading_message, upload_button, image_display, output_text]) |
|
model_selector.change(fn=load_new_model, inputs=model_selector, outputs=[loading_message, upload_button, image_display, output_text]) |
|
upload_button.change(fn=handle_uploaded_image, inputs=upload_button, outputs=[image_display, output_text]) |
|
predefined_images_display.select(fn=handle_predefined_image, inputs=predefined_images_display, outputs=[image_display, output_text]) |
|
|
|
interface.launch() |
|
|