Daryl Lim
Update app.py
802f3ff
"""
This module provides an interface for image captioning using the BLIP-2 model.
The interface allows users to upload an image and receive a caption.
"""
import gradio as gr
import spaces
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the processor and model
try:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-6.7b-coco"
).to(device)
except Exception as error:
print(f"Error initializing model: {error}")
def generate_caption(image: Image.Image) -> str:
"""
Generates a caption for the given image using the BLIP-2 model.
Args:
image (PIL.Image): The input image to generate a caption for.
Returns:
str: The generated caption as a string.
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image.")
try:
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model.generate(**inputs)
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
except Exception as error:
return f"Error generating caption: {str(error)}"
@spaces.GPU
def caption_image(image: Image.Image) -> str:
"""
Takes a PIL Image input and returns a caption.
Args:
image (PIL.Image): The input image to generate a caption for.
Returns:
str: The generated caption, or an error message if something goes wrong.
"""
try:
caption = generate_caption(image)
return caption
except Exception as error:
return f"An error occurred: {str(error)}"
# Constants for Gradio interface configuration
IMAGE_TYPE = "pil"
OUTPUT_TYPE = gr.Textbox(label="Caption")
# Define the Gradio interface for image captioning
demo = gr.Interface(
fn=caption_image,
inputs=gr.Image(type=IMAGE_TYPE),
outputs=OUTPUT_TYPE,
title="Image Captioning with BLIP-2",
description="Upload an image to generate a caption."
)
# Launch the Gradio interface
demo.launch()