fffiloni commited on
Commit
be6d4fe
1 Parent(s): dbbc6a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  """
2
  =====================================================
3
  Optical Flow: Predicting movement with the RAFT model
@@ -74,10 +76,8 @@ _ = urlretrieve(video_url, video_path)
74
  from torchvision.io import read_video
75
  frames, _, _ = read_video(str(video_path), output_format="TCHW")
76
 
77
- img1_batch = torch.stack([frames[100], frames[150]])
78
- img2_batch = torch.stack([frames[101], frames[151]])
79
-
80
- plot(img1_batch)
81
 
82
  #########################
83
  # The RAFT model accepts RGB images. We first get the frames from
@@ -92,15 +92,15 @@ weights = Raft_Large_Weights.DEFAULT
92
  transforms = weights.transforms()
93
 
94
 
95
- def preprocess(img1_batch, img2_batch):
96
- img1_batch = F.resize(img1_batch, size=[520, 960])
97
- img2_batch = F.resize(img2_batch, size=[520, 960])
98
- return transforms(img1_batch, img2_batch)
99
 
100
 
101
- img1_batch, img2_batch = preprocess(img1_batch, img2_batch)
102
 
103
- print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}")
104
 
105
 
106
  ####################################
@@ -120,7 +120,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
120
  model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
121
  model = model.eval()
122
 
123
- list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
124
  print(f"type = {type(list_of_flows)}")
125
  print(f"length = {len(list_of_flows)} = number of iterations of the model")
126
 
@@ -160,9 +160,9 @@ from torchvision.utils import flow_to_image
160
  flow_imgs = flow_to_image(predicted_flows)
161
 
162
  # The images have been mapped into [-1, 1] but for plotting we want them in [0, 1]
163
- img1_batch = [(img1 + 1) / 2 for img1 in img1_batch]
164
 
165
- grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)]
166
  plot(grid)
167
 
168
  ####################################
 
1
+ import gradio as gr
2
+
3
  """
4
  =====================================================
5
  Optical Flow: Predicting movement with the RAFT model
 
76
  from torchvision.io import read_video
77
  frames, _, _ = read_video(str(video_path), output_format="TCHW")
78
 
79
+ img1= [frames[100]
80
+ img2 = [frames[101]
 
 
81
 
82
  #########################
83
  # The RAFT model accepts RGB images. We first get the frames from
 
92
  transforms = weights.transforms()
93
 
94
 
95
+ def preprocess(img, img2):
96
+ img1 = F.resize(img1, size=[520, 960])
97
+ img2 = F.resize(img2, size=[520, 960])
98
+ return transforms(img1, img2)
99
 
100
 
101
+ img1, img2 = preprocess(img1, img2)
102
 
103
+ print(f"shape = {img1.shape}, dtype = {img1.dtype}")
104
 
105
 
106
  ####################################
 
120
  model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
121
  model = model.eval()
122
 
123
+ list_of_flows = model(img1.to(device), img2.to(device))
124
  print(f"type = {type(list_of_flows)}")
125
  print(f"length = {len(list_of_flows)} = number of iterations of the model")
126
 
 
160
  flow_imgs = flow_to_image(predicted_flows)
161
 
162
  # The images have been mapped into [-1, 1] but for plotting we want them in [0, 1]
163
+ img1 = [(img1 + 1) / 2 for img1 in img1]
164
 
165
+ grid = [[img1, flow_img] for (img1, flow_img) in zip(img1, flow_imgs)]
166
  plot(grid)
167
 
168
  ####################################