Waste-Detector / model.py
Hector Lopez
Fixed upload file bug
782cec7
raw
history blame
2.33 kB
from io import BytesIO
from icevision import *
import collections
import PIL
import torch
import numpy as np
import torchvision
import icevision.models.ross.efficientdet
MODEL_TYPE = icevision.models.ross.efficientdet
def get_model(checkpoint_path):
extra_args = {}
backbone = MODEL_TYPE.backbones.d0
# The efficientdet model requires an img_size parameter
extra_args['img_size'] = 512
model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
num_classes=2,
**extra_args)
ckpt = get_checkpoint(checkpoint_path)
model.load_state_dict(ckpt)
return model
def get_checkpoint(checkpoint_path):
ckpt = torch.load('checkpoint.ckpt', map_location=torch.device('cpu'))
fixed_state_dict = collections.OrderedDict()
for k, v in ckpt['state_dict'].items():
new_k = k[6:]
fixed_state_dict[new_k] = v
return fixed_state_dict
def predict(model, image):
#img = PIL.Image.open(image)
img = PIL.Image.open(BytesIO(image))
print(img.shape)
class_map = ClassMap(classes=['Waste'])
transforms = tfms.A.Adapter([
*tfms.A.resize_and_pad(512),
tfms.A.Normalize()
])
pred_dict = MODEL_TYPE.end2end_detect(img,
transforms,
model,
class_map=class_map,
detection_threshold=0.5,
return_as_pil_img=False,
return_img=True,
display_bbox=False,
display_score=False,
display_label=False)
return pred_dict
def prepare_prediction(pred_dict):
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
boxes = torch.stack(boxes)
scores = torch.as_tensor(pred_dict['detection']['scores'])
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
image = np.array(pred_dict['img'])
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, 0.1)
boxes = boxes[fixed_boxes, :]
return boxes, image