Spaces:
Running
Running
dgbkn
commited on
Commit
·
4bbacde
1
Parent(s):
25ed33b
dnr
Browse files- best.pt +3 -0
- main.py +49 -0
- pillmodel.py +107 -0
- requirements.txt +9 -0
- upload.html +52 -0
best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:504dd2ca20d1e8e7be2d925b840c90c78742e6abe41fb95507dd915d212ca371
|
3 |
+
size 92233042
|
main.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from fastapi import FastAPI, File, UploadFile
|
3 |
+
from fastapi.responses import JSONResponse,HTMLResponse
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from pillmodel import get_prediction
|
8 |
+
import base64
|
9 |
+
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
app.add_middleware(
|
13 |
+
CORSMiddleware,
|
14 |
+
allow_origins=["*"],
|
15 |
+
allow_credentials=True,
|
16 |
+
allow_methods=["*"],
|
17 |
+
allow_headers=["*"],
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
@app.post("/predict")
|
23 |
+
async def predict(image: UploadFile = File(...)):
|
24 |
+
contents = await image.read()
|
25 |
+
nparr = np.frombuffer(contents, np.uint8)
|
26 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
27 |
+
|
28 |
+
# Save the image to a temporary location
|
29 |
+
# temp_image_path = "temp_image.jpg"
|
30 |
+
# cv2.imwrite(temp_image_path, img)
|
31 |
+
|
32 |
+
# Prediction
|
33 |
+
predicted_image, count_dict = get_prediction(img)
|
34 |
+
# Encode predicted image to base64
|
35 |
+
_, buffer = cv2.imencode('.jpg', predicted_image)
|
36 |
+
predicted_image_str = base64.b64encode(buffer).decode('utf-8')
|
37 |
+
|
38 |
+
# Send a confirmation message
|
39 |
+
message_to_send = (
|
40 |
+
f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. "
|
41 |
+
f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills."
|
42 |
+
)
|
43 |
+
|
44 |
+
return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str})
|
45 |
+
|
46 |
+
@app.get("/", response_class=HTMLResponse)
|
47 |
+
async def read_root():
|
48 |
+
with open("index.html", "r") as file:
|
49 |
+
return HTMLResponse(content=file.read(), status_code=200)
|
pillmodel.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from shapely.geometry import Polygon
|
5 |
+
from io import BytesIO
|
6 |
+
from ultralytics import YOLO
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# Define colors
|
10 |
+
COLORS = [(98, 231, 4), (228, 161, 0)] # Green and blue
|
11 |
+
CLASSES = ['capsules', 'tablets']
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def get_prediction(image):
|
16 |
+
'''
|
17 |
+
Gets image from telebot, make predictions,
|
18 |
+
counts predicted classes, draws dots on image,
|
19 |
+
returns dict with counts and labelled image
|
20 |
+
'''
|
21 |
+
# Load a model
|
22 |
+
model = YOLO('yolov8l.pt') # load an official model
|
23 |
+
model = YOLO('best.pt') # load a custom model
|
24 |
+
# image = cv2.imread(image_path)
|
25 |
+
|
26 |
+
# Get prediction
|
27 |
+
prediction = model(image)
|
28 |
+
|
29 |
+
# Get predicted classes
|
30 |
+
predicted_classes = prediction[0].boxes.cls
|
31 |
+
# Get predicted confidence of each class
|
32 |
+
prediction_confidences = prediction[0].boxes.conf
|
33 |
+
|
34 |
+
# Get polygons
|
35 |
+
polygons = prediction[0].masks.xy
|
36 |
+
# Convert polygons to int32
|
37 |
+
polygons = [polygon.astype(np.int32) for polygon in polygons]
|
38 |
+
|
39 |
+
# Create indices mask that shows what is overlapping polygon has smaller confidence score
|
40 |
+
indices_mask = remove_overlapping_polygons(polygons, prediction_confidences)
|
41 |
+
|
42 |
+
# Create new fixed lists with predicted classes and polygons
|
43 |
+
fixed_predicted_classes = predicted_classes[np.array(indices_mask, dtype=bool)]
|
44 |
+
fixed_polygons = [polygons[i] for i in range(len(indices_mask)) if indices_mask[i] == 1]
|
45 |
+
# fixed_predicted_classes = [predicted_classes[i] for i in range(len(indices_mask)) if indices_mask[i] == 1]
|
46 |
+
|
47 |
+
# Get counts of classes
|
48 |
+
unique, counts = torch.unique(fixed_predicted_classes, return_counts=True)
|
49 |
+
# Get dicts with counts of classes
|
50 |
+
count_dict = {CLASSES[int(key)]: value for key, value in zip(unique.tolist(), counts.tolist())}
|
51 |
+
|
52 |
+
# # Draw polygons
|
53 |
+
# for polygon, predicted_class in zip(fixed_polygons, fixed_predicted_classes):
|
54 |
+
# cv2.polylines(image, [polygon], True, COLORS[int(predicted_class)])
|
55 |
+
|
56 |
+
# Draw dots
|
57 |
+
for polygon, predicted_class in zip(fixed_polygons, fixed_predicted_classes):
|
58 |
+
# Find center of polygon
|
59 |
+
center_coordinates = (np.mean(polygon[:, 0], dtype=np.int32), np.mean(polygon[:, 1], dtype=np.int32)) # x and y respectively
|
60 |
+
# Draw a circle
|
61 |
+
cv2.circle(image, center_coordinates, 5, COLORS[int(predicted_class)], 2, cv2.LINE_AA)
|
62 |
+
|
63 |
+
# # Show image with predictions on it
|
64 |
+
# cv2.imshow("Image", image)
|
65 |
+
# cv2.waitKey(0)
|
66 |
+
# cv2.destroyAllWindows()
|
67 |
+
# from google.colab.patches import cv2_imshow
|
68 |
+
# cv2_imshow(image)
|
69 |
+
return image, count_dict
|
70 |
+
|
71 |
+
|
72 |
+
def remove_overlapping_polygons(polygons, prediction_confidences):
|
73 |
+
'''
|
74 |
+
Takes polygons, finds overlapping regions,
|
75 |
+
intersection area, overlap percentage,
|
76 |
+
creates indices mask that shows what
|
77 |
+
overlapping polygon has smaller confidence.
|
78 |
+
'''
|
79 |
+
|
80 |
+
# Convert the NumPy arrays to Shapely polygons
|
81 |
+
shapely_polygons = [Polygon(polygon) for polygon in polygons]
|
82 |
+
# Create an empty list with overlapping pairs
|
83 |
+
overlapping_pairs = []
|
84 |
+
|
85 |
+
# Check for overlaps between all pairs of polygons
|
86 |
+
for i in range(len(shapely_polygons)):
|
87 |
+
for j in range(i+1, len(shapely_polygons)):
|
88 |
+
if shapely_polygons[i].intersects(shapely_polygons[j]):
|
89 |
+
# Calculate the percentage of overlap
|
90 |
+
intersection_area = shapely_polygons[i].intersection(shapely_polygons[j]).area
|
91 |
+
overlap_percentage = intersection_area / shapely_polygons[i].area
|
92 |
+
# Add overlapping polygons indexes to list
|
93 |
+
if overlap_percentage > 0.5:
|
94 |
+
overlapping_pairs.append((i, j))
|
95 |
+
|
96 |
+
# Mask of remains indices
|
97 |
+
indices_mask = [1 for i in range(len(shapely_polygons))]
|
98 |
+
|
99 |
+
# Remove one of the overlapping polygons
|
100 |
+
for first_over_polygon_ind, second_over_polygon_ind in overlapping_pairs:
|
101 |
+
# Find index that has the smallest prediction confidence
|
102 |
+
first_has_bigger_conf = prediction_confidences[first_over_polygon_ind] >= prediction_confidences[second_over_polygon_ind]
|
103 |
+
index_small_conf = [first_over_polygon_ind, second_over_polygon_ind][first_has_bigger_conf]
|
104 |
+
# Set value with smaller confidence to 0 in indices_mask
|
105 |
+
indices_mask[index_small_conf] = 0
|
106 |
+
|
107 |
+
return indices_mask
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cvzone
|
2 |
+
ultralytics
|
3 |
+
numpy
|
4 |
+
pandas
|
5 |
+
torch
|
6 |
+
|
7 |
+
opencv-python
|
8 |
+
shapely
|
9 |
+
ultralytics
|
upload.html
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Image Upload</title>
|
7 |
+
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
|
8 |
+
<style>
|
9 |
+
body {
|
10 |
+
font-family: Arial, sans-serif;
|
11 |
+
}
|
12 |
+
</style>
|
13 |
+
</head>
|
14 |
+
<body class="flex flex-col items-center justify-center h-screen bg-gray-100">
|
15 |
+
<h1 class="text-2xl font-bold mb-4">Upload an Image</h1>
|
16 |
+
<input type="file" id="fileInput" accept="image/*" class="mb-4">
|
17 |
+
<button onclick="uploadImage()" class="bg-blue-500 text-white px-4 py-2 rounded">Upload</button>
|
18 |
+
<div id="output" class="mt-4"></div>
|
19 |
+
<script>
|
20 |
+
async function uploadImage() {
|
21 |
+
const fileInput = document.getElementById('fileInput');
|
22 |
+
const output = document.getElementById('output');
|
23 |
+
const file = fileInput.files[0];
|
24 |
+
if (!file) {
|
25 |
+
output.innerHTML = 'No image selected.';
|
26 |
+
return;
|
27 |
+
}
|
28 |
+
|
29 |
+
const formData = new FormData();
|
30 |
+
formData.append('image', file);
|
31 |
+
|
32 |
+
output.innerHTML = 'Uploading...';
|
33 |
+
|
34 |
+
try {
|
35 |
+
const response = await fetch('/predict', {
|
36 |
+
method: 'POST',
|
37 |
+
body: formData
|
38 |
+
});
|
39 |
+
const result = await response.json();
|
40 |
+
const predictedImageSrc = `data:image/jpeg;base64,${result.predicted_image}`;
|
41 |
+
output.innerHTML = `
|
42 |
+
<p>${result.message}</p>
|
43 |
+
<img src="${predictedImageSrc}" alt="Predicted Image" class="mt-4">
|
44 |
+
`;
|
45 |
+
} catch (error) {
|
46 |
+
output.innerHTML = 'Failed to get prediction';
|
47 |
+
console.error(error);
|
48 |
+
}
|
49 |
+
}
|
50 |
+
</script>
|
51 |
+
</body>
|
52 |
+
</html>
|