Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""🎬 Keras Video Classification CNN-RNN model | |
Spaces for showing the model usage. | |
Author: | |
- Thomas Chaigneau @ChainYo | |
""" | |
import os | |
import cv2 | |
import imageio | |
import gradio as gr | |
import numpy as np | |
from tensorflow import keras | |
from tensorflow_docs.vis import embed | |
from huggingface_hub import from_pretrained_keras | |
IMG_SIZE = 224 | |
NUM_FEATURES = 2048 | |
model = from_pretrained_keras("keras-io/video-classification-cnn-rnn") | |
samples = [] | |
for file in os.listdir("samples"): | |
tag = file.split("_")[1] | |
samples.append([f"samples/{file}", 20]) | |
def crop_center_square(frame): | |
y, x = frame.shape[0:2] | |
min_dim = min(y, x) | |
start_x = (x // 2) - (min_dim // 2) | |
start_y = (y // 2) - (min_dim // 2) | |
return frame[start_y : start_y + min_dim, start_x : start_x + min_dim] | |
def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)): | |
cap = cv2.VideoCapture(path) | |
frames = [] | |
try: | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = crop_center_square(frame) | |
frame = cv2.resize(frame, resize) | |
frame = frame[:, :, [2, 1, 0]] | |
frames.append(frame) | |
if len(frames) == max_frames: | |
break | |
finally: | |
cap.release() | |
return np.array(frames) | |
def build_feature_extractor(): | |
feature_extractor = keras.applications.InceptionV3( | |
weights="imagenet", | |
include_top=False, | |
pooling="avg", | |
input_shape=(IMG_SIZE, IMG_SIZE, 3), | |
) | |
preprocess_input = keras.applications.inception_v3.preprocess_input | |
inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3)) | |
preprocessed = preprocess_input(inputs) | |
outputs = feature_extractor(preprocessed) | |
return keras.Model(inputs, outputs, name="feature_extractor") | |
feature_extractor = build_feature_extractor() | |
def prepare_video(frames, max_seq_length: int = 20): | |
frames = frames[None, ...] | |
frame_mask = np.zeros(shape=(1, max_seq_length,), dtype="bool") | |
frame_features = np.zeros(shape=(1, max_seq_length, NUM_FEATURES), dtype="float32") | |
for i, batch in enumerate(frames): | |
video_length = batch.shape[0] | |
length = min(max_seq_length, video_length) | |
for j in range(length): | |
frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :]) | |
frame_mask[i, :length] = 1 # 1 = not masked, 0 = masked | |
return frame_features, frame_mask | |
def sequence_prediction(path): | |
class_vocab = ["CricketShot", "PlayingCello", "Punch", "ShavingBeard", "TennisSwing"] | |
frames = load_video(path) | |
frame_features, frame_mask = prepare_video(frames) | |
probabilities = model.predict([frame_features, frame_mask])[0] | |
preds = {} | |
for i in np.argsort(probabilities)[::-1]: | |
preds[class_vocab[i]] = float(probabilities[i]) | |
gif = to_gif(frames) | |
return preds, gif | |
def to_gif(images): | |
converted_images = images.astype(np.uint8) | |
imageio.mimsave("animation.gif", converted_images, fps=10) | |
return embed.embed_file(converted_images) | |
article = article = "<div style='text-align: center;'><a href='https://github.com/ChainYo' target='_blank'>Space by Thomas Chaigneau</a><br><a href='https://keras.io/examples/vision/video_classification/' target='_blank'>Keras example by Sayak Paul</a></div>" | |
app = gr.Interface( | |
sequence_prediction, | |
inputs=[gr.inputs.Video(label="Video", type="avi")], | |
outputs=[ | |
gr.outputs.Label(label="Prediction", type="confidences"), | |
gr.outputs.Image(label="GIF", type="gif"), | |
], | |
title="Keras Video Classification CNN-RNN model", | |
description="Keras Working Group", | |
article=article, | |
# examples=samples | |
).launch(enable_queue=True) |