Spaces:
Runtime error
Runtime error
import json | |
from datetime import datetime | |
from time import time | |
from typing import List, Optional, Tuple | |
import cv2 | |
import pandas as pd | |
import torch | |
from tap import Tap | |
from torch import Tensor | |
from transformers import ( | |
AutoFeatureExtractor, | |
TimesformerForVideoClassification, | |
VideoMAEFeatureExtractor, | |
) | |
from utils.img_container import ImgContainer | |
class ArgParser(Tap): | |
is_recording: Optional[bool] = False | |
# "facebook/timesformer-base-finetuned-k400" | |
# "facebook/timesformer-base-finetuned-k600", | |
# "facebook/timesformer-base-finetuned-ssv2", | |
# "facebook/timesformer-hr-finetuned-k600", | |
# "facebook/timesformer-hr-finetuned-k400", | |
# "facebook/timesformer-hr-finetuned-ssv2", | |
# "fcakyon/timesformer-large-finetuned-k400", | |
# "fcakyon/timesformer-large-finetuned-k600", | |
model_name: Optional[str] = "facebook/timesformer-base-finetuned-k400" | |
num_skip_frames: Optional[int] = 2 | |
top_k: Optional[int] = 5 | |
id2label: Optional[str] = "labels/kinetics_400.json" | |
threshold: Optional[float] = 10.0 # 10.0 | |
max_confidence: Optional[float] = 20.0 # Set None if not scale | |
class ActivityModel: | |
def __init__(self, args: ArgParser): | |
self.feature_extractor, self.model = self.load_model(args.model_name) | |
self.args = args | |
self.frames_per_video = self.get_frames_per_video(args.model_name) | |
print(f"Frames per video: {self.frames_per_video}") | |
self.load_json() | |
self.diary: List[ | |
Tuple[str, int, str, float] | |
] = [] # [time, activity, confidence] | |
def save_diary(self): | |
df = pd.DataFrame( | |
self.diary, columns=["time", "timestamp", "activity", "confidence"] | |
) | |
df.to_csv("diary.csv") | |
df.to_excel("diary.xlsx") | |
def load_json(self): | |
if args.id2label is not None: | |
with open(args.id2label, encoding="utf-8") as f: | |
tmp = json.load(f) | |
d = dict() | |
for key, item in tmp.items(): | |
d[int(key)] = item | |
self.model.config.id2label = d | |
def load_model( | |
self, model_name: str | |
) -> Tuple[VideoMAEFeatureExtractor, TimesformerForVideoClassification]: | |
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name: | |
feature_extractor = AutoFeatureExtractor.from_pretrained( | |
"MCG-NJU/videomae-base-finetuned-kinetics" | |
) | |
else: | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = TimesformerForVideoClassification.from_pretrained(model_name) | |
return feature_extractor, model | |
def inference(self, img_container: ImgContainer): | |
if not img_container.ready: | |
return | |
inputs = self.feature_extractor(list(img_container.imgs), return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits: Tensor = outputs.logits | |
# model predicts one of the 400 Kinetics-400 classes | |
max_index = logits.argmax(-1).item() | |
if max_index not in self.model.config.id2label: | |
return | |
predicted_label = self.model.config.id2label[max_index] | |
confidence = logits[0][max_index].item() | |
if (self.args.threshold is None) or ( | |
self.args.threshold is not None and confidence >= self.args.threshold | |
): | |
img_container.frame_rate.label = f"{predicted_label}_{confidence:.2f}%" | |
self.diary.append( | |
(str(datetime.now()), int(time()), predicted_label, confidence) | |
) | |
# logits = np.squeeze(logits) | |
# logits = logits.squeeze().numpy() | |
# indices = np.argsort(logits)[::-1][: self.args.top_k] | |
# values = logits[indices] | |
# results: List[Tuple[str, float]] = [] | |
# for index, value in zip(indices, values): | |
# predicted_label = self.model.config.id2label[index] | |
# # print(f"Label: {predicted_label} - {value:.2f}%") | |
# results.append((predicted_label, value)) | |
# img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence")) | |
def get_frames_per_video(self, model_name: str) -> int: | |
if "base-finetuned" in model_name: | |
return 8 | |
elif "hr-finetuned" in model_name: | |
return 16 | |
else: | |
return 96 | |
def main(args: ArgParser): | |
activity_model = ActivityModel(args) | |
img_container = ImgContainer(activity_model.frames_per_video, args.is_recording) | |
num_skips = 0 | |
# define a video capture object | |
camera = cv2.VideoCapture(0) | |
frame_width = int(camera.get(3)) | |
frame_height = int(camera.get(4)) | |
size = (frame_width, frame_height) | |
video_output = cv2.VideoWriter( | |
"activities.mp4", cv2.VideoWriter_fourcc(*"MP4V"), 10, size | |
) | |
if camera.isOpened() == False: | |
print("Error reading video file") | |
while camera.isOpened(): | |
# Capture the video frame | |
# by frame | |
ret, frame = camera.read() | |
num_skips = (num_skips + 1) % args.num_skip_frames | |
img_container.img = frame | |
img_container.frame_rate.count() | |
if num_skips == 0: | |
img_container.add_frame(frame) | |
activity_model.inference(img_container) | |
rs = img_container.frame_rate.show_fps(frame, img_container.is_recording) | |
# Display the resulting frame | |
cv2.imshow("ActivityTracking", rs) | |
if img_container.is_recording: | |
video_output.write(rs) | |
# the 'q' button is set as the | |
# quitting button you may use any | |
# desired button of your choice | |
k = cv2.waitKey(1) | |
if k == ord("q"): | |
break | |
elif k == ord("r"): | |
img_container.toggle_recording() | |
activity_model.save_diary() | |
# After the loop release the cap object | |
camera.release() | |
video_output.release() | |
# Destroy all the windows | |
cv2.destroyAllWindows() | |
if __name__ == "__main__": | |
args = ArgParser().parse_args() | |
main(args) | |