dgbkn commited on
Commit
4bbacde
·
1 Parent(s): 25ed33b
Files changed (5) hide show
  1. best.pt +3 -0
  2. main.py +49 -0
  3. pillmodel.py +107 -0
  4. requirements.txt +9 -0
  5. upload.html +52 -0
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:504dd2ca20d1e8e7be2d925b840c90c78742e6abe41fb95507dd915d212ca371
3
+ size 92233042
main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI, File, UploadFile
3
+ from fastapi.responses import JSONResponse,HTMLResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import cv2
6
+ import numpy as np
7
+ from pillmodel import get_prediction
8
+ import base64
9
+
10
+ app = FastAPI()
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+
21
+
22
+ @app.post("/predict")
23
+ async def predict(image: UploadFile = File(...)):
24
+ contents = await image.read()
25
+ nparr = np.frombuffer(contents, np.uint8)
26
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
27
+
28
+ # Save the image to a temporary location
29
+ # temp_image_path = "temp_image.jpg"
30
+ # cv2.imwrite(temp_image_path, img)
31
+
32
+ # Prediction
33
+ predicted_image, count_dict = get_prediction(img)
34
+ # Encode predicted image to base64
35
+ _, buffer = cv2.imencode('.jpg', predicted_image)
36
+ predicted_image_str = base64.b64encode(buffer).decode('utf-8')
37
+
38
+ # Send a confirmation message
39
+ message_to_send = (
40
+ f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. "
41
+ f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills."
42
+ )
43
+
44
+ return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str})
45
+
46
+ @app.get("/", response_class=HTMLResponse)
47
+ async def read_root():
48
+ with open("index.html", "r") as file:
49
+ return HTMLResponse(content=file.read(), status_code=200)
pillmodel.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import cv2
4
+ from shapely.geometry import Polygon
5
+ from io import BytesIO
6
+ from ultralytics import YOLO
7
+ import torch
8
+
9
+ # Define colors
10
+ COLORS = [(98, 231, 4), (228, 161, 0)] # Green and blue
11
+ CLASSES = ['capsules', 'tablets']
12
+
13
+
14
+
15
+ def get_prediction(image):
16
+ '''
17
+ Gets image from telebot, make predictions,
18
+ counts predicted classes, draws dots on image,
19
+ returns dict with counts and labelled image
20
+ '''
21
+ # Load a model
22
+ model = YOLO('yolov8l.pt') # load an official model
23
+ model = YOLO('best.pt') # load a custom model
24
+ # image = cv2.imread(image_path)
25
+
26
+ # Get prediction
27
+ prediction = model(image)
28
+
29
+ # Get predicted classes
30
+ predicted_classes = prediction[0].boxes.cls
31
+ # Get predicted confidence of each class
32
+ prediction_confidences = prediction[0].boxes.conf
33
+
34
+ # Get polygons
35
+ polygons = prediction[0].masks.xy
36
+ # Convert polygons to int32
37
+ polygons = [polygon.astype(np.int32) for polygon in polygons]
38
+
39
+ # Create indices mask that shows what is overlapping polygon has smaller confidence score
40
+ indices_mask = remove_overlapping_polygons(polygons, prediction_confidences)
41
+
42
+ # Create new fixed lists with predicted classes and polygons
43
+ fixed_predicted_classes = predicted_classes[np.array(indices_mask, dtype=bool)]
44
+ fixed_polygons = [polygons[i] for i in range(len(indices_mask)) if indices_mask[i] == 1]
45
+ # fixed_predicted_classes = [predicted_classes[i] for i in range(len(indices_mask)) if indices_mask[i] == 1]
46
+
47
+ # Get counts of classes
48
+ unique, counts = torch.unique(fixed_predicted_classes, return_counts=True)
49
+ # Get dicts with counts of classes
50
+ count_dict = {CLASSES[int(key)]: value for key, value in zip(unique.tolist(), counts.tolist())}
51
+
52
+ # # Draw polygons
53
+ # for polygon, predicted_class in zip(fixed_polygons, fixed_predicted_classes):
54
+ # cv2.polylines(image, [polygon], True, COLORS[int(predicted_class)])
55
+
56
+ # Draw dots
57
+ for polygon, predicted_class in zip(fixed_polygons, fixed_predicted_classes):
58
+ # Find center of polygon
59
+ center_coordinates = (np.mean(polygon[:, 0], dtype=np.int32), np.mean(polygon[:, 1], dtype=np.int32)) # x and y respectively
60
+ # Draw a circle
61
+ cv2.circle(image, center_coordinates, 5, COLORS[int(predicted_class)], 2, cv2.LINE_AA)
62
+
63
+ # # Show image with predictions on it
64
+ # cv2.imshow("Image", image)
65
+ # cv2.waitKey(0)
66
+ # cv2.destroyAllWindows()
67
+ # from google.colab.patches import cv2_imshow
68
+ # cv2_imshow(image)
69
+ return image, count_dict
70
+
71
+
72
+ def remove_overlapping_polygons(polygons, prediction_confidences):
73
+ '''
74
+ Takes polygons, finds overlapping regions,
75
+ intersection area, overlap percentage,
76
+ creates indices mask that shows what
77
+ overlapping polygon has smaller confidence.
78
+ '''
79
+
80
+ # Convert the NumPy arrays to Shapely polygons
81
+ shapely_polygons = [Polygon(polygon) for polygon in polygons]
82
+ # Create an empty list with overlapping pairs
83
+ overlapping_pairs = []
84
+
85
+ # Check for overlaps between all pairs of polygons
86
+ for i in range(len(shapely_polygons)):
87
+ for j in range(i+1, len(shapely_polygons)):
88
+ if shapely_polygons[i].intersects(shapely_polygons[j]):
89
+ # Calculate the percentage of overlap
90
+ intersection_area = shapely_polygons[i].intersection(shapely_polygons[j]).area
91
+ overlap_percentage = intersection_area / shapely_polygons[i].area
92
+ # Add overlapping polygons indexes to list
93
+ if overlap_percentage > 0.5:
94
+ overlapping_pairs.append((i, j))
95
+
96
+ # Mask of remains indices
97
+ indices_mask = [1 for i in range(len(shapely_polygons))]
98
+
99
+ # Remove one of the overlapping polygons
100
+ for first_over_polygon_ind, second_over_polygon_ind in overlapping_pairs:
101
+ # Find index that has the smallest prediction confidence
102
+ first_has_bigger_conf = prediction_confidences[first_over_polygon_ind] >= prediction_confidences[second_over_polygon_ind]
103
+ index_small_conf = [first_over_polygon_ind, second_over_polygon_ind][first_has_bigger_conf]
104
+ # Set value with smaller confidence to 0 in indices_mask
105
+ indices_mask[index_small_conf] = 0
106
+
107
+ return indices_mask
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ cvzone
2
+ ultralytics
3
+ numpy
4
+ pandas
5
+ torch
6
+
7
+ opencv-python
8
+ shapely
9
+ ultralytics
upload.html ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Image Upload</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
8
+ <style>
9
+ body {
10
+ font-family: Arial, sans-serif;
11
+ }
12
+ </style>
13
+ </head>
14
+ <body class="flex flex-col items-center justify-center h-screen bg-gray-100">
15
+ <h1 class="text-2xl font-bold mb-4">Upload an Image</h1>
16
+ <input type="file" id="fileInput" accept="image/*" class="mb-4">
17
+ <button onclick="uploadImage()" class="bg-blue-500 text-white px-4 py-2 rounded">Upload</button>
18
+ <div id="output" class="mt-4"></div>
19
+ <script>
20
+ async function uploadImage() {
21
+ const fileInput = document.getElementById('fileInput');
22
+ const output = document.getElementById('output');
23
+ const file = fileInput.files[0];
24
+ if (!file) {
25
+ output.innerHTML = 'No image selected.';
26
+ return;
27
+ }
28
+
29
+ const formData = new FormData();
30
+ formData.append('image', file);
31
+
32
+ output.innerHTML = 'Uploading...';
33
+
34
+ try {
35
+ const response = await fetch('/predict', {
36
+ method: 'POST',
37
+ body: formData
38
+ });
39
+ const result = await response.json();
40
+ const predictedImageSrc = `data:image/jpeg;base64,${result.predicted_image}`;
41
+ output.innerHTML = `
42
+ <p>${result.message}</p>
43
+ <img src="${predictedImageSrc}" alt="Predicted Image" class="mt-4">
44
+ `;
45
+ } catch (error) {
46
+ output.innerHTML = 'Failed to get prediction';
47
+ console.error(error);
48
+ }
49
+ }
50
+ </script>
51
+ </body>
52
+ </html>