Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pathlib import Path
|
3 |
+
import config
|
4 |
+
from utils import load_model, infer_uploaded_image, infer_uploaded_video, infer_uploaded_webcam
|
5 |
+
from shared import upload_records
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# 设置网页标题和布局
|
9 |
+
st.set_page_config(page_title="工业零件检测", layout="wide", initial_sidebar_state="expanded")
|
10 |
+
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
<h1 style='text-align: center;'>缺陷发现者</h1>
|
14 |
+
<h1 style='text-align: right; font-size: 24px; color: #333; font-weight: bold; margin-bottom: 20px;'>
|
15 |
+
——基于YOLOv8的工业零件检测
|
16 |
+
</h1>
|
17 |
+
""", unsafe_allow_html=True)
|
18 |
+
|
19 |
+
# 侧边栏:模型配置
|
20 |
+
st.sidebar.header("模型配置")
|
21 |
+
task_options = ["目标检测", "实例分割", "图像分类"]
|
22 |
+
if 'task_type' not in st.session_state:
|
23 |
+
st.session_state['task_type'] = task_options[0]
|
24 |
+
|
25 |
+
task_type = st.sidebar.selectbox(
|
26 |
+
"任务选择",
|
27 |
+
task_options,
|
28 |
+
key="task_type"
|
29 |
+
)
|
30 |
+
model_path = ""
|
31 |
+
if task_type == "目标检测":
|
32 |
+
model_type = st.sidebar.selectbox(
|
33 |
+
"模型选择",
|
34 |
+
config.DETECTION_MODEL_LIST,
|
35 |
+
key="model_type_selectbox"
|
36 |
+
)
|
37 |
+
model_path = Path(config.DETECTION_MODEL_DIR, model_type)
|
38 |
+
elif task_type == "实例分割":
|
39 |
+
model_type = st.sidebar.selectbox(
|
40 |
+
"模型选择",
|
41 |
+
config.INSTANCE_SEGMENTATION_MODEL_LIST,
|
42 |
+
key="model_type_selectbox",
|
43 |
+
index=1 # 默认选择 best-2.pt
|
44 |
+
)
|
45 |
+
model_path = Path(config.INSTANCE_SEGMENTATION_MODEL_DIR, model_type)
|
46 |
+
elif task_type == "图像分类":
|
47 |
+
model_type = st.sidebar.selectbox(
|
48 |
+
"模型选择",
|
49 |
+
config.CLASSIFICATION_MODEL_LIST,
|
50 |
+
key="model_type_selectbox",
|
51 |
+
index=2 # 默认选择 best-3.pt
|
52 |
+
)
|
53 |
+
model_path = Path(config.CLASSIFICATION_MODEL_DIR, model_type)
|
54 |
+
else:
|
55 |
+
st.error("目前仅实现‘目标检测’、‘实例分割’和‘图像分类’功能")
|
56 |
+
|
57 |
+
confidence = float(st.sidebar.slider("选择模型置信度", 30, 100, 30)) / 100
|
58 |
+
|
59 |
+
# 加载模型
|
60 |
+
try:
|
61 |
+
model = load_model(model_path)
|
62 |
+
except Exception as e:
|
63 |
+
st.error(f"Unable to load model. Please check the specified path: {model_path}")
|
64 |
+
|
65 |
+
# 图像/视频配置
|
66 |
+
st.sidebar.header("图像/视频配置")
|
67 |
+
source_selectbox = st.sidebar.selectbox(
|
68 |
+
"选择上传类型",
|
69 |
+
config.SOURCES_LIST,
|
70 |
+
key="source_selectbox"
|
71 |
+
)
|
72 |
+
save_path = st.sidebar.text_input("输入保存结果的文件夹路径", "请输入路径", key="save_path_input")
|
73 |
+
|
74 |
+
# 根据选择的上传类型调用相应的函数
|
75 |
+
if source_selectbox == config.SOURCES_LIST[0]: # Image
|
76 |
+
infer_uploaded_image(confidence, model, save_path, task_type)
|
77 |
+
elif source_selectbox == config.SOURCES_LIST[1]: # Video
|
78 |
+
infer_uploaded_video(confidence, model, save_path, task_type)
|
79 |
+
elif source_selectbox == config.SOURCES_LIST[2]: # Webcam
|
80 |
+
infer_uploaded_webcam(confidence, model, save_path, task_type)
|
81 |
+
else:
|
82 |
+
st.error("目前只支持‘图片’、‘视频’和‘摄像头’类型的上传")
|
83 |
+
|
84 |
+
# 显示上传记录
|
85 |
+
st.subheader("上传记录")
|
86 |
+
if upload_records:
|
87 |
+
# 创建 DataFrame
|
88 |
+
upload_df = pd.DataFrame(upload_records)
|
89 |
+
# 重置索引,并且不把旧索引添加到 DataFrame 中
|
90 |
+
upload_df.columns = ['文件名', '文件类型', '上传时间']
|
91 |
+
upload_df = upload_df.reset_index(drop=True)
|
92 |
+
# 显示 DataFrame
|
93 |
+
st.table(upload_df)
|
94 |
+
else:
|
95 |
+
st.write("没有上传记录。")
|
config.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import sys
|
3 |
+
|
4 |
+
# Get the absolute path of the current file
|
5 |
+
file_path = Path(__file__).resolve()
|
6 |
+
|
7 |
+
# Get the parent directory of the current file
|
8 |
+
root_path = file_path.parent
|
9 |
+
|
10 |
+
# Add the root path to the sys.path list if it is not already there
|
11 |
+
if root_path not in sys.path:
|
12 |
+
sys.path.append(str(root_path))
|
13 |
+
|
14 |
+
# Get the relative path of the root directory with respect to the current working directory
|
15 |
+
ROOT = root_path.relative_to(Path.cwd())
|
16 |
+
|
17 |
+
# Source
|
18 |
+
SOURCES_LIST = ["图片", "视频", "摄像头"]
|
19 |
+
|
20 |
+
# DL model config
|
21 |
+
DETECTION_MODEL_DIR = ROOT / 'weights' / 'detection'
|
22 |
+
INSTANCE_SEGMENTATION_MODEL_DIR = ROOT / 'weights' / 'instance_segmentation'
|
23 |
+
CLASSIFICATION_MODEL_DIR = ROOT / 'weights' / 'classification'
|
24 |
+
YOLOv8n = DETECTION_MODEL_DIR / "yolov8n.pt"
|
25 |
+
YOLOv8s = DETECTION_MODEL_DIR / "yolov8s.pt"
|
26 |
+
YOLOv8m = DETECTION_MODEL_DIR / "yolov8m.pt"
|
27 |
+
YOLOv8l = DETECTION_MODEL_DIR / "yolov8l.pt"
|
28 |
+
YOLOv8x = DETECTION_MODEL_DIR / "yolov8x.pt"
|
29 |
+
|
30 |
+
DETECTION_MODEL_LIST = [
|
31 |
+
"目标检测.pt","实例分割.pt","图像分类.pt"]
|
32 |
+
|
33 |
+
INSTANCE_SEGMENTATION_MODEL_LIST = [
|
34 |
+
"目标检测.pt","实例分割.pt","图像分类.pt"]
|
35 |
+
DEFAULT_INSTANCE_SEGMENTATION_MODEL = "实例分割.pt"
|
36 |
+
|
37 |
+
CLASSIFICATION_MODEL_LIST = [
|
38 |
+
"目标检测.pt","实例分割.pt","图像分类.pt"]
|
39 |
+
DEFAULT_CLASSIFICATION_MODEL = "图像分类.pt"
|
shared.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
upload_records = []
|
utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shared import upload_records
|
2 |
+
from ultralytics import YOLO
|
3 |
+
import streamlit as st
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import tempfile
|
8 |
+
import datetime
|
9 |
+
import os
|
10 |
+
import io
|
11 |
+
import time
|
12 |
+
|
13 |
+
def _display_detected_frames(conf, model, st_frame, image, save_path, task_type):
|
14 |
+
"""
|
15 |
+
Display the detected objects on a video frame using the YOLO model.
|
16 |
+
:param conf (float): Confidence threshold for object detection.
|
17 |
+
:param model (YOLO): An instance of the YOLO class containing the YOLO model.
|
18 |
+
:param st_frame (Streamlit object): A Streamlit object to display the detected video.
|
19 |
+
:param image (numpy array): A numpy array representing the video frame.
|
20 |
+
:param save_path (str): The path to save the results.
|
21 |
+
:param task_type (str): The type of task, either 'detection' or 'segmentation'.
|
22 |
+
:return: None
|
23 |
+
"""
|
24 |
+
# Ensure the image is a 3-channel彩色图像
|
25 |
+
if image.ndim == 2 or image.shape[2] == 1: # 灰度图像或单通道
|
26 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
27 |
+
elif image.shape[2] == 4: # 四通道RGBA图像
|
28 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
|
29 |
+
|
30 |
+
# Resize the image to the standard size expected by the model
|
31 |
+
image_resized = cv2.resize(image, (640, 480))
|
32 |
+
|
33 |
+
# Perform object detection or segmentation using the YOLO model
|
34 |
+
results = model.predict(image_resized, conf=conf)
|
35 |
+
|
36 |
+
# Convert the results to the correct format for display and saving
|
37 |
+
if task_type == 'detection':
|
38 |
+
result_image = results[0].plot()
|
39 |
+
else: # segmentation
|
40 |
+
result_image = results[0].plot()
|
41 |
+
|
42 |
+
# Convert from BGR to RGB for Streamlit display
|
43 |
+
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
|
44 |
+
|
45 |
+
# Resize the result image to the fixed output size (750, 500) while maintaining aspect ratio
|
46 |
+
h, w = result_image_rgb.shape[:2]
|
47 |
+
scale_factor = min(550 / w, 450 / h)
|
48 |
+
new_w, new_h = int(w * scale_factor), int(h * scale_factor)
|
49 |
+
result_image_resized = cv2.resize(result_image_rgb, (new_w, new_h))
|
50 |
+
|
51 |
+
# Pad the image to ensure it is 750x500
|
52 |
+
padded_image = np.full((500, 750, 3), 255, dtype=np.uint8) # Create a white background
|
53 |
+
start_x = (750 - new_w) // 2
|
54 |
+
start_y = (500 - new_h) // 2
|
55 |
+
padded_image[start_y:start_y+new_h, start_x:start_x+new_w, :] = result_image_resized
|
56 |
+
|
57 |
+
# Display the frame with detections or segmentations in the Streamlit app
|
58 |
+
st_frame.image(
|
59 |
+
padded_image, # Directly use RGB image for display
|
60 |
+
caption=f'运行结果',
|
61 |
+
use_column_width=True
|
62 |
+
)
|
63 |
+
|
64 |
+
# If a save path is provided, save the frame with detections or segmentations
|
65 |
+
if save_path:
|
66 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
67 |
+
filename = f"{task_type}_frame_{timestamp}.png"
|
68 |
+
save_path_full = os.path.join(save_path, filename)
|
69 |
+
# Save the padded image in RGB format
|
70 |
+
cv2.imwrite(save_path_full, result_image_resized) # Save in RGB format
|
71 |
+
st.write(f"文件保存在: {save_path_full}")
|
72 |
+
@st.cache_resource
|
73 |
+
def load_model(model_path):
|
74 |
+
"""
|
75 |
+
Loads a YOLO object detection or segmentation model from the specified model_path.
|
76 |
+
Parameters:
|
77 |
+
model_path (str): The path to the YOLO model file.
|
78 |
+
Returns:
|
79 |
+
A YOLO object detection or segmentation model.
|
80 |
+
"""
|
81 |
+
model = YOLO(model_path)
|
82 |
+
return model
|
83 |
+
|
84 |
+
def infer_uploaded_image(conf, model, save_path, task_type):
|
85 |
+
"""
|
86 |
+
Execute inference for uploaded images in batch.
|
87 |
+
:param conf: Confidence of YOLO model
|
88 |
+
:param model: An instance of the YOLO class containing the YOLO model.
|
89 |
+
:param save_path: The path to save the results.
|
90 |
+
:param task_type: The type of task, either 'detection' or 'segmentation'.
|
91 |
+
:return: None
|
92 |
+
"""
|
93 |
+
source_imgs = st.sidebar.file_uploader(
|
94 |
+
"选择图像",
|
95 |
+
type=("jpg", "jpeg", "png", 'bmp', 'webp'),
|
96 |
+
accept_multiple_files=True,
|
97 |
+
)
|
98 |
+
|
99 |
+
if source_imgs:
|
100 |
+
for img_info in source_imgs:
|
101 |
+
file_type = os.path.splitext(img_info.name)[1][1:].lower()
|
102 |
+
upload_records.append({
|
103 |
+
"file_name": img_info.name,
|
104 |
+
"file_type": file_type,
|
105 |
+
"uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
106 |
+
})
|
107 |
+
|
108 |
+
uploaded_image = Image.open(img_info)
|
109 |
+
img_byte_arr = io.BytesIO()
|
110 |
+
uploaded_image.save(img_byte_arr, format=file_type.upper() if file_type != 'jpg' else 'JPEG')
|
111 |
+
img_byte_arr = img_byte_arr.getvalue()
|
112 |
+
image = np.array(Image.open(io.BytesIO(img_byte_arr)))
|
113 |
+
|
114 |
+
st.image(
|
115 |
+
img_byte_arr,
|
116 |
+
caption=f"上传的图像: {img_info.name}",
|
117 |
+
use_column_width=True
|
118 |
+
)
|
119 |
+
|
120 |
+
with st.spinner("正在运行..."):
|
121 |
+
_display_detected_frames(conf, model, st.empty(), image, save_path, task_type)
|
122 |
+
|
123 |
+
def infer_uploaded_video(conf, model, save_path, task_type):
|
124 |
+
"""
|
125 |
+
Execute inference for uploaded video and display the detected objects on the video.
|
126 |
+
:param conf: Confidence of YOLO model
|
127 |
+
:param model: An instance of the YOLO class containing the YOLO model.
|
128 |
+
:param save_path: The path to save the results.
|
129 |
+
:param task_type: The type of task, either 'detection' or 'segmentation'.
|
130 |
+
:return: None
|
131 |
+
"""
|
132 |
+
source_video = st.sidebar.file_uploader(
|
133 |
+
"选择视频",
|
134 |
+
accept_multiple_files=True
|
135 |
+
)
|
136 |
+
|
137 |
+
if source_video:
|
138 |
+
for video_file in source_video:
|
139 |
+
file_type = os.path.splitext(video_file.name)[1][1:].lower()
|
140 |
+
upload_records.append({
|
141 |
+
"file_name": video_file.name,
|
142 |
+
"file_type": file_type,
|
143 |
+
"uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
144 |
+
})
|
145 |
+
|
146 |
+
st.video(video_file)
|
147 |
+
|
148 |
+
if st.button("开始运行"):
|
149 |
+
with st.spinner("运行中..."):
|
150 |
+
try:
|
151 |
+
tfile = tempfile.NamedTemporaryFile()
|
152 |
+
tfile.write(video_file.read())
|
153 |
+
vid_cap = cv2.VideoCapture(tfile.name)
|
154 |
+
st_frame = st.empty()
|
155 |
+
frame_rate = vid_cap.get(cv2.CAP_PROP_FPS)
|
156 |
+
delay = int(1000 / frame_rate)
|
157 |
+
|
158 |
+
start_time = time.time()
|
159 |
+
while True:
|
160 |
+
success, image = vid_cap.read()
|
161 |
+
if not success:
|
162 |
+
break
|
163 |
+
|
164 |
+
current_time = time.time()
|
165 |
+
if current_time - start_time >= 1.0:
|
166 |
+
_display_detected_frames(conf, model, st_frame, image, save_path, task_type)
|
167 |
+
start_time = current_time
|
168 |
+
|
169 |
+
vid_cap.release()
|
170 |
+
except Exception as e:
|
171 |
+
st.error(f"Error loading video: {e}")
|
172 |
+
|
173 |
+
def infer_uploaded_webcam(conf, model, save_path, task_type):
|
174 |
+
"""
|
175 |
+
Execute inference for webcam.
|
176 |
+
:param conf: Confidence of YOLO model
|
177 |
+
:param model: An instance of the YOLO class containing the YOLO model.
|
178 |
+
:param save_path: The path to save the results.
|
179 |
+
:param task_type: The type of task, either 'detection' or 'segmentation'.
|
180 |
+
:return: None
|
181 |
+
"""
|
182 |
+
try:
|
183 |
+
flag = st.button(
|
184 |
+
"关闭摄像头"
|
185 |
+
)
|
186 |
+
vid_cap = cv2.VideoCapture(0)
|
187 |
+
st_frame = st.empty()
|
188 |
+
while not flag:
|
189 |
+
success, image = vid_cap.read()
|
190 |
+
if success:
|
191 |
+
_display_detected_frames(conf, model, st_frame, image, save_path, task_type)
|
192 |
+
else:
|
193 |
+
vid_cap.release()
|
194 |
+
break
|
195 |
+
except Exception as e:
|
196 |
+
st.error(f"Error loading video: {str(e)}")
|