sagar007 commited on
Commit
26c0f04
Β·
verified Β·
1 Parent(s): 3cd1243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -4,6 +4,7 @@ import cv2
4
  import numpy as np
5
  from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
6
  from PIL import Image
 
7
 
8
  # Set up device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -26,6 +27,26 @@ def process_mask(mask, target_size):
26
  mask_image = mask_image.resize(target_size, Image.NEAREST)
27
  return np.array(mask_image) > 0
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def segment_image(input_image, object_name):
30
  try:
31
  if input_image is None:
@@ -36,9 +57,9 @@ def segment_image(input_image, object_name):
36
  if not original_size or 0 in original_size:
37
  return None, "Invalid image size. Please upload a different image."
38
 
39
- # Generate image caption
40
  blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
41
- caption = blip_model.generate(**blip_inputs)
42
  caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
43
 
44
  # Process the image with SAM
@@ -58,15 +79,21 @@ def segment_image(input_image, object_name):
58
  # Find the mask that best matches the specified object
59
  best_mask = None
60
  best_score = -1
 
 
 
 
 
61
  for mask in masks[0]:
62
  mask_binary = mask.numpy() > 0.5
63
- mask_area = mask_binary.sum()
64
- if object_name.lower() in caption_text.lower() and mask_area > best_score:
65
- best_mask = mask_binary
66
- best_score = mask_area
 
67
 
68
  if best_mask is None:
69
- return input_image, f"Could not find '{object_name}' in the image."
70
 
71
  combined_mask = process_mask(best_mask, original_size)
72
 
 
4
  import numpy as np
5
  from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
6
  from PIL import Image
7
+ from scipy.ndimage import label, center_of_mass
8
 
9
  # Set up device
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
27
  mask_image = mask_image.resize(target_size, Image.NEAREST)
28
  return np.array(mask_image) > 0
29
 
30
+ def is_cat_like(mask, image_area):
31
+ labeled, num_features = label(mask)
32
+ if num_features == 0:
33
+ return False
34
+
35
+ largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1))
36
+ area = largest_component.sum()
37
+
38
+ # Check if the area is reasonable for a cat (between 5% and 30% of image)
39
+ if not (0.05 * image_area < area < 0.3 * image_area):
40
+ return False
41
+
42
+ # Check if the shape is roughly elliptical
43
+ cy, cx = center_of_mass(largest_component)
44
+ major_axis = max(largest_component.shape)
45
+ minor_axis = min(largest_component.shape)
46
+ aspect_ratio = major_axis / minor_axis
47
+
48
+ return 1.5 < aspect_ratio < 3 # Most cats have an aspect ratio in this range
49
+
50
  def segment_image(input_image, object_name):
51
  try:
52
  if input_image is None:
 
57
  if not original_size or 0 in original_size:
58
  return None, "Invalid image size. Please upload a different image."
59
 
60
+ # Generate detailed image caption
61
  blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
62
+ caption = blip_model.generate(**blip_inputs, max_length=50)
63
  caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
64
 
65
  # Process the image with SAM
 
79
  # Find the mask that best matches the specified object
80
  best_mask = None
81
  best_score = -1
82
+ image_area = original_size[0] * original_size[1]
83
+
84
+ cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty']
85
+ caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words)
86
+
87
  for mask in masks[0]:
88
  mask_binary = mask.numpy() > 0.5
89
+ if is_cat_like(mask_binary, image_area) and caption_contains_cat:
90
+ mask_area = mask_binary.sum()
91
+ if mask_area > best_score:
92
+ best_mask = mask_binary
93
+ best_score = mask_area
94
 
95
  if best_mask is None:
96
+ return input_image, f"Could not find a suitable '{object_name}' in the image."
97
 
98
  combined_mask = process_mask(best_mask, original_size)
99