Spaces:
Runtime error
Runtime error
import subprocess | |
import os | |
if os.getenv('SYSTEM') == 'spaces': | |
subprocess.call('pip install -U openmim'.split()) | |
subprocess.call('pip install python-dotenv'.split()) | |
subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split()) | |
subprocess.call('mim install mmcv>=2.0.0'.split()) | |
subprocess.call('mim install mmengine'.split()) | |
subprocess.call('mim install mmdet'.split()) | |
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split()) | |
subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split()) | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import cv2 | |
import dotenv | |
dotenv.load_dotenv() | |
import numpy as np | |
import gradio as gr | |
import glob | |
from inference import inference_frame,inference_frame_serial | |
from inference import inference_frame_par_ready | |
from inference import process_frame | |
from inference import classes | |
from inference import class_sizes_lower | |
from metrics import process_results_for_plot | |
from metrics import prediction_dashboard | |
import os | |
import pathlib | |
import multiprocessing as mp | |
from time import time | |
if not os.path.exists('videos_example'): | |
REPO_ID='SharkSpace/videos_examples' | |
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example') | |
theme = gr.themes.Soft( | |
primary_hue="sky", | |
neutral_hue="slate", | |
) | |
def add_border(frame, color = (255, 0, 0), thickness = 2): | |
# Add a red border to the image | |
relative = max(frame.shape[0],frame.shape[1]) | |
top = int(relative*0.025) | |
bottom = int(relative*0.025) | |
left = int(relative*0.025) | |
right = int(relative*0.025) | |
# Add the border to the image | |
bordered_image = cv2.copyMakeBorder(frame, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) | |
return bordered_image | |
def overlay_text_on_image(image, text_list, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=0.5, font_thickness=1, margin=10, color=(255, 255, 255)): | |
relative = min(image.shape[0],image.shape[1]) | |
y0, dy = margin, int(relative*0.1) # start y position and line gap | |
for i, line in enumerate(text_list): | |
y = y0 + i * dy | |
text_width, _ = cv2.getTextSize(line, font, font_size, font_thickness)[0] | |
cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size, color, font_thickness, lineType=cv2.LINE_AA) | |
return image | |
def overlay_logo(frame,logo, position=(10, 10)): | |
""" | |
Overlay a transparent logo (with alpha channel) on a frame. | |
Parameters: | |
- frame: The main image/frame to overlay the logo on. | |
- logo_path: Path to the logo image. | |
- position: (x, y) tuple indicating where the logo starts (top left corner). | |
""" | |
# Load the logo and its alpha channel | |
alpha_channel = np.ones(logo.shape[:2], dtype=logo.dtype) | |
print(logo.min(),logo.max()) | |
logo = np.dstack((logo, alpha_channel)) | |
indexes = logo[:,:,1]>150 | |
logo[indexes,3] = 0 | |
l_channels = cv2.split(logo) | |
if len(l_channels) != 4: | |
raise ValueError("Logo doesn't have an alpha channel!") | |
l_b, l_g, l_r, l_alpha = l_channels | |
cv2.imwrite('l_alpha.png',l_alpha*255) | |
# Extract regions of interest (ROI) from both images | |
roi = frame[position[1]:position[1]+logo.shape[0], position[0]:position[0]+logo.shape[1]] | |
# Blend the logo using the alpha channel | |
for channel in range(0, 3): | |
roi[:, :, channel] = (l_alpha ) * l_channels[channel] + (1.0 - l_alpha ) * roi[:, :, channel] | |
return frame | |
def add_danger_symbol_from_image(frame, top_pred): | |
relative = max(frame.shape[0],frame.shape[1]) | |
if top_pred['shark_sighted'] and top_pred['dangerous_dist']: | |
# Add the danger symbol | |
danger_symbol = cv2.imread('static/danger_symbol.jpeg') | |
danger_symbol = cv2.resize(danger_symbol, (int(relative*0.1), int(relative*0.1)), interpolation = cv2.INTER_AREA)[:,:,::-1] | |
frame = overlay_logo(frame,danger_symbol, position=(int(relative*0.05), int(relative*0.05))) | |
return frame | |
def draw_cockpit(frame, top_pred,cnt): | |
# Bullet points: | |
high_danger_color = (255,0,0) | |
low_danger_color = yellowgreen = (154,205,50) | |
shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted']) | |
human_sighted = 'Number of Humans: ' + str(top_pred['human_n']) | |
shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size']) | |
shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight']) | |
danger_level = 'Danger Level: ' | |
danger_level += 'High' if top_pred['dangerous_dist'] else 'Low' | |
danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen' | |
# Create a list of strings to plot | |
strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level] | |
relative = max(frame.shape[0],frame.shape[1]) | |
if top_pred['shark_sighted'] and top_pred['dangerous_dist'] and cnt%2 == 0: | |
frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025)) | |
frame = add_danger_symbol_from_image(frame, top_pred) | |
elif top_pred['shark_sighted'] and not top_pred['dangerous_dist'] and cnt%2 == 0: | |
frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025)) | |
frame = add_danger_symbol_from_image(frame, top_pred) | |
else: | |
frame = add_border(frame, color=(0,0,0), thickness=int(relative*0.025)) | |
overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=relative*0.0007, font_thickness=1, margin=int(relative*0.05), color=(255, 255, 255)) | |
return frame | |
def process_video(input_video, out_fps = 'auto', skip_frames = 7): | |
cap = cv2.VideoCapture(input_video) | |
output_path = "output.mp4" | |
if out_fps != 'auto' and type(out_fps) == int: | |
fps = int(out_fps) | |
else: | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
if out_fps == 'auto': | |
fps = int(fps / skip_frames) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)) | |
iterating, frame = cap.read() | |
cnt = 0 | |
while iterating: | |
print('overall count ', cnt) | |
if (cnt % skip_frames) == 0: | |
print('starting Frame: ', cnt) | |
# flip frame vertically | |
display_frame, result = inference_frame_serial(frame) | |
#print(result) | |
top_pred = process_results_for_plot(predictions = result.numpy(), | |
classes = classes, | |
class_sizes = class_sizes_lower) | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
prediction_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB) | |
# | |
#video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
if cnt*skip_frames %2==0 and top_pred['shark_sighted']: | |
prediction_frame = cv2.resize(prediction_frame, (int(width), int(height))) | |
frame =prediction_frame | |
if top_pred['shark_sighted']: | |
frame = draw_cockpit(frame, top_pred,cnt*skip_frames) | |
frame = cv2.resize(frame, (int(width), int(height))) | |
video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
pred_dashbord = prediction_dashboard(top_pred = top_pred) | |
#print('sending frame') | |
print('finalizing frame:',cnt) | |
print(pred_dashbord.shape) | |
print(frame.shape) | |
print(prediction_frame.shape) | |
yield frame , None | |
cnt += 1 | |
iterating, frame = cap.read() | |
video.release() | |
yield None, output_path | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Row().style(equal_height=True,height='25%'): | |
input_video = gr.Video(label="Input") | |
original_frames = gr.Image(label="Processed Frame").style( height=650) | |
#processed_frames = gr.Image(label="Shark Engine") | |
output_video = gr.Video(label="Output Video") | |
#dashboard = gr.Image(label="Events") | |
with gr.Row(): | |
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4')) | |
samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)] | |
examples = gr.Examples(samples, inputs=input_video) | |
process_video_btn = gr.Button("Process Video") | |
#process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard]) | |
process_video_btn.click(process_video, input_video, [ original_frames, output_video]) | |
demo.queue() | |
if os.getenv('SYSTEM') == 'spaces': | |
demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD'))) | |
else: | |
demo.launch() | |