Daryl Lim commited on
Commit
e16c83b
·
1 Parent(s): 58c7226

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -26
app.py CHANGED
@@ -1,60 +1,81 @@
1
  """
2
- This module provides an interface for image captioning using the BLIP model.
3
  The interface allows users to upload an image and receive a caption.
4
  """
5
 
6
  import gradio as gr
7
  import spaces
8
- from transformers import BlipProcessor, BlipForConditionalGeneration
9
  from PIL import Image
10
 
 
 
 
 
 
 
11
  # Initialize the processor and model
12
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
13
- model = (
14
- BlipForConditionalGeneration
15
- .from_pretrained("Salesforce/blip-image-captioning-base")
16
- .to("cuda")
17
- )
 
 
 
 
18
 
19
- def generate_caption(image: Image) -> str:
20
  """
21
- Generates a caption for a given image using the BLIP model.
22
 
23
  Args:
24
- image (Image): The input image as a PIL Image object.
25
 
26
  Returns:
27
- str: The generated caption.
28
  """
29
- inputs = processor(images=image, return_tensors="pt").to("cuda")
30
- outputs = model.generate(**inputs)
31
- caption = processor.decode(outputs[0], skip_special_tokens=True)
32
- return caption
 
 
 
 
 
 
33
 
34
  @spaces.GPU
35
- def caption_image(image: Image) -> str:
36
  """
37
  Takes a PIL Image input and returns a caption.
38
 
39
  Args:
40
- image (Image): The input image as a PIL Image object.
41
 
42
  Returns:
43
- str: The generated caption or an error message.
44
  """
45
  try:
46
- return generate_caption(image)
47
- except Exception as e:
48
- return f"An error occurred: {str(e)}"
 
 
 
 
 
49
 
50
- # Define the Gradio interface
51
  demo = gr.Interface(
52
  fn=caption_image,
53
- inputs=gr.Image(type="pil"),
54
- outputs="text",
55
  title="Image Captioning with BLIP",
56
  description="Upload an image to generate a caption."
57
  )
58
 
59
- # Launch the interface
60
  demo.launch()
 
1
  """
2
+ This module provides an interface for image captioning using the BLIP-2 model.
3
  The interface allows users to upload an image and receive a caption.
4
  """
5
 
6
  import gradio as gr
7
  import spaces
8
+ from transformers import BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig
9
  from PIL import Image
10
 
11
+ # Define device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Define quantization configuration
15
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
16
+
17
  # Initialize the processor and model
18
+ try:
19
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
20
+ model = Blip2ForConditionalGeneration.from_pretrained(
21
+ "Salesforce/blip2-opt-6.7b-coco",
22
+ quantization_config=quantization_config, # Quantize model to 8-bit
23
+ device_map="auto", # Efficient GPU utilization
24
+ torch_dtype=torch.float16 # Load weights in float16 to save memory
25
+ ).to(device)
26
+ except Exception as error:
27
+ print(f"Error initializing model: {error}")
28
 
29
+ def generate_caption(image: Image.Image) -> str:
30
  """
31
+ Generates a caption for the given image using the BLIP-2 model.
32
 
33
  Args:
34
+ image (PIL.Image): The input image to generate a caption for.
35
 
36
  Returns:
37
+ str: The generated caption as a string.
38
  """
39
+ if not isinstance(image, Image.Image):
40
+ raise ValueError("Input must be a PIL Image.")
41
+
42
+ try:
43
+ inputs = processor(images=image, return_tensors="pt").to(device)
44
+ outputs = model.generate(**inputs)
45
+ caption = processor.decode(outputs[0], skip_special_tokens=True)
46
+ return caption
47
+ except Exception as error:
48
+ return f"Error generating caption: {str(error)}"
49
 
50
  @spaces.GPU
51
+ def caption_image(image: Image.Image) -> str:
52
  """
53
  Takes a PIL Image input and returns a caption.
54
 
55
  Args:
56
+ image (PIL.Image): The input image to generate a caption for.
57
 
58
  Returns:
59
+ str: The generated caption, or an error message if something goes wrong.
60
  """
61
  try:
62
+ caption = generate_caption(image)
63
+ return caption
64
+ except Exception as error:
65
+ return f"An error occurred: {str(error)}"
66
+
67
+ # Constants for Gradio interface configuration
68
+ IMAGE_TYPE = "pil"
69
+ OUTPUT_TYPE = "text"
70
 
71
+ # Define the Gradio interface for image captioning
72
  demo = gr.Interface(
73
  fn=caption_image,
74
+ inputs=gr.Image(type=IMAGE_TYPE),
75
+ outputs=OUTPUT_TYPE,
76
  title="Image Captioning with BLIP",
77
  description="Upload an image to generate a caption."
78
  )
79
 
80
+ # Launch the Gradio interface
81
  demo.launch()