Niv Sardi
import python
1a24a58
raw
history blame
2.68 kB
import os
from fastapi import FastAPI, WebSocket
from YOLOv6.yolov6.core.inferer import Inferer
import cv2
import yaml as YAML
import json
import csv
import ssl
import hashlib
from entity import read_entities
import imtool
app = FastAPI()
weights = './runs/train/exp27/weights/best_stop_aug_ckpt.pt'
device = 'cpu'
yaml = './data.yaml'
img_size = [640, 640]
half = False
conf_thres = 0.5
iou_thres = 0.45
classes = None
agnostic_nms = None
max_det = 1000
try:
with open(yaml, 'r') as f:
classes_data = YAML.safe_load(f.read())
entities = read_entities('../data/entities.csv')
certs = {}
with os.scandir('../data/certs') as it:
for entry in it:
bco, ext = entry.name.split('.')
if ext == 'cert':
try:
cert_dict = ssl._ssl._test_decode_cert(entry.path)
with open(entry.path, 'r') as f:
cert_dict.update({
'fingerprint': hashlib.sha1(
ssl.PEM_cert_to_DER_cert(f.read())
).hexdigest()
})
except Exception as e:
print("Error decoding certificate: {:}".format(e))
else:
name = entities[bco].name
certs.update({name: cert_dict})
print(f'loaded {len(certs.keys())} certs, got {len(classes_data["names"])} classes')
inferer = Inferer(weights, device, yaml, img_size, half)
except Exception as e:
print('error', e)
@app.get("/")
async def root():
return {"message": "API is working"}
@app.websocket("/ws")
async def websockets_cb(websocket: WebSocket):
try:
await websocket.accept()
while True:
data = await websocket.receive_text()
img = imtool.read_base64(data)
cv2.imwrite("debug.png", img)
try:
os.remove("debug.txt")
except:
pass
inferer.load(img)
ret = inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det)
print(ret)
await websocket.send_text(ret + '@@@@' + '[%d,%d,%d]'%img.shape)
except Exception as e:
print("got: ", e)
@app.websocket("/bgws")
async def send_classes(websocket: WebSocket):
await websocket.accept()
await websocket.send_text(json.dumps({
'classes': classes_data,
'certs': certs
}))
await websocket.close()
if __name__ == "__main__":
import uvicorn
config = uvicorn.Config("api:app", port=5000, log_level="info")
server = uvicorn.Server(config)
server.run()