from shared import upload_records from ultralytics import YOLO import streamlit as st import cv2 from PIL import Image import numpy as np import tempfile import datetime import os import io import time def _display_detected_frames(conf, model, st_frame, image, save_path, task_type): """ Display the detected objects on a video frame using the YOLO model. :param conf (float): Confidence threshold for object detection. :param model (YOLO): An instance of the YOLO class containing the YOLO model. :param st_frame (Streamlit object): A Streamlit object to display the detected video. :param image (numpy array): A numpy array representing the video frame. :param save_path (str): The path to save the results. :param task_type (str): The type of task, either 'detection' or 'segmentation'. :return: None """ # Ensure the image is a 3-channel彩色图像 if image.ndim == 2 or image.shape[2] == 1: # 灰度图像或单通道 image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.shape[2] == 4: # 四通道RGBA图像 image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) # Resize the image to the standard size expected by the model image_resized = cv2.resize(image, (640, 480)) # Perform object detection or segmentation using the YOLO model results = model.predict(image_resized, conf=conf) # Convert the results to the correct format for display and saving if task_type == 'detection': result_image = results[0].plot() else: # segmentation result_image = results[0].plot() # Convert from BGR to RGB for Streamlit display result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB) # Resize the result image to the fixed output size (750, 500) while maintaining aspect ratio h, w = result_image_rgb.shape[:2] scale_factor = min(550 / w, 450 / h) new_w, new_h = int(w * scale_factor), int(h * scale_factor) result_image_resized = cv2.resize(result_image_rgb, (new_w, new_h)) # Pad the image to ensure it is 750x500 padded_image = np.full((500, 750, 3), 255, dtype=np.uint8) # Create a white background start_x = (750 - new_w) // 2 start_y = (500 - new_h) // 2 padded_image[start_y:start_y+new_h, start_x:start_x+new_w, :] = result_image_resized # Display the frame with detections or segmentations in the Streamlit app st_frame.image( padded_image, # Directly use RGB image for display caption=f'运行结果', use_column_width=True ) # If a save path is provided, save the frame with detections or segmentations if save_path: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{task_type}_frame_{timestamp}.png" save_path_full = os.path.join(save_path, filename) # Save the padded image in RGB format cv2.imwrite(save_path_full, result_image_resized) # Save in RGB format st.write(f"文件保存在: {save_path_full}") @st.cache_resource def load_model(model_path): """ Loads a YOLO object detection or segmentation model from the specified model_path. Parameters: model_path (str): The path to the YOLO model file. Returns: A YOLO object detection or segmentation model. """ model = YOLO(model_path) return model def infer_uploaded_image(conf, model, save_path, task_type): """ Execute inference for uploaded images in batch. :param conf: Confidence of YOLO model :param model: An instance of the YOLO class containing the YOLO model. :param save_path: The path to save the results. :param task_type: The type of task, either 'detection' or 'segmentation'. :return: None """ source_imgs = st.sidebar.file_uploader( "选择图像", type=("jpg", "jpeg", "png", 'bmp', 'webp'), accept_multiple_files=True, ) if source_imgs: for img_info in source_imgs: file_type = os.path.splitext(img_info.name)[1][1:].lower() upload_records.append({ "file_name": img_info.name, "file_type": file_type, "uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") }) uploaded_image = Image.open(img_info) img_byte_arr = io.BytesIO() uploaded_image.save(img_byte_arr, format=file_type.upper() if file_type != 'jpg' else 'JPEG') img_byte_arr = img_byte_arr.getvalue() image = np.array(Image.open(io.BytesIO(img_byte_arr))) st.image( img_byte_arr, caption=f"上传的图像: {img_info.name}", use_column_width=True ) with st.spinner("正在运行..."): _display_detected_frames(conf, model, st.empty(), image, save_path, task_type) def infer_uploaded_video(conf, model, save_path, task_type): """ Execute inference for uploaded video and display the detected objects on the video. :param conf: Confidence of YOLO model :param model: An instance of the YOLO class containing the YOLO model. :param save_path: The path to save the results. :param task_type: The type of task, either 'detection' or 'segmentation'. :return: None """ source_video = st.sidebar.file_uploader( "选择视频", accept_multiple_files=True ) if source_video: for video_file in source_video: file_type = os.path.splitext(video_file.name)[1][1:].lower() upload_records.append({ "file_name": video_file.name, "file_type": file_type, "uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") }) st.video(video_file) if st.button("开始运行"): with st.spinner("运行中..."): try: tfile = tempfile.NamedTemporaryFile() tfile.write(video_file.read()) vid_cap = cv2.VideoCapture(tfile.name) st_frame = st.empty() frame_rate = vid_cap.get(cv2.CAP_PROP_FPS) delay = int(1000 / frame_rate) start_time = time.time() while True: success, image = vid_cap.read() if not success: break current_time = time.time() if current_time - start_time >= 1.0: _display_detected_frames(conf, model, st_frame, image, save_path, task_type) start_time = current_time vid_cap.release() except Exception as e: st.error(f"Error loading video: {e}") def infer_uploaded_webcam(conf, model, save_path, task_type): """ Execute inference for webcam. :param conf: Confidence of YOLO model :param model: An instance of the YOLO class containing the YOLO model. :param save_path: The path to save the results. :param task_type: The type of task, either 'detection' or 'segmentation'. :return: None """ try: flag = st.button( "关闭摄像头" ) vid_cap = cv2.VideoCapture(0) st_frame = st.empty() while not flag: success, image = vid_cap.read() if success: _display_detected_frames(conf, model, st_frame, image, save_path, task_type) else: vid_cap.release() break except Exception as e: st.error(f"Error loading video: {str(e)}")