Spaces:
Runtime error
Runtime error
# Check Pytorch installation | |
import torch, torchvision | |
print("torch version:",torch.__version__, "cuda:",torch.cuda.is_available()) | |
# Check MMDetection installation | |
import mmdet | |
import os | |
import mmcv | |
import mmengine | |
from mmdet.apis import init_detector, inference_detector | |
from mmdet.utils import register_all_modules | |
from mmdet.registry import VISUALIZERS | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub import snapshot_download | |
from time import time | |
import concurrent.futures | |
import threading | |
classes= ['Beach', | |
'Sea', | |
'Wave', | |
'Rock', | |
'Breaking wave', | |
'Reflection of the sea', | |
'Foam', | |
'Algae', | |
'Vegetation', | |
'Watermark', | |
'Bird', | |
'Ship', | |
'Boat', | |
'Car', | |
'Kayak', | |
"Shark's line", | |
'Dock', | |
'Dog', | |
'Unidentifiable shade', | |
'Bird shadow', | |
'Boat shadow', | |
'Kayal shade', | |
'Surfer shadow', | |
'Shark shadow', | |
'Surfboard shadow', | |
'Crocodile', | |
'Sea cow', | |
'Stingray', | |
'Person', | |
'ocean', | |
'Surfer', | |
'Surfer', | |
'Fish', | |
'Killer whale', | |
'Whale', | |
'Dolphin', | |
'Miscellaneous', | |
'Unidentifiable shark', | |
'Carpet shark', | |
'Dusty shark', | |
'Blue shark', | |
'Great white shark', | |
'Copper shark', | |
'Nurse shark', | |
'Silky shark', | |
'Leopard shark', | |
'Shortfin mako shark', | |
'Hammerhead shark', | |
'Oceanic whitetip shark', | |
'Blacktip shark', | |
'Tiger shark', | |
'Bull shark']*3 | |
REPO_ID = "SharkSpace/maskformer_model" | |
FILENAME = "mask2former" | |
snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/') | |
# Choose to use a config and initialize the detector | |
config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py' | |
#'/content/mmdetection/configs/panoptic_fpn/panoptic-fpn_r50_fpn_ms-3x_coco.py' | |
# Setup a checkpoint file to load | |
checkpoint_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/checkpoint.pth' | |
# '/content/drive/MyDrive/Algorithms/weights/shark_panoptic_weights_16_4_23/panoptic-fpn_r50_fpn_ms-3x_coco/epoch_36.pth' | |
# register all modules in mmdet into the registries | |
register_all_modules() | |
# build the model from a config file and a checkpoint file | |
model = init_detector(config_file, checkpoint_file, device='cuda:0') # or device='cuda:0' | |
model.dataset_meta['classes'] = classes | |
print(model.cfg.visualizer) | |
# init visualizer(run the block only once in jupyter notebook) | |
visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
print(dir(visualizer)) | |
# the dataset_meta is loaded from the checkpoint and | |
# then pass to the model in init_detector | |
visualizer.dataset_meta = model.dataset_meta | |
def inference_frame_serial(image): | |
result = inference_detector(model, image) | |
# show the results | |
visualizer.add_datasample( | |
'result', | |
image, | |
data_sample=result, | |
draw_gt = None, | |
show=False | |
) | |
frame = visualizer.get_image() | |
return frame | |
def inference_frame(image): | |
result = inference_detector(model, image) | |
# show the results | |
frames = [] | |
cnt=0 | |
for res in result: | |
visualizer.add_datasample( | |
'result', | |
image[cnt], | |
data_sample=res.numpy(), | |
draw_gt = None, | |
show=False | |
) | |
frame = visualizer.get_image() | |
frames.append(frame) | |
cnt+=1 | |
#frames = process_frames(result, image, visualizer) | |
return frames | |
def inference_frame_par_ready(image): | |
result = inference_detector(model, image) | |
return [result[i].numpy() for i in range(len(result))] | |
def process_frame(in_tuple = (None, None, None)): | |
visualizer.add_datasample( | |
'result', | |
in_tuple[1], #image, | |
data_sample=in_tuple[0], #result | |
draw_gt = None, | |
show=False | |
) | |
#frame = visualizer.get_image() | |
#print(in_tuple[2]) | |
return visualizer.get_image() | |
#def process_frame(frame): | |
# def process_frames(result, image, visualizer): | |
# frames = [] | |
# lock = threading.Lock() | |
# def process_data(cnt, res, img): | |
# visualizer.add_datasample('result', img, data_sample=res, draw_gt=None, show=False) | |
# frame = visualizer.get_image() | |
# with lock: | |
# frames.append(frame) | |
# threads = [] | |
# for cnt, res in enumerate(result): | |
# t = threading.Thread(target=process_data, args=(cnt, res, image[cnt])) | |
# threads.append(t) | |
# t.start() | |
# for t in threads: | |
# t.join() | |
# return frames | |