kaidewang commited on
Commit
6f50ee4
·
verified ·
1 Parent(s): cf4bac6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +95 -0
  2. config.py +39 -0
  3. shared.py +1 -0
  4. utils.py +196 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pathlib import Path
3
+ import config
4
+ from utils import load_model, infer_uploaded_image, infer_uploaded_video, infer_uploaded_webcam
5
+ from shared import upload_records
6
+ import pandas as pd
7
+
8
+ # 设置网页标题和布局
9
+ st.set_page_config(page_title="工业零件检测", layout="wide", initial_sidebar_state="expanded")
10
+
11
+ st.markdown(
12
+ """
13
+ <h1 style='text-align: center;'>缺陷发现者</h1>
14
+ <h1 style='text-align: right; font-size: 24px; color: #333; font-weight: bold; margin-bottom: 20px;'>
15
+ ——基于YOLOv8的工业零件检测
16
+ </h1>
17
+ """, unsafe_allow_html=True)
18
+
19
+ # 侧边栏:模型配置
20
+ st.sidebar.header("模型配置")
21
+ task_options = ["目标检测", "实例分割", "图像分类"]
22
+ if 'task_type' not in st.session_state:
23
+ st.session_state['task_type'] = task_options[0]
24
+
25
+ task_type = st.sidebar.selectbox(
26
+ "任务选择",
27
+ task_options,
28
+ key="task_type"
29
+ )
30
+ model_path = ""
31
+ if task_type == "目标检测":
32
+ model_type = st.sidebar.selectbox(
33
+ "模型选择",
34
+ config.DETECTION_MODEL_LIST,
35
+ key="model_type_selectbox"
36
+ )
37
+ model_path = Path(config.DETECTION_MODEL_DIR, model_type)
38
+ elif task_type == "实例分割":
39
+ model_type = st.sidebar.selectbox(
40
+ "模型选择",
41
+ config.INSTANCE_SEGMENTATION_MODEL_LIST,
42
+ key="model_type_selectbox",
43
+ index=1 # 默认选择 best-2.pt
44
+ )
45
+ model_path = Path(config.INSTANCE_SEGMENTATION_MODEL_DIR, model_type)
46
+ elif task_type == "图像分类":
47
+ model_type = st.sidebar.selectbox(
48
+ "模型选择",
49
+ config.CLASSIFICATION_MODEL_LIST,
50
+ key="model_type_selectbox",
51
+ index=2 # 默认选择 best-3.pt
52
+ )
53
+ model_path = Path(config.CLASSIFICATION_MODEL_DIR, model_type)
54
+ else:
55
+ st.error("目前仅实现‘目标检测’、‘实例分割’和‘图像分类’功能")
56
+
57
+ confidence = float(st.sidebar.slider("选择模型置信度", 30, 100, 30)) / 100
58
+
59
+ # 加载模型
60
+ try:
61
+ model = load_model(model_path)
62
+ except Exception as e:
63
+ st.error(f"Unable to load model. Please check the specified path: {model_path}")
64
+
65
+ # 图像/视频配置
66
+ st.sidebar.header("图像/视频配置")
67
+ source_selectbox = st.sidebar.selectbox(
68
+ "选择上传类型",
69
+ config.SOURCES_LIST,
70
+ key="source_selectbox"
71
+ )
72
+ save_path = st.sidebar.text_input("输入保存结果的文件夹路径", "请输入路径", key="save_path_input")
73
+
74
+ # 根据选择的上传类型调用相应的函数
75
+ if source_selectbox == config.SOURCES_LIST[0]: # Image
76
+ infer_uploaded_image(confidence, model, save_path, task_type)
77
+ elif source_selectbox == config.SOURCES_LIST[1]: # Video
78
+ infer_uploaded_video(confidence, model, save_path, task_type)
79
+ elif source_selectbox == config.SOURCES_LIST[2]: # Webcam
80
+ infer_uploaded_webcam(confidence, model, save_path, task_type)
81
+ else:
82
+ st.error("目前只支持‘图片’、‘视频’和‘摄像头’类型的上传")
83
+
84
+ # 显示上传记录
85
+ st.subheader("上传记录")
86
+ if upload_records:
87
+ # 创建 DataFrame
88
+ upload_df = pd.DataFrame(upload_records)
89
+ # 重置索引,并且不把旧索引添加到 DataFrame 中
90
+ upload_df.columns = ['文件名', '文件类型', '上传时间']
91
+ upload_df = upload_df.reset_index(drop=True)
92
+ # 显示 DataFrame
93
+ st.table(upload_df)
94
+ else:
95
+ st.write("没有上传记录。")
config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+
4
+ # Get the absolute path of the current file
5
+ file_path = Path(__file__).resolve()
6
+
7
+ # Get the parent directory of the current file
8
+ root_path = file_path.parent
9
+
10
+ # Add the root path to the sys.path list if it is not already there
11
+ if root_path not in sys.path:
12
+ sys.path.append(str(root_path))
13
+
14
+ # Get the relative path of the root directory with respect to the current working directory
15
+ ROOT = root_path.relative_to(Path.cwd())
16
+
17
+ # Source
18
+ SOURCES_LIST = ["图片", "视频", "摄像头"]
19
+
20
+ # DL model config
21
+ DETECTION_MODEL_DIR = ROOT / 'weights' / 'detection'
22
+ INSTANCE_SEGMENTATION_MODEL_DIR = ROOT / 'weights' / 'instance_segmentation'
23
+ CLASSIFICATION_MODEL_DIR = ROOT / 'weights' / 'classification'
24
+ YOLOv8n = DETECTION_MODEL_DIR / "yolov8n.pt"
25
+ YOLOv8s = DETECTION_MODEL_DIR / "yolov8s.pt"
26
+ YOLOv8m = DETECTION_MODEL_DIR / "yolov8m.pt"
27
+ YOLOv8l = DETECTION_MODEL_DIR / "yolov8l.pt"
28
+ YOLOv8x = DETECTION_MODEL_DIR / "yolov8x.pt"
29
+
30
+ DETECTION_MODEL_LIST = [
31
+ "目标检测.pt","实例分割.pt","图像分类.pt"]
32
+
33
+ INSTANCE_SEGMENTATION_MODEL_LIST = [
34
+ "目标检测.pt","实例分割.pt","图像分类.pt"]
35
+ DEFAULT_INSTANCE_SEGMENTATION_MODEL = "实例分割.pt"
36
+
37
+ CLASSIFICATION_MODEL_LIST = [
38
+ "目标检测.pt","实例分割.pt","图像分类.pt"]
39
+ DEFAULT_CLASSIFICATION_MODEL = "图像分类.pt"
shared.py ADDED
@@ -0,0 +1 @@
 
 
1
+ upload_records = []
utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shared import upload_records
2
+ from ultralytics import YOLO
3
+ import streamlit as st
4
+ import cv2
5
+ from PIL import Image
6
+ import numpy as np
7
+ import tempfile
8
+ import datetime
9
+ import os
10
+ import io
11
+ import time
12
+
13
+ def _display_detected_frames(conf, model, st_frame, image, save_path, task_type):
14
+ """
15
+ Display the detected objects on a video frame using the YOLO model.
16
+ :param conf (float): Confidence threshold for object detection.
17
+ :param model (YOLO): An instance of the YOLO class containing the YOLO model.
18
+ :param st_frame (Streamlit object): A Streamlit object to display the detected video.
19
+ :param image (numpy array): A numpy array representing the video frame.
20
+ :param save_path (str): The path to save the results.
21
+ :param task_type (str): The type of task, either 'detection' or 'segmentation'.
22
+ :return: None
23
+ """
24
+ # Ensure the image is a 3-channel彩色图像
25
+ if image.ndim == 2 or image.shape[2] == 1: # 灰度图像或单通道
26
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
27
+ elif image.shape[2] == 4: # 四通道RGBA图像
28
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
29
+
30
+ # Resize the image to the standard size expected by the model
31
+ image_resized = cv2.resize(image, (640, 480))
32
+
33
+ # Perform object detection or segmentation using the YOLO model
34
+ results = model.predict(image_resized, conf=conf)
35
+
36
+ # Convert the results to the correct format for display and saving
37
+ if task_type == 'detection':
38
+ result_image = results[0].plot()
39
+ else: # segmentation
40
+ result_image = results[0].plot()
41
+
42
+ # Convert from BGR to RGB for Streamlit display
43
+ result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
44
+
45
+ # Resize the result image to the fixed output size (750, 500) while maintaining aspect ratio
46
+ h, w = result_image_rgb.shape[:2]
47
+ scale_factor = min(550 / w, 450 / h)
48
+ new_w, new_h = int(w * scale_factor), int(h * scale_factor)
49
+ result_image_resized = cv2.resize(result_image_rgb, (new_w, new_h))
50
+
51
+ # Pad the image to ensure it is 750x500
52
+ padded_image = np.full((500, 750, 3), 255, dtype=np.uint8) # Create a white background
53
+ start_x = (750 - new_w) // 2
54
+ start_y = (500 - new_h) // 2
55
+ padded_image[start_y:start_y+new_h, start_x:start_x+new_w, :] = result_image_resized
56
+
57
+ # Display the frame with detections or segmentations in the Streamlit app
58
+ st_frame.image(
59
+ padded_image, # Directly use RGB image for display
60
+ caption=f'运行结果',
61
+ use_column_width=True
62
+ )
63
+
64
+ # If a save path is provided, save the frame with detections or segmentations
65
+ if save_path:
66
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
67
+ filename = f"{task_type}_frame_{timestamp}.png"
68
+ save_path_full = os.path.join(save_path, filename)
69
+ # Save the padded image in RGB format
70
+ cv2.imwrite(save_path_full, result_image_resized) # Save in RGB format
71
+ st.write(f"文件保存在: {save_path_full}")
72
+ @st.cache_resource
73
+ def load_model(model_path):
74
+ """
75
+ Loads a YOLO object detection or segmentation model from the specified model_path.
76
+ Parameters:
77
+ model_path (str): The path to the YOLO model file.
78
+ Returns:
79
+ A YOLO object detection or segmentation model.
80
+ """
81
+ model = YOLO(model_path)
82
+ return model
83
+
84
+ def infer_uploaded_image(conf, model, save_path, task_type):
85
+ """
86
+ Execute inference for uploaded images in batch.
87
+ :param conf: Confidence of YOLO model
88
+ :param model: An instance of the YOLO class containing the YOLO model.
89
+ :param save_path: The path to save the results.
90
+ :param task_type: The type of task, either 'detection' or 'segmentation'.
91
+ :return: None
92
+ """
93
+ source_imgs = st.sidebar.file_uploader(
94
+ "选择图像",
95
+ type=("jpg", "jpeg", "png", 'bmp', 'webp'),
96
+ accept_multiple_files=True,
97
+ )
98
+
99
+ if source_imgs:
100
+ for img_info in source_imgs:
101
+ file_type = os.path.splitext(img_info.name)[1][1:].lower()
102
+ upload_records.append({
103
+ "file_name": img_info.name,
104
+ "file_type": file_type,
105
+ "uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
106
+ })
107
+
108
+ uploaded_image = Image.open(img_info)
109
+ img_byte_arr = io.BytesIO()
110
+ uploaded_image.save(img_byte_arr, format=file_type.upper() if file_type != 'jpg' else 'JPEG')
111
+ img_byte_arr = img_byte_arr.getvalue()
112
+ image = np.array(Image.open(io.BytesIO(img_byte_arr)))
113
+
114
+ st.image(
115
+ img_byte_arr,
116
+ caption=f"上传的图像: {img_info.name}",
117
+ use_column_width=True
118
+ )
119
+
120
+ with st.spinner("正在运行..."):
121
+ _display_detected_frames(conf, model, st.empty(), image, save_path, task_type)
122
+
123
+ def infer_uploaded_video(conf, model, save_path, task_type):
124
+ """
125
+ Execute inference for uploaded video and display the detected objects on the video.
126
+ :param conf: Confidence of YOLO model
127
+ :param model: An instance of the YOLO class containing the YOLO model.
128
+ :param save_path: The path to save the results.
129
+ :param task_type: The type of task, either 'detection' or 'segmentation'.
130
+ :return: None
131
+ """
132
+ source_video = st.sidebar.file_uploader(
133
+ "选择视频",
134
+ accept_multiple_files=True
135
+ )
136
+
137
+ if source_video:
138
+ for video_file in source_video:
139
+ file_type = os.path.splitext(video_file.name)[1][1:].lower()
140
+ upload_records.append({
141
+ "file_name": video_file.name,
142
+ "file_type": file_type,
143
+ "uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
144
+ })
145
+
146
+ st.video(video_file)
147
+
148
+ if st.button("开始运行"):
149
+ with st.spinner("运行中..."):
150
+ try:
151
+ tfile = tempfile.NamedTemporaryFile()
152
+ tfile.write(video_file.read())
153
+ vid_cap = cv2.VideoCapture(tfile.name)
154
+ st_frame = st.empty()
155
+ frame_rate = vid_cap.get(cv2.CAP_PROP_FPS)
156
+ delay = int(1000 / frame_rate)
157
+
158
+ start_time = time.time()
159
+ while True:
160
+ success, image = vid_cap.read()
161
+ if not success:
162
+ break
163
+
164
+ current_time = time.time()
165
+ if current_time - start_time >= 1.0:
166
+ _display_detected_frames(conf, model, st_frame, image, save_path, task_type)
167
+ start_time = current_time
168
+
169
+ vid_cap.release()
170
+ except Exception as e:
171
+ st.error(f"Error loading video: {e}")
172
+
173
+ def infer_uploaded_webcam(conf, model, save_path, task_type):
174
+ """
175
+ Execute inference for webcam.
176
+ :param conf: Confidence of YOLO model
177
+ :param model: An instance of the YOLO class containing the YOLO model.
178
+ :param save_path: The path to save the results.
179
+ :param task_type: The type of task, either 'detection' or 'segmentation'.
180
+ :return: None
181
+ """
182
+ try:
183
+ flag = st.button(
184
+ "关闭摄像头"
185
+ )
186
+ vid_cap = cv2.VideoCapture(0)
187
+ st_frame = st.empty()
188
+ while not flag:
189
+ success, image = vid_cap.read()
190
+ if success:
191
+ _display_detected_frames(conf, model, st_frame, image, save_path, task_type)
192
+ else:
193
+ vid_cap.release()
194
+ break
195
+ except Exception as e:
196
+ st.error(f"Error loading video: {str(e)}")