ryo2's picture
Update app.py
c604388 verified
raw
history blame
1.94 kB
import glob
import gradio as gr
import yolov5
model = yolov5.load('model/dango.pt')
#てすと
def inference(gr_input):
# set model parameters
model.conf = 0.45 # NMS confidence threshold
model.iou = 0.45 # NMS IoU threshold
model.agnostic = False # NMS class-agnostic
model.multi_label = False # NMS multiple labels per box
model.max_det = 1 # maximum number of detections per image
results = model(gr_input)
output_folder = "results"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# parse results
predictions = results.pred[0]
boxes = predictions[:, :4] # x1, y1, x2, y2
scores = predictions[:, 4]
categories = predictions[:, 5]
# boxesからx1,y1,x2,y2を取り出す
x1 = boxes[:, 0].tolist()
y1 = boxes[:, 1].tolist()
x2 = boxes[:, 2].tolist()
y2 = boxes[:, 3].tolist()
# 4つの座標のうち、一つでも入っていなかったら、その画像はスキップ
if x1 == [] or y1 == [] or x2 == [] or y2 == []:
return
x1 = int(x1[0])
y1 = int(y1[0])
x2 = int(x2[0])
y2 = int(y2[0])
img = Image.open(image)
img_crop = img.crop((x1, y1, x2, y2))
img_name = os.path.basename(image)
# pngで保存
results.save(save_dir='results/')
img_crop.save(f"{output_folder}/{img_name}", quality=95)
img_list = [glob.glob("result/*.png")]
return img_list
with gr.Blocks() as app:
gr.Markdown('<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=alrab222.Cinderella" />')
gr.Markdown(
"# <center> ダンゴムシ捕捉\n"
"## <center> ダンゴムシの腹側からの画像を、機械学習で判別できるモデルです\n"
)
inputs = gr.Image()
output = gr.Gallery(label="結果")
btn = gr.Button("judge")
btn.click(inference, inputs, output)
app.launch()