MNGames commited on
Commit
16756b8
1 Parent(s): 7fcc53a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -1,27 +1,27 @@
1
  import gradio as gr
2
- from transformers import AutoModelForVideoClassification, AutoTokenizer, VideoClassificationPipeline
3
  import cv2 # OpenCV for video processing
4
 
5
  # Model ID for video classification (UCF101 subset)
6
- model_id = "sayakpaul/videomae-base-finetuned-ucf101-subset"
7
 
8
  def analyze_video(video):
9
  # Extract key frames from the video using OpenCV
10
  frames = extract_key_frames(video)
11
 
12
- # Load model and tokenizer manually
13
- model = AutoModelForVideoClassification.from_pretrained(model_id)
14
- tokenizer = AutoTokenizer.from_pretrained(model_id)
15
 
16
  # Create the pipeline
17
- classifier = VideoClassificationPipeline(model=model, tokenizer=tokenizer, device=-1)
18
 
19
  # Analyze key frames using video classification model
20
  results = []
21
  for frame in frames:
22
  # OpenCV uses BGR, convert to RGB for the model
23
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
24
- predictions = classifier(images=[frame_rgb]) # Assuming model outputs probabilities
25
  # Analyze predictions for insights related to the play
26
  result = analyze_predictions_ucf101(predictions)
27
  results.append(result)
 
1
  import gradio as gr
2
+ from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor, VideoClassificationPipeline
3
  import cv2 # OpenCV for video processing
4
 
5
  # Model ID for video classification (UCF101 subset)
6
+ model_id = "MCG-NJU/videomae-base"
7
 
8
  def analyze_video(video):
9
  # Extract key frames from the video using OpenCV
10
  frames = extract_key_frames(video)
11
 
12
+ # Load model and feature extractor manually
13
+ model = VideoMAEForVideoClassification.from_pretrained(model_id)
14
+ feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_id)
15
 
16
  # Create the pipeline
17
+ classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, device=-1)
18
 
19
  # Analyze key frames using video classification model
20
  results = []
21
  for frame in frames:
22
  # OpenCV uses BGR, convert to RGB for the model
23
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
24
+ predictions = classifier([frame_rgb]) # Assuming model outputs probabilities
25
  # Analyze predictions for insights related to the play
26
  result = analyze_predictions_ucf101(predictions)
27
  results.append(result)