Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -48,8 +48,8 @@ def segment_to_bbox(segment_indexs):
|
|
48 |
|
49 |
def clipseg_prediction(image):
|
50 |
|
51 |
-
print('
|
52 |
-
img_w, img_h = image.shape
|
53 |
inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
|
54 |
# predict
|
55 |
with torch.no_grad():
|
@@ -81,6 +81,7 @@ def clipseg_prediction(image):
|
|
81 |
|
82 |
@torch.no_grad()
|
83 |
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
|
|
|
84 |
global cache_data
|
85 |
image_input = Image.fromarray(image_input)
|
86 |
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
|
@@ -112,10 +113,9 @@ def main_func(inputs):
|
|
112 |
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
|
113 |
return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
|
114 |
else:
|
|
|
115 |
return Image.fromarray(image_input)
|
116 |
|
117 |
-
return pred_masks
|
118 |
-
|
119 |
def reset_data():
|
120 |
global cache_data
|
121 |
cache_data = None
|
|
|
48 |
|
49 |
def clipseg_prediction(image):
|
50 |
|
51 |
+
print('Clip-Segmentation-started------->')
|
52 |
+
img_w, img_h,_ = image.shape
|
53 |
inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
|
54 |
# predict
|
55 |
with torch.no_grad():
|
|
|
81 |
|
82 |
@torch.no_grad()
|
83 |
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
|
84 |
+
print('SAM-Segmentation-started------->')
|
85 |
global cache_data
|
86 |
image_input = Image.fromarray(image_input)
|
87 |
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
|
|
|
113 |
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
|
114 |
return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
|
115 |
else:
|
116 |
+
print('Prediction:: No vehicle found in the image')
|
117 |
return Image.fromarray(image_input)
|
118 |
|
|
|
|
|
119 |
def reset_data():
|
120 |
global cache_data
|
121 |
cache_data = None
|