Spaces:
Runtime error
Runtime error
File size: 3,371 Bytes
cd4c90e 9fbf078 cd4c90e e5bb367 9fbf078 cd4c90e e5bb367 9fbf078 e5bb367 9fbf078 cd4c90e 9fbf078 cd4c90e 9fbf078 ab5b42b cd4c90e b80c100 cd4c90e b80c100 1463eb9 b80c100 cd4c90e b80c100 cd4c90e b80c100 9fbf078 cd4c90e 9fbf078 cd4c90e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import cv2
import PIL
import torch
from classifier import CustomEfficientNet, CustomViT
from model import get_model, predict, prepare_prediction, predict_class
print('Creating the model')
model = get_model('efficientDet_icevision.ckpt')
print('Loading the classifier')
classifier = CustomViT(target_size=7, pretrained=False)
classifier.load_state_dict(torch.load('class_ViT_taco_7_class.pth', map_location='cpu'))
def plot_img_no_mask(image, boxes, labels):
colors = {
0: (255,255,0),
1: (255, 0, 0),
2: (0, 0, 255),
3: (0,128,0),
4: (255,165,0),
5: (230,230,250),
6: (192,192,192)
}
texts = {
0: 'plastic',
1: 'dangerous',
2: 'carton',
3: 'glass',
4: 'organic',
5: 'rest',
6: 'other'
}
# Show image
boxes = boxes.cpu().detach().numpy().astype(np.int32)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
for i, box in enumerate(boxes):
color = colors[labels[i]]
[x1, y1, x2, y2] = np.array(box).astype(int)
# Si no se hace la copia da error en cv2.rectangle
image = np.array(image).copy()
pt1 = (x1, y1)
pt2 = (x2, y2)
cv2.rectangle(image, pt1, pt2, color, thickness=5)
cv2.putText(image, texts[labels[i]], (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
plt.axis('off')
ax.imshow(image)
fig.savefig("img.png", bbox_inches='tight')
st.subheader('Upload Custom Image')
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
st.subheader('Example Images')
example_imgs = [
'example_imgs/basura_4_2.jpg',
'example_imgs/basura_1.jpg',
'example_imgs/basura_3.jpg'
]
with st.container() as cont:
st.image(example_imgs[0], width=150, caption='1')
if st.button('Select Image', key='Image_1'):
image_file = example_imgs[0]
with st.container() as cont:
st.image(example_imgs[1], width=150, caption='2')
if st.button('Select Image', key='Image_2'):
image_file = example_imgs[1]
with st.container() as cont:
st.image(example_imgs[2], width=150, caption='2')
if st.button('Select Image', key='Image_3'):
image_file = example_imgs[2]
st.subheader('Detection parameters')
detection_threshold = st.slider('Detection threshold',
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1)
nms_threshold = st.slider('NMS threshold',
min_value=0.0,
max_value=1.0,
value=0.3,
step=0.1)
st.subheader('Prediction')
if image_file is not None:
print('Getting predictions')
if isinstance(image_file, str):
data = image_file
else:
data = image_file.read()
pred_dict = predict(model, data, detection_threshold)
print('Fixing the preds')
boxes, image = prepare_prediction(pred_dict, nms_threshold)
print('Predicting classes')
labels = predict_class(classifier, image, boxes)
print('Plotting')
plot_img_no_mask(image, boxes, labels)
img = PIL.Image.open('img.png')
st.image(img,width=750) |