Spaces:
Sleeping
Sleeping
from shared import upload_records | |
from ultralytics import YOLO | |
import streamlit as st | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import tempfile | |
import datetime | |
import os | |
import io | |
import time | |
def _display_detected_frames(conf, model, st_frame, image, save_path, task_type): | |
""" | |
Display the detected objects on a video frame using the YOLO model. | |
:param conf (float): Confidence threshold for object detection. | |
:param model (YOLO): An instance of the YOLO class containing the YOLO model. | |
:param st_frame (Streamlit object): A Streamlit object to display the detected video. | |
:param image (numpy array): A numpy array representing the video frame. | |
:param save_path (str): The path to save the results. | |
:param task_type (str): The type of task, either 'detection' or 'segmentation'. | |
:return: None | |
""" | |
# Ensure the image is a 3-channel彩色图像 | |
if image.ndim == 2 or image.shape[2] == 1: # 灰度图像或单通道 | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
elif image.shape[2] == 4: # 四通道RGBA图像 | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) | |
# Resize the image to the standard size expected by the model | |
image_resized = cv2.resize(image, (640, 480)) | |
# Perform object detection or segmentation using the YOLO model | |
results = model.predict(image_resized, conf=conf) | |
# Convert the results to the correct format for display and saving | |
if task_type == 'detection': | |
result_image = results[0].plot() | |
else: # segmentation | |
result_image = results[0].plot() | |
# Convert from BGR to RGB for Streamlit display | |
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB) | |
# Resize the result image to the fixed output size (750, 500) while maintaining aspect ratio | |
h, w = result_image_rgb.shape[:2] | |
scale_factor = min(550 / w, 450 / h) | |
new_w, new_h = int(w * scale_factor), int(h * scale_factor) | |
result_image_resized = cv2.resize(result_image_rgb, (new_w, new_h)) | |
# Pad the image to ensure it is 750x500 | |
padded_image = np.full((500, 750, 3), 255, dtype=np.uint8) # Create a white background | |
start_x = (750 - new_w) // 2 | |
start_y = (500 - new_h) // 2 | |
padded_image[start_y:start_y+new_h, start_x:start_x+new_w, :] = result_image_resized | |
# Display the frame with detections or segmentations in the Streamlit app | |
st_frame.image( | |
padded_image, # Directly use RGB image for display | |
caption=f'运行结果', | |
use_column_width=True | |
) | |
# If a save path is provided, save the frame with detections or segmentations | |
if save_path: | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"{task_type}_frame_{timestamp}.png" | |
save_path_full = os.path.join(save_path, filename) | |
# Save the padded image in RGB format | |
cv2.imwrite(save_path_full, result_image_resized) # Save in RGB format | |
st.write(f"文件保存在: {save_path_full}") | |
def load_model(model_path): | |
""" | |
Loads a YOLO object detection or segmentation model from the specified model_path. | |
Parameters: | |
model_path (str): The path to the YOLO model file. | |
Returns: | |
A YOLO object detection or segmentation model. | |
""" | |
model = YOLO(model_path) | |
return model | |
def infer_uploaded_image(conf, model, save_path, task_type): | |
""" | |
Execute inference for uploaded images in batch. | |
:param conf: Confidence of YOLO model | |
:param model: An instance of the YOLO class containing the YOLO model. | |
:param save_path: The path to save the results. | |
:param task_type: The type of task, either 'detection' or 'segmentation'. | |
:return: None | |
""" | |
source_imgs = st.sidebar.file_uploader( | |
"选择图像", | |
type=("jpg", "jpeg", "png", 'bmp', 'webp'), | |
accept_multiple_files=True, | |
) | |
if source_imgs: | |
for img_info in source_imgs: | |
file_type = os.path.splitext(img_info.name)[1][1:].lower() | |
upload_records.append({ | |
"file_name": img_info.name, | |
"file_type": file_type, | |
"uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
}) | |
uploaded_image = Image.open(img_info) | |
img_byte_arr = io.BytesIO() | |
uploaded_image.save(img_byte_arr, format=file_type.upper() if file_type != 'jpg' else 'JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
image = np.array(Image.open(io.BytesIO(img_byte_arr))) | |
st.image( | |
img_byte_arr, | |
caption=f"上传的图像: {img_info.name}", | |
use_column_width=True | |
) | |
with st.spinner("正在运行..."): | |
_display_detected_frames(conf, model, st.empty(), image, save_path, task_type) | |
def infer_uploaded_video(conf, model, save_path, task_type): | |
""" | |
Execute inference for uploaded video and display the detected objects on the video. | |
:param conf: Confidence of YOLO model | |
:param model: An instance of the YOLO class containing the YOLO model. | |
:param save_path: The path to save the results. | |
:param task_type: The type of task, either 'detection' or 'segmentation'. | |
:return: None | |
""" | |
source_video = st.sidebar.file_uploader( | |
"选择视频", | |
accept_multiple_files=True | |
) | |
if source_video: | |
for video_file in source_video: | |
file_type = os.path.splitext(video_file.name)[1][1:].lower() | |
upload_records.append({ | |
"file_name": video_file.name, | |
"file_type": file_type, | |
"uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
}) | |
st.video(video_file) | |
if st.button("开始运行"): | |
with st.spinner("运行中..."): | |
try: | |
tfile = tempfile.NamedTemporaryFile() | |
tfile.write(video_file.read()) | |
vid_cap = cv2.VideoCapture(tfile.name) | |
st_frame = st.empty() | |
frame_rate = vid_cap.get(cv2.CAP_PROP_FPS) | |
delay = int(1000 / frame_rate) | |
start_time = time.time() | |
while True: | |
success, image = vid_cap.read() | |
if not success: | |
break | |
current_time = time.time() | |
if current_time - start_time >= 1.0: | |
_display_detected_frames(conf, model, st_frame, image, save_path, task_type) | |
start_time = current_time | |
vid_cap.release() | |
except Exception as e: | |
st.error(f"Error loading video: {e}") | |
def infer_uploaded_webcam(conf, model, save_path, task_type): | |
""" | |
Execute inference for webcam. | |
:param conf: Confidence of YOLO model | |
:param model: An instance of the YOLO class containing the YOLO model. | |
:param save_path: The path to save the results. | |
:param task_type: The type of task, either 'detection' or 'segmentation'. | |
:return: None | |
""" | |
try: | |
flag = st.button( | |
"关闭摄像头" | |
) | |
vid_cap = cv2.VideoCapture(0) | |
st_frame = st.empty() | |
while not flag: | |
success, image = vid_cap.read() | |
if success: | |
_display_detected_frames(conf, model, st_frame, image, save_path, task_type) | |
else: | |
vid_cap.release() | |
break | |
except Exception as e: | |
st.error(f"Error loading video: {str(e)}") |