yolos_demo / app.py
zoheb's picture
second commit
b18860c
raw
history blame contribute delete
No virus
6.34 kB
import shutil
import cv2
from PIL import Image
import streamlit as st
from transformers import AutoModelForObjectDetection, AutoFeatureExtractor
import torch
import matplotlib.pyplot as plt
from stqdm import stqdm
from pathlib import Path
from moviepy.editor import VideoFileClip
# Load the model
best_model_path = "zoheb/yolos-small-balloon"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = AutoFeatureExtractor.from_pretrained(best_model_path, size=512, max_size=864)
model_pt = AutoModelForObjectDetection.from_pretrained(best_model_path).to(device)
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
# Edit video
def cut_video(clip=None):
with st.form("edit"):
duration = int(clip.duration)
st.write("Edit a small part of video")
start = st.sidebar.number_input('Start time (seconds):',max_value=duration)
end = st.sidebar.number_input('End time (seconds):',min_value=start+3,max_value=duration)
submitted = st.form_submit_button("Edit Out")
if submitted:
clip = clip.subclip(start, end)
clip.write_videofile("edit.mp4")
return clip
# Convert Video to Frames
def video_to_frames(video, dir):
cap = cv2.VideoCapture(str(video))
success, image = cap.read()
frame_count = 0
while success:
frameId = int(round(cap.get(1))) # current frame number
if frameId % 5 == 0:
cv2.imwrite(f"{str(dir)}/frame_{frame_count}.jpg", image)
frame_count += 1
success, image = cap.read()
cap.release()
#print (f"No. of frames {frame_count}")
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
# rescale bboxes
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
# Save predicted frame
def save_results(pil_img, prob, boxes, mod_img_path):
plt.figure(figsize=(18,10))
plt.imshow(pil_img)
id2label = {0: 'balloon'}
ax = plt.gca()
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{id2label[cl.item()]}: {p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.tight_layout(pad=0)
plt.savefig(mod_img_path, transparent=True)
plt.close()
# Save predictions
def save_predictions(image, outputs, mod_img_path, threshold=0.9):
# keep only predictions with confidence >= threshold
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > threshold
# convert predicted boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
# save results
save_results(image, probas[keep], bboxes_scaled, mod_img_path)
# Predict on frames
def predict_on_frames(dir, mod_dir):
files = [f for f in dir.glob('*.jpg') if f.is_file()]
#for sorting the file names properly
files.sort(key = lambda x: int(x.stem[6:]))
for file in stqdm(files, desc="Generating... this is a slow task"):
filename = Path(dir, file)
#print(filename)
#reading each files
img = Image.open(str(filename))
# extract features
img_ftr = feature_extractor(images=img, return_tensors="pt")
pixel_values = img_ftr["pixel_values"].to(device)
# forward pass to get class logits and bounding boxes
outputs = model_pt(pixel_values=pixel_values)
mod_img_path = Path(mod_dir, file.name)
save_predictions(img, outputs, mod_img_path)
# Convert frames to video
def frames_to_video(dir, path, fps=5):
frame_array = []
files = [f for f in dir.glob('*.jpg') if f.is_file()]
#for sorting the file names properly
files.sort(key = lambda x: int(x.stem[6:]))
for file in files:
filename = Path(dir, file)
#reading each files
img = cv2.imread(str(filename))
height, width, _ = img.shape
size = (width, height)
#print(filename)
#inserting the frames into an image array
frame_array.append(img)
out = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
for item in frame_array:
# writing to a image array
out.write(item)
out.release()
# Display video
def display(path):
video_file = open(str(path), 'rb')
video_bytes = video_file.read()
st.video(video_bytes)
# Main
if __name__=='__main__':
st.title('Detect Balloons using YOLOS')
# All dir and Files
BASE_DIR = Path(__file__).parent.absolute()
FRAMES_DIR = Path(BASE_DIR, "extracted_images")
MOD_DIR = Path(BASE_DIR, "modified_images")
if FRAMES_DIR.exists() and FRAMES_DIR.is_dir():
shutil.rmtree(FRAMES_DIR)
FRAMES_DIR.mkdir(parents=True, exist_ok=True)
if MOD_DIR.exists() and MOD_DIR.is_dir():
shutil.rmtree(MOD_DIR)
MOD_DIR.mkdir(parents=True, exist_ok=True)
edited_video = Path(BASE_DIR, "edit.mp4")
generated_video = Path(BASE_DIR, "balloons.mp4")
# Upload the video
uploaded_file = st.sidebar.file_uploader("Upload a small video containing Balloons", type=["mp4", "mpeg"])
if uploaded_file is not None:
st.video(uploaded_file)
vid = uploaded_file.name
st.info(f'Uploaded {vid}')
with open(vid, mode='wb') as f:
f.write(uploaded_file.read())
uploaded_video = Path(BASE_DIR, vid)
clip = VideoFileClip(vid)
clip = cut_video(clip)
if clip is not None:
# Detect balloon in the frames and generate video
try:
st.info('View Edited Clip')
display(edited_video)
video_to_frames(edited_video, FRAMES_DIR)
predict_on_frames(FRAMES_DIR, MOD_DIR)
frames_to_video(MOD_DIR, generated_video)
st.success("Successfully Generated!!")
# Video file Generated
display(generated_video)
st.download_button('Download the Video', open(str(generated_video), 'rb').read(), file_name=generated_video.name)
except Exception as e:
st.error(f"Could not convert the file due to {e}")
else:
st.error("Please submit an edited clip.")
else:
st.info('File Not Uploaded Yet!!!')