File size: 2,245 Bytes
0398e5a
e16c83b
0398e5a
 
 
 
 
a6773ea
802f3ff
0398e5a
 
e16c83b
 
 
0398e5a
e16c83b
 
 
9d456e3
e16c83b
 
 
0398e5a
e16c83b
0398e5a
e16c83b
0398e5a
 
e16c83b
0398e5a
 
e16c83b
0398e5a
e16c83b
 
 
 
 
 
 
 
 
 
0398e5a
 
e16c83b
0398e5a
 
 
 
e16c83b
0398e5a
 
e16c83b
0398e5a
 
e16c83b
 
 
 
 
 
 
802f3ff
0398e5a
e16c83b
0398e5a
 
e16c83b
 
802f3ff
0398e5a
 
 
e16c83b
0398e5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
This module provides an interface for image captioning using the BLIP-2 model.
The interface allows users to upload an image and receive a caption.
"""

import gradio as gr
import spaces
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the processor and model
try:
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-6.7b-coco"
    ).to(device)
except Exception as error:
    print(f"Error initializing model: {error}")

def generate_caption(image: Image.Image) -> str:
    """
    Generates a caption for the given image using the BLIP-2 model.

    Args:
        image (PIL.Image): The input image to generate a caption for.

    Returns:
        str: The generated caption as a string.
    """
    if not isinstance(image, Image.Image):
        raise ValueError("Input must be a PIL Image.")
    
    try:
        inputs = processor(images=image, return_tensors="pt").to(device)
        outputs = model.generate(**inputs)
        caption = processor.decode(outputs[0], skip_special_tokens=True)
        return caption
    except Exception as error:
        return f"Error generating caption: {str(error)}"

@spaces.GPU
def caption_image(image: Image.Image) -> str:
    """
    Takes a PIL Image input and returns a caption.

    Args:
        image (PIL.Image): The input image to generate a caption for.

    Returns:
        str: The generated caption, or an error message if something goes wrong.
    """
    try:
        caption = generate_caption(image)
        return caption
    except Exception as error:
        return f"An error occurred: {str(error)}"

# Constants for Gradio interface configuration
IMAGE_TYPE = "pil"
OUTPUT_TYPE = gr.Textbox(label="Caption")

# Define the Gradio interface for image captioning
demo = gr.Interface(
    fn=caption_image,
    inputs=gr.Image(type=IMAGE_TYPE),
    outputs=OUTPUT_TYPE,
    title="Image Captioning with BLIP-2",
    description="Upload an image to generate a caption."
)

# Launch the Gradio interface
demo.launch()