Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ from PIL import Image
|
|
21 |
from decord import VideoReader
|
22 |
from decord import cpu
|
23 |
from videomamba_image import videomamba_image_tiny
|
24 |
-
from videomamba_video import
|
25 |
from kinetics_class_index import kinetics_classnames
|
26 |
from imagenet_class_index import imagenet_classnames
|
27 |
from transforms import (
|
@@ -39,7 +39,7 @@ device = "cuda"
|
|
39 |
model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_m16_k400_f16_res224.pth")
|
40 |
model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
|
41 |
# Pick a pretrained model
|
42 |
-
model_video =
|
43 |
video_sd = torch.load(model_video_path, map_location='cpu')
|
44 |
model_video.load_state_dict(video_sd)
|
45 |
model_image = videomamba_image_tiny()
|
|
|
21 |
from decord import VideoReader
|
22 |
from decord import cpu
|
23 |
from videomamba_image import videomamba_image_tiny
|
24 |
+
from videomamba_video import videomamba_middle
|
25 |
from kinetics_class_index import kinetics_classnames
|
26 |
from imagenet_class_index import imagenet_classnames
|
27 |
from transforms import (
|
|
|
39 |
model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_m16_k400_f16_res224.pth")
|
40 |
model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
|
41 |
# Pick a pretrained model
|
42 |
+
model_video = videomamba_middle(num_classes=400, num_frames=16)
|
43 |
video_sd = torch.load(model_video_path, map_location='cpu')
|
44 |
model_video.load_state_dict(video_sd)
|
45 |
model_image = videomamba_image_tiny()
|