p2oileen commited on
Commit
f287a32
·
1 Parent(s): fa86813
Files changed (2) hide show
  1. requirements.txt +2 -0
  2. train.py +7 -6
requirements.txt CHANGED
@@ -6,3 +6,5 @@ scikit-image>=0.14.0
6
  torchvision>=0.2.1
7
  pillow>=7.2.0
8
  lpips>=0.1.3
 
 
 
6
  torchvision>=0.2.1
7
  pillow>=7.2.0
8
  lpips>=0.1.3
9
+ gdown
10
+ tqdm
train.py CHANGED
@@ -12,6 +12,7 @@ from data_loader import (FileDataset,
12
  RandomResizedCropWithAutoCenteringAndZeroPadding)
13
  from torch.utils.data.distributed import DistributedSampler
14
  from conr import CoNR
 
15
 
16
  def data_sampler(dataset, shuffle, distributed):
17
 
@@ -123,7 +124,7 @@ def infer(args, humanflowmodel, image_names_list):
123
  time_stamp = time.time()
124
  prev_frame_rgb = []
125
  prev_frame_a = []
126
- for i, data in enumerate(train_data):
127
  data_time_interval = time.time() - time_stamp
128
  time_stamp = time.time()
129
  with torch.no_grad():
@@ -137,11 +138,11 @@ def infer(args, humanflowmodel, image_names_list):
137
 
138
  train_time_interval = time.time() - time_stamp
139
  time_stamp = time.time()
140
- if i % 5 == 0 and args.local_rank == 0:
141
- print("[infer batch: %4d/%4d] time:%2f+%2f" % (
142
- i, train_num,
143
- data_time_interval, train_time_interval
144
- ))
145
  with torch.no_grad():
146
 
147
  if args.test_output_video:
 
12
  RandomResizedCropWithAutoCenteringAndZeroPadding)
13
  from torch.utils.data.distributed import DistributedSampler
14
  from conr import CoNR
15
+ from tqdm import tqdm
16
 
17
  def data_sampler(dataset, shuffle, distributed):
18
 
 
124
  time_stamp = time.time()
125
  prev_frame_rgb = []
126
  prev_frame_a = []
127
+ for i, data in tqdm(enumerate(train_data)):
128
  data_time_interval = time.time() - time_stamp
129
  time_stamp = time.time()
130
  with torch.no_grad():
 
138
 
139
  train_time_interval = time.time() - time_stamp
140
  time_stamp = time.time()
141
+ # if i % 5 == 0 and args.local_rank == 0:
142
+ # print("[infer batch: %4d/%4d] time:%2f+%2f" % (
143
+ # i, train_num,
144
+ # data_time_interval, train_time_interval
145
+ # ))
146
  with torch.no_grad():
147
 
148
  if args.test_output_video: