fffiloni commited on
Commit
cf309f8
1 Parent(s): 7e3803b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -51
app.py CHANGED
@@ -21,39 +21,11 @@ import numpy as np
21
  import torch
22
  import matplotlib.pyplot as plt
23
  import torchvision.transforms.functional as F
 
 
 
24
 
25
 
26
- plt.rcParams["savefig.bbox"] = "tight"
27
- # sphinx_gallery_thumbnail_number = 2
28
-
29
-
30
- def plot(imgs, **imshow_kwargs):
31
- if not isinstance(imgs[0], list):
32
- # Make a 2d grid even if there's just 1 row
33
- imgs = [imgs]
34
-
35
- num_rows = len(imgs)
36
- num_cols = len(imgs[0])
37
- _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
38
- for row_idx, row in enumerate(imgs):
39
- for col_idx, img in enumerate(row):
40
- ax = axs[row_idx, col_idx]
41
- img = F.to_pil_image(img.to("cpu"))
42
- ax.imshow(np.asarray(img), **imshow_kwargs)
43
- ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
44
-
45
- plt.tight_layout()
46
-
47
- ###################################
48
- # Reading Videos Using Torchvision
49
- # --------------------------------
50
- # We will first read a video using :func:`~torchvision.io.read_video`.
51
- # Alternatively one can use the new :class:`~torchvision.io.VideoReader` API (if
52
- # torchvision is built from source).
53
- # The video we will use here is free of use from `pexels.com
54
- # <https://www.pexels.com/video/a-man-playing-a-game-of-basketball-5192157/>`_,
55
- # credits go to `Pavel Danilyuk <https://www.pexels.com/@pavel-danilyuk>`_.
56
-
57
 
58
  import tempfile
59
  from pathlib import Path
@@ -64,29 +36,15 @@ def infer():
64
  video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
65
  _ = urlretrieve(video_url, video_path)
66
 
67
- #########################
68
- # :func:`~torchvision.io.read_video` returns the video frames, audio frames and
69
- # the metadata associated with the video. In our case, we only need the video
70
- # frames.
71
- #
72
- # Here we will just make 2 predictions between 2 pre-selected pairs of frames,
73
- # namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a
74
- # single model input.
75
 
76
- from torchvision.io import read_video
77
- frames, _, _ = read_video(str(video_path), output_format="TCHW")
78
 
79
- img1= [frames[100]
80
- img2 = [frames[101]
81
 
82
- #########################
83
- # The RAFT model accepts RGB images. We first get the frames from
84
- # :func:`~torchvision.io.read_video` and resize them to ensure their
85
- # dimensions are divisible by 8. Then we use the transforms bundled into the
86
- # weights in order to preprocess the input and rescale its values to the
87
- # required ``[-1, 1]`` interval.
88
 
89
- from torchvision.models.optical_flow import Raft_Large_Weights
90
 
91
  weights = Raft_Large_Weights.DEFAULT
92
  transforms = weights.transforms()
@@ -112,7 +70,7 @@ def infer():
112
  # We also provide the :func:`~torchvision.models.optical_flow.raft_small` model
113
  # builder, which is smaller and faster to run, sacrificing a bit of accuracy.
114
 
115
- from torchvision.models.optical_flow import raft_large
116
 
117
  # If you can, run this example on a GPU, it will be a lot faster.
118
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  import torch
22
  import matplotlib.pyplot as plt
23
  import torchvision.transforms.functional as F
24
+ from torchvision.io import read_video
25
+ from torchvision.models.optical_flow import Raft_Large_Weights
26
+ from torchvision.models.optical_flow import raft_large
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  import tempfile
31
  from pathlib import Path
 
36
  video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
37
  _ = urlretrieve(video_url, video_path)
38
 
 
 
 
 
 
 
 
 
39
 
 
 
40
 
41
+
42
+ frames, _, _ = read_video(str(video_path), output_format="TCHW")
43
 
44
+ img1= frames[100]
45
+ img2 = frames[101]
 
 
 
 
46
 
47
+
48
 
49
  weights = Raft_Large_Weights.DEFAULT
50
  transforms = weights.transforms()
 
70
  # We also provide the :func:`~torchvision.models.optical_flow.raft_small` model
71
  # builder, which is smaller and faster to run, sacrificing a bit of accuracy.
72
 
73
+
74
 
75
  # If you can, run this example on a GPU, it will be a lot faster.
76
  device = "cuda" if torch.cuda.is_available() else "cpu"