Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import subprocess | |
from huggingface_hub import snapshot_download | |
REPO_ID='SharkSpace/videos_examples' | |
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example') | |
if os.getenv('SYSTEM') == 'spaces': | |
subprocess.call('pip install -U openmim'.split()) | |
subprocess.call('pip install python-dotenv'.split()) | |
subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split()) | |
subprocess.call('mim install mmcv>=2.0.0'.split()) | |
subprocess.call('mim install mmengine'.split()) | |
subprocess.call('mim install mmdet'.split()) | |
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split()) | |
subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split()) | |
import cv2 | |
import dotenv | |
dotenv.load_dotenv() | |
import numpy as np | |
import gradio as gr | |
import glob | |
from inference import inference_frame,inference_frame_serial | |
from inference import inference_frame_par_ready | |
from inference import process_frame | |
import os | |
import pathlib | |
import multiprocessing as mp | |
from time import time | |
def analize_video_serial(x): | |
print(x) | |
path = '/tmp/test/' | |
os.makedirs(path, exist_ok=True) | |
videos = len(os.listdir(path)) | |
path = f'{path}{videos}' | |
os.makedirs(path, exist_ok=True) | |
outname = f'{path}_processed.mp4' | |
if os.path.exists(outname): | |
print('video already processed') | |
return outname | |
cap = cv2.VideoCapture(x) | |
counter = 0 | |
import pdb;pdb.set_trace() | |
while(cap.isOpened()): | |
ret, frame = cap.read() | |
yield None, frame | |
if ret==True: | |
name = os.path.join(path,f'{counter:05d}.png') | |
frame = inference_frame_serial(frame) | |
# write the flipped frame | |
cv2.imwrite(name, frame) | |
counter +=1 | |
#yield None,frame | |
else: | |
break | |
# Release everything if job is finished | |
print(path) | |
os.system(f'''ffmpeg -framerate 20 -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''') | |
return outname,frame | |
def analyze_video_parallel(x, skip_frames = 5, | |
frame_rate_out = 8, batch_size = 16): | |
print(x) | |
#Define path to saved images | |
path = '/tmp/test/' | |
os.makedirs(path, exist_ok=True) | |
# Define name of current video as number of videos in path | |
n_videos_in_path = len(os.listdir(path)) | |
path = f'{path}{n_videos_in_path}' | |
os.makedirs(path, exist_ok=True) | |
# Define name of output video | |
outname = f'{path}_processed.mp4' | |
if os.path.exists(outname): | |
print('video already processed') | |
return outname | |
cap = cv2.VideoCapture(x) | |
counter = 0 | |
pred_results_all = [] | |
frames_all = [] | |
while(cap.isOpened()): | |
frames = [] | |
#start = time() | |
while len(frames) < batch_size: | |
#start = time() | |
ret, frame = cap.read() | |
if ret == False: | |
break | |
elif counter % skip_frames == 0: | |
frames.append(frame) | |
counter += 1 | |
#print(f'read time: {time()-start}') | |
frames_all.extend(frames) | |
# Get timing for inference | |
start = time() | |
print('len frames passed: ', len(frames)) | |
if len(frames) > 0: | |
pred_results = inference_frame_par_ready(frames) | |
print(f'inference time: {time()-start}') | |
pred_results_all.extend(pred_results) | |
# break while loop when return of the image reader is False | |
if ret == False: | |
break | |
print('exited prediction loop') | |
# Release everything if job is finished | |
cap.release() | |
start = time() | |
pool = mp.Pool(mp.cpu_count()-2) | |
pool_out = pool.map(process_frame, | |
list(zip(pred_results_all, | |
frames_all, | |
[i for i in range(len(pred_results_all))]))) | |
pool.close() | |
print(f'pool time: {time()-start}') | |
start = time() | |
counter = 0 | |
for pool_out_tmp in pool_out: | |
name = os.path.join(path,f'{counter:05d}.png') | |
cv2.imwrite(name, pool_out_tmp) | |
counter +=1 | |
yield None,pool_out_tmp | |
print(f'write time: {time()-start}') | |
# Create video from predicted images | |
print(path) | |
os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''') | |
return outname, pool_out_tmp | |
def set_example_image(example: list) -> dict: | |
return gr.Video.update(value=example[0]) | |
def show_video(example: list) -> dict: | |
return gr.Video.update(value=example[0]) | |
with gr.Blocks(title='Shark Patrol',theme=gr.themes.Soft(),live=True,) as demo: | |
gr.Markdown("Alpha Demo of the Sharkpatrol Oceanlife Detector.") | |
with gr.Tab("Preloaded Examples"): | |
with gr.Row(): | |
video_example = gr.Video(source='upload',include_audio=False,stream=True) | |
with gr.Row(): | |
paths = sorted(pathlib.Path('videos_example/').rglob('*rgb.mp4')) | |
example_preds = gr.Dataset(components=[video_example], | |
samples=[[path.as_posix()] | |
for path in paths]) | |
example_preds.click(fn=show_video, | |
inputs=example_preds, | |
outputs=video_example) | |
with gr.Tab("Test your own Video"): | |
with gr.Row(): | |
video_input = gr.Video(source='upload',include_audio=False) | |
#video_input.style(witdh='50%',height='50%') | |
image_temp = gr.Image() | |
with gr.Row(): | |
video_output = gr.Video() | |
#video_output.style(witdh='50%',height='50%') | |
video_button = gr.Button("Analyze your Video") | |
with gr.Row(): | |
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4')) | |
example_images = gr.Dataset(components=[video_input], | |
samples=[[path.as_posix()] | |
for path in paths if 'raw_videos' in str(path)]) | |
video_button.click(analize_video_serial, inputs=video_input, outputs=[video_output,image_temp]) | |
example_images.click(fn=set_example_image, | |
inputs=example_images, | |
outputs=video_input) | |
demo.queue() | |
if os.getenv('SYSTEM') == 'spaces': | |
demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD'))) | |
else: | |
demo.launch() | |