MIT-Fishery-App / app_utils.py
aus10powell's picture
Update app_utils.py
62b7bc7
"""
app_utils.py
Description: This file contains utility functions to support the Streamlit app.
These functions handle file processing, video conversion, and inference running
on uploaded images and videos.
Author: Austin Powell
"""
import pandas as pd
import altair as alt
import io
import av
from tqdm import tqdm
import numpy as np
import logging
import streamlit as st
def extract_file_datetime(fname):
"""Extract datetime from file name
Args:
fname (str): File name
Returns:
pd.datetime: Datetime extracted from file name
"""
fname = os.path.basename(fname)
dt = fname.split("_")[1]
h,m,s = fname.split("_")[2].split(".")[0].split("-")
return pd.to_datetime(f"{dt} {h}:{m}:{s}")
def frames_to_video(frames=None, fps=12):
"""
Convert frames to video for Streamlit
Args:
frames: frame from cv2.VideoCapture as numpy. E.g. frame.astype(np.uint8)
fps: Frames per second. Useful if the inference video is compressed to slow down for analysis
"""
# Grab information from the first frame
height, width, layers = frames[0].shape
# Create a BytesIO "in memory file"
output_memory_file = io.BytesIO()
# Open "in memory file" as MP4 video output
output = av.open(output_memory_file, "w", format="mp4")
# Add H.264 video stream to the MP4 container, with framerate = fps
stream = output.add_stream("h264", str(fps))
# Set frame width and height
stream.width = width
stream.height = height
# Set pixel format (yuv420p for better compatibility)
stream.pix_fmt = "yuv420p"
# Select low crf for high quality (the price is larger file size)
stream.options = {
"crf": "17"
}
# Iterate through the frames, encode, and write to MP4 memory file
logging.info("INFO: Encoding frames and writing to MP4 format.")
for frame in tqdm(frames):
# Convert frame to av.VideoFrame format
frame = av.VideoFrame.from_ndarray(frame.astype(np.uint8), format="bgr24")
# Encode the video frame
packet = stream.encode(frame)
# "Mux" the encoded frame (add the encoded frame to MP4 file)
output.mux(packet)
# Flush the encoder
packet = stream.encode(None)
output.mux(packet)
# Close the output video file
output.close()
# Reset the file pointer to the beginning of the memory file
output_memory_file.seek(0)
# Return the output memory file
return output_memory_file
def process_uploaded_file():
st.subheader("Upload your own video...")
# Initialize accepted file types for upload
img_types = ["jpg", "png", "jpeg"]
video_types = ["mp4", "avi"]
# Allow user to upload an image or video file
uploaded_file = st.file_uploader("Select an image or video file...", type=img_types + video_types)
# Display the uploaded file
if uploaded_file is not None:
if str(uploaded_file.type).split("/")[-1] in img_types:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded image", use_column_width=True)
# TBD: Inference code to run and display for single image
elif str(uploaded_file.type).split("/")[-1] in video_types:
# Display uploaded video
st.video(uploaded_file)
# Convert streamlit video object to OpenCV format to run inferences
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
vf = cv.VideoCapture(tfile.name)
# Run inference on the uploaded video
with st.spinner("Running inference..."):
frames, counts, timestamps = inference.main(vf)
logging.info("INFO: Completed running inference on frames")
st.balloons()
# Convert OpenCV Numpy frames in-memory to IO Bytes for streamlit
streamlit_video_file = frames_to_video(frames=frames, fps=11)
# Show processed video and provide download button
st.video(streamlit_video_file)
st.download_button(
label="Download processed video",
data=streamlit_video_file,
mime="mp4",
file_name="processed_video.mp4",
)
# Create dataframe for fish counts and timestamps
df_counts_time = pd.DataFrame(
data={"fish_count": counts, "timestamps": timestamps[1:]}
)
# Display fish count vs. timestamp chart
st.altair_chart(
plot_count_date(dataframe=df_counts_time),
use_container_width=True,
)
else:
st.write("No file uploaded")