Abijith commited on
Commit
b377691
1 Parent(s): 91b984b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -48,8 +48,8 @@ def segment_to_bbox(segment_indexs):
48
 
49
  def clipseg_prediction(image):
50
 
51
- print('image_shape:: ', image.shape)
52
- img_w, img_h = image.shape
53
  inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
54
  # predict
55
  with torch.no_grad():
@@ -81,6 +81,7 @@ def clipseg_prediction(image):
81
 
82
  @torch.no_grad()
83
  def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
 
84
  global cache_data
85
  image_input = Image.fromarray(image_input)
86
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
@@ -112,10 +113,9 @@ def main_func(inputs):
112
  mask_colors[final_mask, :] = np.array([[128, 0, 0]])
113
  return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
114
  else:
 
115
  return Image.fromarray(image_input)
116
 
117
- return pred_masks
118
-
119
  def reset_data():
120
  global cache_data
121
  cache_data = None
 
48
 
49
  def clipseg_prediction(image):
50
 
51
+ print('Clip-Segmentation-started------->')
52
+ img_w, img_h,_ = image.shape
53
  inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
54
  # predict
55
  with torch.no_grad():
 
81
 
82
  @torch.no_grad()
83
  def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
84
+ print('SAM-Segmentation-started------->')
85
  global cache_data
86
  image_input = Image.fromarray(image_input)
87
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
 
113
  mask_colors[final_mask, :] = np.array([[128, 0, 0]])
114
  return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
115
  else:
116
+ print('Prediction:: No vehicle found in the image')
117
  return Image.fromarray(image_input)
118
 
 
 
119
  def reset_data():
120
  global cache_data
121
  cache_data = None