p2oileen commited on
Commit
4537523
·
1 Parent(s): f287a32
Files changed (1) hide show
  1. train.py +7 -35
train.py CHANGED
@@ -71,35 +71,6 @@ def test():
71
  conrmodel.dist()
72
  infer(args, conrmodel, image_names_list)
73
 
74
- # def test():
75
- # source_names_list = []
76
- # for name in os.listdir(args.test_input_person_images):
77
- # thissource = os.path.join(args.test_input_person_images, name)
78
- # if os.path.isfile(thissource):
79
- # source_names_list.append([thissource])
80
- # if os.path.isdir(thissource):
81
- # toadd = [os.path.join(thissource, this_file)
82
- # for this_file in os.listdir(thissource)]
83
- # if (toadd != []):
84
- # source_names_list.append(toadd)
85
- # else:
86
- # print("skipping empty folder :"+thissource)
87
- # image_names_list = []
88
- # for eachlist in source_names_list:
89
- # for name in sorted(os.listdir(args.test_input_poses_images)):
90
- # thistarget = os.path.join(args.test_input_poses_images, name)
91
- # if os.path.isfile(thistarget):
92
- # image_names_list.append([thistarget, *eachlist])
93
- # if os.path.isdir(thistarget):
94
- # print("skipping folder :"+thistarget)
95
-
96
- # print(image_names_list)
97
- # print("---building models...")
98
- # conrmodel = CoNR(args)
99
- # conrmodel.load_model(path=args.test_checkpoint_dir)
100
- # conrmodel.dist()
101
- # infer(args, conrmodel, image_names_list)
102
-
103
 
104
  def infer(args, humanflowmodel, image_names_list):
105
  print("---test images: ", len(image_names_list))
@@ -124,7 +95,9 @@ def infer(args, humanflowmodel, image_names_list):
124
  time_stamp = time.time()
125
  prev_frame_rgb = []
126
  prev_frame_a = []
127
- for i, data in tqdm(enumerate(train_data)):
 
 
128
  data_time_interval = time.time() - time_stamp
129
  time_stamp = time.time()
130
  with torch.no_grad():
@@ -138,11 +111,10 @@ def infer(args, humanflowmodel, image_names_list):
138
 
139
  train_time_interval = time.time() - time_stamp
140
  time_stamp = time.time()
141
- # if i % 5 == 0 and args.local_rank == 0:
142
- # print("[infer batch: %4d/%4d] time:%2f+%2f" % (
143
- # i, train_num,
144
- # data_time_interval, train_time_interval
145
- # ))
146
  with torch.no_grad():
147
 
148
  if args.test_output_video:
 
71
  conrmodel.dist()
72
  infer(args, conrmodel, image_names_list)
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def infer(args, humanflowmodel, image_names_list):
76
  print("---test images: ", len(image_names_list))
 
95
  time_stamp = time.time()
96
  prev_frame_rgb = []
97
  prev_frame_a = []
98
+
99
+ pbar = tqdm(range(train_num), ncols=100)
100
+ for i, data in enumerate(train_data):
101
  data_time_interval = time.time() - time_stamp
102
  time_stamp = time.time()
103
  with torch.no_grad():
 
111
 
112
  train_time_interval = time.time() - time_stamp
113
  time_stamp = time.time()
114
+ if args.local_rank == 0:
115
+ pbar.set_description(f"Epoch {i}/{train_num}")
116
+ pbar.set_postfix({"data_time": data_time_interval, "train_time":train_time_interval})
117
+
 
118
  with torch.no_grad():
119
 
120
  if args.test_output_video: