shark_detection / app.py
piperod91's picture
Iterating predictions even when no shark is sighted. Improving messages for info
afeb582
raw
history blame
11.6 kB
import subprocess
import os
if os.getenv('SYSTEM') == 'spaces':
subprocess.call('pip install -U openmim'.split())
subprocess.call('pip install python-dotenv'.split())
subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split())
subprocess.call('mim install mmcv>=2.0.0'.split())
subprocess.call('mim install mmengine==0.7.2'.split())
subprocess.call('mim install mmdet==3.0.0'.split())
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
import gradio as gr
from huggingface_hub import snapshot_download
import cv2
import dotenv
dotenv.load_dotenv()
import numpy as np
import gradio as gr
import glob
from inference import inference_frame,inference_frame_serial
from inference import inference_frame_par_ready
from inference import process_frame
from inference import classes
from inference import class_sizes_lower
from metrics import process_results_for_plot
from metrics import prediction_dashboard
import os
import pathlib
import multiprocessing as mp
from time import time
if not os.path.exists('videos_example'):
REPO_ID='SharkSpace/videos_examples'
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
theme = gr.themes.Soft(
primary_hue="sky",
neutral_hue="slate",
)
def add_border(frame, color = (255, 0, 0), thickness = 2):
# Add a red border to the image
relative = max(frame.shape[0],frame.shape[1])
top = int(relative*0.025)
bottom = int(relative*0.025)
left = int(relative*0.025)
right = int(relative*0.025)
# Add the border to the image
bordered_image = cv2.copyMakeBorder(frame, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return bordered_image
def overlay_text_on_image(image, text_list, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=0.5, font_thickness=1, margin=10, color=(255, 255, 255)):
relative = min(image.shape[0],image.shape[1])
y0, dy = margin, int(relative*0.1) # start y position and line gap
for i, line in enumerate(text_list):
y = y0 + i * dy
if 'Shark' in line or 'Human' in line :
text_width, _ = cv2.getTextSize(line, font, font_size*1.5, font_thickness)[0]
cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size*1.2, color, font_thickness, lineType=cv2.LINE_AA)
else:
text_width, _ = cv2.getTextSize(line, font, font_size, font_thickness)[0]
cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size, color, font_thickness, lineType=cv2.LINE_AA)
return image
def overlay_logo(frame,logo, position=(10, 10)):
"""
Overlay a transparent logo (with alpha channel) on a frame.
Parameters:
- frame: The main image/frame to overlay the logo on.
- logo_path: Path to the logo image.
- position: (x, y) tuple indicating where the logo starts (top left corner).
"""
# Load the logo and its alpha channel
alpha_channel = np.ones(logo.shape[:2], dtype=logo.dtype)
print(logo.min(),logo.max())
logo = np.dstack((logo, alpha_channel))
indexes = logo[:,:,1]>150
logo[indexes,3] = 0
l_channels = cv2.split(logo)
if len(l_channels) != 4:
raise ValueError("Logo doesn't have an alpha channel!")
l_b, l_g, l_r, l_alpha = l_channels
cv2.imwrite('l_alpha.png',l_alpha*255)
# Extract regions of interest (ROI) from both images
roi = frame[position[1]:position[1]+logo.shape[0], position[0]:position[0]+logo.shape[1]]
# Blend the logo using the alpha channel
for channel in range(0, 3):
roi[:, :, channel] = (l_alpha ) * l_channels[channel] + (1.0 - l_alpha ) * roi[:, :, channel]
return frame
def add_danger_symbol_from_image(frame, top_pred):
relative = max(frame.shape[0],frame.shape[1])
if top_pred['shark_sighted'] and top_pred['dangerous_dist']:
# Add the danger symbol
danger_symbol = cv2.imread('static/danger_symbol.jpeg')
danger_symbol = cv2.resize(danger_symbol, (int(relative*0.1), int(relative*0.1)), interpolation = cv2.INTER_AREA)[:,:,::-1]
frame = overlay_logo(frame,danger_symbol, position=(int(relative*0.05), int(relative*0.05)))
return frame
def draw_cockpit(frame, top_pred,cnt):
# Bullet points:
high_danger_color = (255,0,0)
low_danger_color = yellowgreen = (154,205,50)
if top_pred['shark_sighted'] > 0:
shark_suspected = 'Shark Sighted !'
elif top_pred['shark_suspected'] > 0:
shark_suspected = 'Shark Suspected !'
else:
shark_suspected = 'No Sharks ...'
if top_pred['human_sighted'] > 0:
human_suspected = 'Human Sighted !'
elif top_pred['human_suspected'] > 0:
human_suspected = 'Human Suspected !'
else:
human_suspected = 'No Humans ...'
shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size']) if top_pred['biggest_shark_size'] else 'Biggest shark size: ...'
shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight']) if top_pred['biggest_shark_weight'] else 'Biggest shark weight: ...'
danger_level = 'Danger Level: '
danger_level += 'High' if top_pred['dangerous_dist_confirmed'] else 'Low'
danger_color = 'orangered' if top_pred['dangerous_dist_confirmed'] else 'yellowgreen'
# Create a list of strings to plot
strings = [shark_suspected, human_suspected, shark_size_estimate, shark_weight_estimate, danger_level]
# shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted'])
# human_sighted = 'Number of Humans: ' + str(top_pred['human_n'])
# shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size'])
# shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight'])
# danger_level = 'Danger Level: '
# danger_level += 'High' if top_pred['dangerous_dist'] else 'Low'
# danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen'
# # Create a list of strings to plot
# strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
relative = max(frame.shape[0],frame.shape[1])
if top_pred['shark_sighted'] and top_pred['dangerous_dist_confirmed'] and cnt%2 == 0:
frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025))
frame = add_danger_symbol_from_image(frame, top_pred)
elif top_pred['shark_sighted'] and not top_pred['dangerous_dist_confirmed'] and cnt%2 == 0:
frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025))
frame = add_danger_symbol_from_image(frame, top_pred)
else:
frame = add_border(frame, color=(0,0,0), thickness=int(relative*0.025))
overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=relative*0.0007, font_thickness=1, margin=int(relative*0.05), color=(255, 255, 255))
return frame
def process_video(input_video, out_fps = 'auto', skip_frames = 7):
print('Processing video: ')
cap = cv2.VideoCapture(input_video)
output_path = "output.mp4"
if out_fps != 'auto' and type(out_fps) == int:
fps = int(out_fps)
else:
fps = int(cap.get(cv2.CAP_PROP_FPS))
if out_fps == 'auto':
fps = int(fps / skip_frames)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if width > 1920 or height > 1080:
width = int(width//4)
height = int(height//4)
video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
iterating, frame = cap.read()
cnt = 0
drawn_count = 0
last_5_shark_detected = np.array([0, 0, 0, 0, 0])
last_5_human_detected = np.array([0, 0, 0, 0, 0])
last_5_dangerous_dist = np.array([0, 0, 0, 0, 0])
while iterating:
print('overall count ', cnt)
if (cnt % skip_frames) == 0:
drawn_count += 1
frame = cv2.resize(frame, (int(width), int(height)))
print('starting Frame: ', cnt)
# flip frame vertically
display_frame, result = inference_frame_serial(frame)
#print(result)
top_pred = process_results_for_plot(predictions = result.numpy(),
classes = classes,
class_sizes = class_sizes_lower)
# add to last 5
last_5_shark_detected[drawn_count % 5] = int(top_pred['shark_n'] > 0)
last_5_human_detected[drawn_count % 5] = int(top_pred['human_n'] > 0)
last_5_dangerous_dist[drawn_count % 5] = int(top_pred['dangerous_dist'] > 0)
top_pred['shark_sighted'] = int(np.sum(last_5_shark_detected) > 3)
top_pred['human_sighted'] = int(np.sum(last_5_human_detected) > 3)
top_pred['dangerous_dist_confirmed'] = int(np.sum(last_5_dangerous_dist) > 3)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
prediction_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
#
#video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
if cnt*skip_frames %2==0:
prediction_frame = cv2.resize(prediction_frame, (int(width), int(height)))
frame = prediction_frame
#if top_pred['shark_sighted'] or top_pred['shark_suspected']:
frame = draw_cockpit(frame, top_pred,cnt*skip_frames)
frame = cv2.resize(frame, (int(width), int(height)))
video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
pred_dashbord = prediction_dashboard(top_pred = top_pred)
drawn_count += 1
#print('sending frame')
print('finalizing frame:',cnt)
#print(pred_dashbord.shape)
#print(frame.shape)
#print(prediction_frame.shape)
#print(width, height)
yield frame , None
cnt += 1
iterating, frame = cap.read()
video.release()
yield None, output_path
with gr.Blocks(theme=theme) as demo:
with gr.Row().style(equal_height=True,height='25%'):
input_video = gr.Video(label="Input")
original_frames = gr.Image(label="Processed Frame").style( height=650)
#processed_frames = gr.Image(label="Shark Engine")
output_video = gr.Video(label="Output Video")
#dashboard = gr.Image(label="Events")
with gr.Row():
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
examples = gr.Examples(samples, inputs=input_video)
process_video_btn = gr.Button("Process Video")
#process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard])
process_video_btn.click(process_video, input_video, [ original_frames, output_video])
demo.queue()
if os.getenv('SYSTEM') == 'spaces':
demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
else:
demo.launch(debug=True)