waste-classifier / trash_detector.py
santit96's picture
Stop versioning the model checkpoints, now they are downloaded from huggingface. Add env vars
dd14920
import os
import numpy as np
import torch
from fastai.vision.all import load_learner
from huggingface_hub import hf_hub_download
from constants import (CLAS_FILENAME, CLAS_FILEPATH, CLAS_THRESHOLD,
DET_FILENAME, DET_FILEPATH, DET_NAME, DET_THRESHOLD,
DEVICE, HF_CLAS_REPO_NAME, HF_DET_REPO_NAME,
MODELS_PATH)
from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
def localize_trash(im):
# detector, if checkpoint doesn't exist then download from hf
if not os.path.exists(DET_FILEPATH):
hf_hub_download(HF_DET_REPO_NAME, DET_FILENAME, local_dir=MODELS_PATH)
detector = set_model(DET_NAME, 1, DET_FILEPATH, DEVICE)
detector.eval()
# mean-std normalize the input image (batch-size: 1)
img = get_transforms(im)
# propagate through the model
outputs = detector(img.to(DEVICE))
# keep only predictions above set confidence
bboxes_keep = outputs[0, outputs[0, :, 4] > DET_THRESHOLD]
probas = bboxes_keep[:, 4:]
# convert boxes to image scales
bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
return probas, bboxes_scaled
def classify_trash(im, probas, bboxes_scaled):
# classifier, if checkpoint doesn't exist then download from hf
if not os.path.exists(CLAS_FILEPATH):
hf_hub_download(HF_CLAS_REPO_NAME, CLAS_FILENAME, local_dir=MODELS_PATH)
classifier = load_learner(CLAS_FILEPATH)
bboxes_final = []
cls_prob = []
for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()):
img = im.crop((xmin, ymin, xmax, ymax))
outputs = classifier.predict(img)
p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
p[0] = torch.max(np.trunc(outputs[2] * 100))
if p[0] >= CLAS_THRESHOLD * 100:
bboxes_final.append((xmin, ymin, xmax, ymax))
cls_prob.append(p)
return cls_prob, bboxes_final
def detect_trash(img):
# prepare models for evaluation
torch.set_grad_enabled(False)
# 1) Localize
probas, bboxes_scaled = localize_trash(img)
# 2) Classify
cls_prob, bboxes_final = classify_trash(img, probas, bboxes_scaled)
return cls_prob, bboxes_final