Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
|
6 |
from PIL import Image
|
|
|
7 |
|
8 |
# Set up device
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -26,6 +27,26 @@ def process_mask(mask, target_size):
|
|
26 |
mask_image = mask_image.resize(target_size, Image.NEAREST)
|
27 |
return np.array(mask_image) > 0
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def segment_image(input_image, object_name):
|
30 |
try:
|
31 |
if input_image is None:
|
@@ -36,9 +57,9 @@ def segment_image(input_image, object_name):
|
|
36 |
if not original_size or 0 in original_size:
|
37 |
return None, "Invalid image size. Please upload a different image."
|
38 |
|
39 |
-
# Generate image caption
|
40 |
blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
|
41 |
-
caption = blip_model.generate(**blip_inputs)
|
42 |
caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
|
43 |
|
44 |
# Process the image with SAM
|
@@ -58,15 +79,21 @@ def segment_image(input_image, object_name):
|
|
58 |
# Find the mask that best matches the specified object
|
59 |
best_mask = None
|
60 |
best_score = -1
|
|
|
|
|
|
|
|
|
|
|
61 |
for mask in masks[0]:
|
62 |
mask_binary = mask.numpy() > 0.5
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
if best_mask is None:
|
69 |
-
return input_image, f"Could not find '{object_name}' in the image."
|
70 |
|
71 |
combined_mask = process_mask(best_mask, original_size)
|
72 |
|
|
|
4 |
import numpy as np
|
5 |
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
|
6 |
from PIL import Image
|
7 |
+
from scipy.ndimage import label, center_of_mass
|
8 |
|
9 |
# Set up device
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
27 |
mask_image = mask_image.resize(target_size, Image.NEAREST)
|
28 |
return np.array(mask_image) > 0
|
29 |
|
30 |
+
def is_cat_like(mask, image_area):
|
31 |
+
labeled, num_features = label(mask)
|
32 |
+
if num_features == 0:
|
33 |
+
return False
|
34 |
+
|
35 |
+
largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1))
|
36 |
+
area = largest_component.sum()
|
37 |
+
|
38 |
+
# Check if the area is reasonable for a cat (between 5% and 30% of image)
|
39 |
+
if not (0.05 * image_area < area < 0.3 * image_area):
|
40 |
+
return False
|
41 |
+
|
42 |
+
# Check if the shape is roughly elliptical
|
43 |
+
cy, cx = center_of_mass(largest_component)
|
44 |
+
major_axis = max(largest_component.shape)
|
45 |
+
minor_axis = min(largest_component.shape)
|
46 |
+
aspect_ratio = major_axis / minor_axis
|
47 |
+
|
48 |
+
return 1.5 < aspect_ratio < 3 # Most cats have an aspect ratio in this range
|
49 |
+
|
50 |
def segment_image(input_image, object_name):
|
51 |
try:
|
52 |
if input_image is None:
|
|
|
57 |
if not original_size or 0 in original_size:
|
58 |
return None, "Invalid image size. Please upload a different image."
|
59 |
|
60 |
+
# Generate detailed image caption
|
61 |
blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
|
62 |
+
caption = blip_model.generate(**blip_inputs, max_length=50)
|
63 |
caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
|
64 |
|
65 |
# Process the image with SAM
|
|
|
79 |
# Find the mask that best matches the specified object
|
80 |
best_mask = None
|
81 |
best_score = -1
|
82 |
+
image_area = original_size[0] * original_size[1]
|
83 |
+
|
84 |
+
cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty']
|
85 |
+
caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words)
|
86 |
+
|
87 |
for mask in masks[0]:
|
88 |
mask_binary = mask.numpy() > 0.5
|
89 |
+
if is_cat_like(mask_binary, image_area) and caption_contains_cat:
|
90 |
+
mask_area = mask_binary.sum()
|
91 |
+
if mask_area > best_score:
|
92 |
+
best_mask = mask_binary
|
93 |
+
best_score = mask_area
|
94 |
|
95 |
if best_mask is None:
|
96 |
+
return input_image, f"Could not find a suitable '{object_name}' in the image."
|
97 |
|
98 |
combined_mask = process_mask(best_mask, original_size)
|
99 |
|