Spaces:
Runtime error
Runtime error
add tqdm
Browse files
train.py
CHANGED
@@ -71,35 +71,6 @@ def test():
|
|
71 |
conrmodel.dist()
|
72 |
infer(args, conrmodel, image_names_list)
|
73 |
|
74 |
-
# def test():
|
75 |
-
# source_names_list = []
|
76 |
-
# for name in os.listdir(args.test_input_person_images):
|
77 |
-
# thissource = os.path.join(args.test_input_person_images, name)
|
78 |
-
# if os.path.isfile(thissource):
|
79 |
-
# source_names_list.append([thissource])
|
80 |
-
# if os.path.isdir(thissource):
|
81 |
-
# toadd = [os.path.join(thissource, this_file)
|
82 |
-
# for this_file in os.listdir(thissource)]
|
83 |
-
# if (toadd != []):
|
84 |
-
# source_names_list.append(toadd)
|
85 |
-
# else:
|
86 |
-
# print("skipping empty folder :"+thissource)
|
87 |
-
# image_names_list = []
|
88 |
-
# for eachlist in source_names_list:
|
89 |
-
# for name in sorted(os.listdir(args.test_input_poses_images)):
|
90 |
-
# thistarget = os.path.join(args.test_input_poses_images, name)
|
91 |
-
# if os.path.isfile(thistarget):
|
92 |
-
# image_names_list.append([thistarget, *eachlist])
|
93 |
-
# if os.path.isdir(thistarget):
|
94 |
-
# print("skipping folder :"+thistarget)
|
95 |
-
|
96 |
-
# print(image_names_list)
|
97 |
-
# print("---building models...")
|
98 |
-
# conrmodel = CoNR(args)
|
99 |
-
# conrmodel.load_model(path=args.test_checkpoint_dir)
|
100 |
-
# conrmodel.dist()
|
101 |
-
# infer(args, conrmodel, image_names_list)
|
102 |
-
|
103 |
|
104 |
def infer(args, humanflowmodel, image_names_list):
|
105 |
print("---test images: ", len(image_names_list))
|
@@ -124,7 +95,9 @@ def infer(args, humanflowmodel, image_names_list):
|
|
124 |
time_stamp = time.time()
|
125 |
prev_frame_rgb = []
|
126 |
prev_frame_a = []
|
127 |
-
|
|
|
|
|
128 |
data_time_interval = time.time() - time_stamp
|
129 |
time_stamp = time.time()
|
130 |
with torch.no_grad():
|
@@ -138,11 +111,10 @@ def infer(args, humanflowmodel, image_names_list):
|
|
138 |
|
139 |
train_time_interval = time.time() - time_stamp
|
140 |
time_stamp = time.time()
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
# ))
|
146 |
with torch.no_grad():
|
147 |
|
148 |
if args.test_output_video:
|
|
|
71 |
conrmodel.dist()
|
72 |
infer(args, conrmodel, image_names_list)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
def infer(args, humanflowmodel, image_names_list):
|
76 |
print("---test images: ", len(image_names_list))
|
|
|
95 |
time_stamp = time.time()
|
96 |
prev_frame_rgb = []
|
97 |
prev_frame_a = []
|
98 |
+
|
99 |
+
pbar = tqdm(range(train_num), ncols=100)
|
100 |
+
for i, data in enumerate(train_data):
|
101 |
data_time_interval = time.time() - time_stamp
|
102 |
time_stamp = time.time()
|
103 |
with torch.no_grad():
|
|
|
111 |
|
112 |
train_time_interval = time.time() - time_stamp
|
113 |
time_stamp = time.time()
|
114 |
+
if args.local_rank == 0:
|
115 |
+
pbar.set_description(f"Epoch {i}/{train_num}")
|
116 |
+
pbar.set_postfix({"data_time": data_time_interval, "train_time":train_time_interval})
|
117 |
+
|
|
|
118 |
with torch.no_grad():
|
119 |
|
120 |
if args.test_output_video:
|