Arnaudding001 commited on
Commit
e9f92a9
1 Parent(s): e884345

Create raft_evaluate.py

Browse files
Files changed (1) hide show
  1. raft_evaluate.py +195 -0
raft_evaluate.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('core')
3
+
4
+ from PIL import Image
5
+ import argparse
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import matplotlib.pyplot as plt
12
+
13
+ import datasets
14
+ from utils import flow_viz
15
+ from utils import frame_utils
16
+
17
+ from raft import RAFT
18
+ from utils.utils import InputPadder, forward_interpolate
19
+
20
+
21
+ @torch.no_grad()
22
+ def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
23
+ """ Create submission for the Sintel leaderboard """
24
+ model.eval()
25
+ for dstype in ['clean', 'final']:
26
+ test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
27
+
28
+ flow_prev, sequence_prev = None, None
29
+ for test_id in range(len(test_dataset)):
30
+ image1, image2, (sequence, frame) = test_dataset[test_id]
31
+ if sequence != sequence_prev:
32
+ flow_prev = None
33
+
34
+ padder = InputPadder(image1.shape)
35
+ image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
36
+
37
+ flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
38
+ flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
39
+
40
+ if warm_start:
41
+ flow_prev = forward_interpolate(flow_low[0])[None].cuda()
42
+
43
+ output_dir = os.path.join(output_path, dstype, sequence)
44
+ output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
45
+
46
+ if not os.path.exists(output_dir):
47
+ os.makedirs(output_dir)
48
+
49
+ frame_utils.writeFlow(output_file, flow)
50
+ sequence_prev = sequence
51
+
52
+
53
+ @torch.no_grad()
54
+ def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
55
+ """ Create submission for the Sintel leaderboard """
56
+ model.eval()
57
+ test_dataset = datasets.KITTI(split='testing', aug_params=None)
58
+
59
+ if not os.path.exists(output_path):
60
+ os.makedirs(output_path)
61
+
62
+ for test_id in range(len(test_dataset)):
63
+ image1, image2, (frame_id, ) = test_dataset[test_id]
64
+ padder = InputPadder(image1.shape, mode='kitti')
65
+ image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
66
+
67
+ _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
68
+ flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
69
+
70
+ output_filename = os.path.join(output_path, frame_id)
71
+ frame_utils.writeFlowKITTI(output_filename, flow)
72
+
73
+
74
+ @torch.no_grad()
75
+ def validate_chairs(model, iters=24):
76
+ """ Perform evaluation on the FlyingChairs (test) split """
77
+ model.eval()
78
+ epe_list = []
79
+
80
+ val_dataset = datasets.FlyingChairs(split='validation')
81
+ for val_id in range(len(val_dataset)):
82
+ image1, image2, flow_gt, _ = val_dataset[val_id]
83
+ image1 = image1[None].cuda()
84
+ image2 = image2[None].cuda()
85
+
86
+ _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
87
+ epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
88
+ epe_list.append(epe.view(-1).numpy())
89
+
90
+ epe = np.mean(np.concatenate(epe_list))
91
+ print("Validation Chairs EPE: %f" % epe)
92
+ return {'chairs': epe}
93
+
94
+
95
+ @torch.no_grad()
96
+ def validate_sintel(model, iters=32):
97
+ """ Peform validation using the Sintel (train) split """
98
+ model.eval()
99
+ results = {}
100
+ for dstype in ['clean', 'final']:
101
+ val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
102
+ epe_list = []
103
+
104
+ for val_id in range(len(val_dataset)):
105
+ image1, image2, flow_gt, _ = val_dataset[val_id]
106
+ image1 = image1[None].cuda()
107
+ image2 = image2[None].cuda()
108
+
109
+ padder = InputPadder(image1.shape)
110
+ image1, image2 = padder.pad(image1, image2)
111
+
112
+ flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
113
+ flow = padder.unpad(flow_pr[0]).cpu()
114
+
115
+ epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
116
+ epe_list.append(epe.view(-1).numpy())
117
+
118
+ epe_all = np.concatenate(epe_list)
119
+ epe = np.mean(epe_all)
120
+ px1 = np.mean(epe_all<1)
121
+ px3 = np.mean(epe_all<3)
122
+ px5 = np.mean(epe_all<5)
123
+
124
+ print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
125
+ results[dstype] = np.mean(epe_list)
126
+
127
+ return results
128
+
129
+
130
+ @torch.no_grad()
131
+ def validate_kitti(model, iters=24):
132
+ """ Peform validation using the KITTI-2015 (train) split """
133
+ model.eval()
134
+ val_dataset = datasets.KITTI(split='training')
135
+
136
+ out_list, epe_list = [], []
137
+ for val_id in range(len(val_dataset)):
138
+ image1, image2, flow_gt, valid_gt = val_dataset[val_id]
139
+ image1 = image1[None].cuda()
140
+ image2 = image2[None].cuda()
141
+
142
+ padder = InputPadder(image1.shape, mode='kitti')
143
+ image1, image2 = padder.pad(image1, image2)
144
+
145
+ flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
146
+ flow = padder.unpad(flow_pr[0]).cpu()
147
+
148
+ epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
149
+ mag = torch.sum(flow_gt**2, dim=0).sqrt()
150
+
151
+ epe = epe.view(-1)
152
+ mag = mag.view(-1)
153
+ val = valid_gt.view(-1) >= 0.5
154
+
155
+ out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
156
+ epe_list.append(epe[val].mean().item())
157
+ out_list.append(out[val].cpu().numpy())
158
+
159
+ epe_list = np.array(epe_list)
160
+ out_list = np.concatenate(out_list)
161
+
162
+ epe = np.mean(epe_list)
163
+ f1 = 100 * np.mean(out_list)
164
+
165
+ print("Validation KITTI: %f, %f" % (epe, f1))
166
+ return {'kitti-epe': epe, 'kitti-f1': f1}
167
+
168
+
169
+ if __name__ == '__main__':
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument('--model', help="restore checkpoint")
172
+ parser.add_argument('--dataset', help="dataset for evaluation")
173
+ parser.add_argument('--small', action='store_true', help='use small model')
174
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
175
+ parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
176
+ args = parser.parse_args()
177
+
178
+ model = torch.nn.DataParallel(RAFT(args))
179
+ model.load_state_dict(torch.load(args.model))
180
+
181
+ model.cuda()
182
+ model.eval()
183
+
184
+ # create_sintel_submission(model.module, warm_start=True)
185
+ # create_kitti_submission(model.module)
186
+
187
+ with torch.no_grad():
188
+ if args.dataset == 'chairs':
189
+ validate_chairs(model.module)
190
+
191
+ elif args.dataset == 'sintel':
192
+ validate_sintel(model.module)
193
+
194
+ elif args.dataset == 'kitti':
195
+ validate_kitti(model.module)