|
import torch |
|
import os |
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
os.system("rm -r /home/user/app/gradio_cached_examples/14") |
|
try: |
|
import detectron2 |
|
except: |
|
|
|
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') |
|
|
|
|
|
os.system('python -m pip install -e detectron2') |
|
import sys |
|
import cv2 |
|
import os |
|
import glob |
|
import shutil |
|
import gdown |
|
import zipfile |
|
|
|
import time |
|
import random |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from pathlib import Path |
|
sys.path.insert(1, "MEMTrack/src") |
|
from data_prep_utils import process_data |
|
from data_feature_gen import create_train_data, create_test_data |
|
from inferenceBacteriaRetinanet_Motility_v2 import run_inference |
|
from GenerateTrackingData import gen_tracking_data |
|
from Tracking import track_bacteria |
|
from TrackingAnalysis import analyse_tracking |
|
from GenerateVideo import gen_tracking_video |
|
|
|
def find_and_return_csv_files(folder_path, search_pattern): |
|
search_pattern = f"{folder_path}/{search_pattern}*.csv" |
|
csv_files = list(glob.glob(search_pattern)) |
|
return csv_files |
|
|
|
def read_video(video, raw_frame_dir, progress=gr.Progress()): |
|
|
|
video_dir = str(random.randint(111111111, 999999999)) |
|
images_dir = "Images without Labels" |
|
frames_dir = os.path.join(raw_frame_dir, video_dir, images_dir) |
|
os.makedirs(frames_dir, exist_ok=True) |
|
count = 0 |
|
frames = [] |
|
|
|
cap = cv2.VideoCapture(video) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
processed_frames = 0 |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
|
|
if ret is False: |
|
break |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
frame_path = os.path.join(frames_dir, f"{count}.jpg") |
|
cv2.imwrite(frame_path, frame) |
|
frames.append(frame) |
|
count += 1 |
|
processed_frames += 1 |
|
print(f"Processing frame {processed_frames}") |
|
progress(processed_frames / total_frames, desc=f"Reading frame {processed_frames}/{total_frames}") |
|
|
|
|
|
cap.release() |
|
return video_dir |
|
|
|
def download_and_unzip_google_drive_file(file_id, output_path, unzip_path): |
|
url = f'https://drive.google.com/uc?id={file_id}' |
|
url="https://drive.usercontent.google.com/download?id=1agsLD5HV_VmDNpDhjHXTCAVmGUm2IQ6p&export=download&&confirm=t" |
|
gdown.download(url, output_path, quiet=False, ) |
|
|
|
with zipfile.ZipFile(output_path, 'r') as zip_ref: |
|
zip_ref.extractall(unzip_path) |
|
|
|
def clear_form(): |
|
return None, None, None, 60, 60 |
|
|
|
|
|
def doo(video, tiff_stack, images, fps, min_track_length, progress=gr.Progress()): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_frame_dir = "raw_data/" |
|
final_data_dir = "data" |
|
out_sub_dir = "bacteria" |
|
target_data_sub_dir = os.path.join(final_data_dir, out_sub_dir) |
|
feature_dir = "DataFeatures" |
|
test_video_list = ["video1"] |
|
exp_name = "collagen_motility_inference" |
|
feature_data_path = os.path.join(feature_dir, exp_name) |
|
min_track_length = min_track_length |
|
|
|
|
|
|
|
no_motility_model_path = "models/motility/no/collagen_optical_flow_median_bkg_more_data_90k/" |
|
low_motility_model_path = "models/motility/low/collagen_optical_flow_median_bkg_more_data_90k/" |
|
mid_motility_model_path = "models/motility/mid/collagen_optical_flow_median_bkg_more_data_90k/" |
|
high_motility_model_path = "models/motility/high/collagen_optical_flow_median_bkg_more_data_90k/" |
|
|
|
|
|
|
|
if os.path.exists(final_data_dir): |
|
shutil.rmtree(final_data_dir) |
|
if os.path.exists(raw_frame_dir): |
|
shutil.rmtree(raw_frame_dir) |
|
print("deleted raw_frame_dir") |
|
if os.path.exists(feature_dir): |
|
shutil.rmtree(feature_dir) |
|
print("deleted feature dir") |
|
|
|
print("check dirs") |
|
print(os.listdir(".")) |
|
|
|
|
|
video_dir = read_video(video, raw_frame_dir, progress=gr.Progress()) |
|
|
|
progress(1 / 3, desc=f"Processing Frames {1}/{3}") |
|
video_num = process_data(video_dir, raw_frame_dir, final_data_dir, out_sub_dir) |
|
progress(3 / 3, desc=f"Processing Frames {3}/{3}") |
|
|
|
progress(1 / 3, desc=f"Generating Features {1}/{3}") |
|
create_test_data(target_data_sub_dir, feature_dir, exp_name, test_video_list) |
|
progress(3 / 3, desc=f"Features Generated {3}/{3}") |
|
|
|
|
|
progress(1 / 3, desc=f"Loading Models {1}/{3}") |
|
|
|
for video_num in [1]: |
|
|
|
run_inference(video_num=video_num, output_dir=no_motility_model_path, |
|
annotations_test="All", test_dir=feature_data_path, register_dataset=True) |
|
progress(3 / 3, desc=f"Models Loaded{3}/{3}") |
|
run_inference(video_num=video_num, output_dir=mid_motility_model_path, |
|
annotations_test="Motility-mid", test_dir=feature_data_path, register_dataset=False) |
|
progress(1 / 3, desc=f"Running Bacteria Detection {1}/{3}") |
|
|
|
run_inference(video_num=video_num, output_dir=high_motility_model_path, |
|
annotations_test="Motility-high", test_dir=feature_data_path, register_dataset=False) |
|
progress(2 / 3, desc=f"Running Bacteria Detection {2}/{3}") |
|
|
|
run_inference(video_num=video_num, output_dir=low_motility_model_path, |
|
annotations_test="Motility-low", test_dir=feature_data_path, register_dataset=False) |
|
progress(3 / 3, desc=f"Running Bacteria Detection {3}/{3}") |
|
|
|
|
|
progress(0 / 3, desc=f"Tracking {0}/{3}") |
|
for video_num in [1]: |
|
gen_tracking_data(video_num=video_num, data_path=feature_data_path, filter_thresh=0.3) |
|
progress(1 / 3, desc=f"Tracking {1}/{3}") |
|
track_bacteria(video_num=video_num, max_age=35, max_interpolation=35, data_path=feature_data_path) |
|
progress(2 / 3, desc=f"Tracking {2}/{3}") |
|
folder_path = analyse_tracking(video_num=video_num, min_track_length=min_track_length, data_feature_path=feature_data_path, data_root_path=final_data_dir, plot=True) |
|
progress(3 / 3, desc=f"Tracking {3}/{3}") |
|
|
|
output_video1 = gen_tracking_video(video_num=video_num, fps=fps, data_path=feature_data_path) |
|
|
|
output_video2 = gen_tracking_video(video_num=video_num, fps=fps, data_path=feature_data_path, all_images=True) |
|
|
|
output_video3 = gen_tracking_video(video_num=video_num, fps=fps, data_path=feature_data_path, all_images=True) |
|
final_videos = [os.path.basename(output_video1), os.path.basename(output_video2), os.path.basename(output_video3)] |
|
shutil.copy(output_video1, final_videos[0]) |
|
shutil.copy(output_video2, final_videos[1]) |
|
shutil.copy(output_video3, final_videos[2]) |
|
print(output_video1) |
|
print(final_videos) |
|
|
|
search_pattern = "TrackedRawData" |
|
tracking_preds = find_and_return_csv_files(folder_path, search_pattern) |
|
|
|
|
|
return final_videos[0], final_videos[1], final_videos[2], tracking_preds |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
examples = [['RBS_2_4_h264.mp4'], ['RBS_4_4_h264.mp4'], ['RBS_7_6_h264.mp4']] |
|
|
|
title = "🎞️ MEMTrack Bacteria Tracking Video Tool" |
|
description = "Upload a video or selct from example to track. <br><br> If the input video does not play on browser, ensure its in a browser accetable format. Output will be generated iirespective of playback on browser. Refer: https://colab.research.google.com/drive/1U5pX_9iaR_T8knVV7o4ftKdDoGndCdEM?usp=sharing" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Select the appropriate tab to upload a video, a TIFF stack, or a folder containing image frames.") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Upload Video"): |
|
video_input = gr.Video(label="Video File") |
|
|
|
with gr.Tab("Upload TIFF Stack"): |
|
tiff_input = gr.File(label="TIFF File", file_types=["tif", "tiff"]) |
|
|
|
with gr.Tab("Upload Images"): |
|
image_input = gr.File(label="Image Files", file_types=["jpg", "jpeg", "png", "tif", "tiff"], file_count="multiple") |
|
|
|
|
|
fps_slider = gr.Slider(minimum=1, maximum=100, step=1, value=60, label="Output Video FPS") |
|
track_length_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=60, label="Minimum Track Length Threshold") |
|
|
|
|
|
submit_button = gr.Button("Submit") |
|
clear_button = gr.Button("Clear") |
|
|
|
|
|
clear_button.click( |
|
fn=clear_form, |
|
inputs=[], |
|
outputs=[video_input, tiff_input, image_input, fps_slider, track_length_slider] |
|
) |
|
|
|
with gr.Column(): |
|
outputs = [ |
|
gr.Video(label="Tracked Video (tracked frames)"), |
|
gr.Video(label="Tracked Video (all frames)"), |
|
gr.Video(label="Tracked Video (all frames, all tracks)"), |
|
gr.Files(label="CSV Data") |
|
] |
|
|
|
|
|
|
|
submit_button.click( |
|
fn=doo, |
|
inputs=[video_input, tiff_input, image_input, fps_slider, track_length_slider], |
|
outputs=outputs |
|
) |
|
|
|
demo.launch(share=True, debug=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|