Spaces:
Sleeping
Sleeping
""" | |
This module provides an interface for image captioning using the BLIP-2 model. | |
The interface allows users to upload an image and receive a caption. | |
""" | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
from PIL import Image | |
# Define device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize the processor and model | |
try: | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b-coco") | |
model = Blip2ForConditionalGeneration.from_pretrained( | |
"Salesforce/blip2-opt-6.7b-coco" | |
).to(device) | |
except Exception as error: | |
print(f"Error initializing model: {error}") | |
def generate_caption(image: Image.Image) -> str: | |
""" | |
Generates a caption for the given image using the BLIP-2 model. | |
Args: | |
image (PIL.Image): The input image to generate a caption for. | |
Returns: | |
str: The generated caption as a string. | |
""" | |
if not isinstance(image, Image.Image): | |
raise ValueError("Input must be a PIL Image.") | |
try: | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
outputs = model.generate(**inputs) | |
caption = processor.decode(outputs[0], skip_special_tokens=True) | |
return caption | |
except Exception as error: | |
return f"Error generating caption: {str(error)}" | |
def caption_image(image: Image.Image) -> str: | |
""" | |
Takes a PIL Image input and returns a caption. | |
Args: | |
image (PIL.Image): The input image to generate a caption for. | |
Returns: | |
str: The generated caption, or an error message if something goes wrong. | |
""" | |
try: | |
caption = generate_caption(image) | |
return caption | |
except Exception as error: | |
return f"An error occurred: {str(error)}" | |
# Constants for Gradio interface configuration | |
IMAGE_TYPE = "pil" | |
OUTPUT_TYPE = gr.Textbox(label="Caption") | |
# Define the Gradio interface for image captioning | |
demo = gr.Interface( | |
fn=caption_image, | |
inputs=gr.Image(type=IMAGE_TYPE), | |
outputs=OUTPUT_TYPE, | |
title="Image Captioning with BLIP-2", | |
description="Upload an image to generate a caption." | |
) | |
# Launch the Gradio interface | |
demo.launch() | |