Fix corner case when m is too small
Browse files- adaptive_schedule.py +59 -12
adaptive_schedule.py
CHANGED
@@ -345,10 +345,7 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
|
|
345 |
def squeeze_without_change_order(schedules, m):
|
346 |
p = len(schedules)
|
347 |
squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
|
348 |
-
max_len =
|
349 |
-
for seq in squeezed:
|
350 |
-
assert max_len == 0 or max_len == len(seq)
|
351 |
-
max_len = max(max_len, len(seq))
|
352 |
|
353 |
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
354 |
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
|
@@ -389,6 +386,9 @@ def squeeze_without_change_order(schedules, m):
|
|
389 |
identifier_cnt[i][identifier] = _cnt + 1
|
390 |
identifier_index[_cnt * p + i][identifier] = index
|
391 |
stage_index[i] = index + 1
|
|
|
|
|
|
|
392 |
return squeezed
|
393 |
|
394 |
|
@@ -454,6 +454,7 @@ def process_cooldown(schedules, m):
|
|
454 |
schedules[i][index] = 'B'
|
455 |
|
456 |
# 2: add W back in cooldown phase
|
|
|
457 |
for i in range(p):
|
458 |
c_w, c_ww = 0, 0
|
459 |
last_w_index = -1
|
@@ -478,12 +479,57 @@ def process_cooldown(schedules, m):
|
|
478 |
elif c_ww > 0:
|
479 |
schedules[i][j] = 'w'
|
480 |
c_ww -= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
schedules = squeeze_without_change_order(schedules, m)
|
483 |
return schedules
|
484 |
|
485 |
|
486 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
"""
|
488 |
We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
|
489 |
find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
|
@@ -491,17 +537,15 @@ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index =
|
|
491 |
to the vacant cell, and the bubble is filled.
|
492 |
"""
|
493 |
p = len(schedules)
|
494 |
-
max_len = 0
|
495 |
-
for seq in schedules:
|
496 |
-
assert max_len == 0 or max_len == len(seq)
|
497 |
-
max_len = max(max_len, len(seq))
|
498 |
if starting_index is not None:
|
499 |
assert isinstance(starting_index, list) and len(starting_index) == p
|
500 |
if ending_index is not None:
|
501 |
assert isinstance(ending_index, list) and len(ending_index) == p
|
|
|
|
|
|
|
502 |
starting_index = starting_index or [0] * p
|
503 |
ending_index = ending_index or [max_len] * p
|
504 |
-
|
505 |
last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
|
506 |
for i in range(p):
|
507 |
for j in range(max_len):
|
@@ -510,7 +554,6 @@ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index =
|
|
510 |
continue
|
511 |
last_index[i][identifier] = j
|
512 |
|
513 |
-
peak_mem = get_peak_mem(schedules)
|
514 |
stage_mem = [0] * p
|
515 |
def update_mem(stage_i, pass_c):
|
516 |
if pass_c in "Ff":
|
@@ -645,6 +688,7 @@ def check_correctness(schedules, m, raise_exception=False):
|
|
645 |
return False
|
646 |
return True
|
647 |
|
|
|
648 |
def relabel_w(schedules, m):
|
649 |
p = len(schedules)
|
650 |
c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
@@ -654,7 +698,7 @@ def relabel_w(schedules, m):
|
|
654 |
continue
|
655 |
c_cnt[i][schedules[i][j]] += 1
|
656 |
for c in "FfBbWw":
|
657 |
-
assert c_cnt[i][c] == m
|
658 |
for i in range(p):
|
659 |
w_queue = deque(maxlen=2 * m)
|
660 |
for j in range(len(schedules[i])):
|
@@ -722,6 +766,8 @@ def schedule_by_building_block(p, m, building_block, max_mem, keep_stable_phase=
|
|
722 |
if m < redundant_m:
|
723 |
# 4. remove redundancy
|
724 |
schedules = remove_redundancy(schedules, m)
|
|
|
|
|
725 |
schedules = squeeze_without_change_order(schedules, m)
|
726 |
print_schedules(schedules, "after removing redundancy")
|
727 |
init_peak_mem = peak_mem = get_peak_mem(schedules)
|
@@ -820,6 +866,7 @@ def schedule(p, m, cost, max_mem):
|
|
820 |
[4, -1, 4, -1],
|
821 |
[5, -1, 5, -1]
|
822 |
]
|
|
|
823 |
|
824 |
best_schedule = None
|
825 |
best_bubble = None
|
|
|
345 |
def squeeze_without_change_order(schedules, m):
|
346 |
p = len(schedules)
|
347 |
squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
|
348 |
+
max_len = check_and_get_schedule_len(schedules)
|
|
|
|
|
|
|
349 |
|
350 |
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
351 |
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
|
|
|
386 |
identifier_cnt[i][identifier] = _cnt + 1
|
387 |
identifier_index[_cnt * p + i][identifier] = index
|
388 |
stage_index[i] = index + 1
|
389 |
+
new_len = max(stage_index)
|
390 |
+
for i in range(p):
|
391 |
+
squeezed[i] = squeezed[i][:new_len]
|
392 |
return squeezed
|
393 |
|
394 |
|
|
|
454 |
schedules[i][index] = 'B'
|
455 |
|
456 |
# 2: add W back in cooldown phase
|
457 |
+
max_len = 0
|
458 |
for i in range(p):
|
459 |
c_w, c_ww = 0, 0
|
460 |
last_w_index = -1
|
|
|
479 |
elif c_ww > 0:
|
480 |
schedules[i][j] = 'w'
|
481 |
c_ww -= 1
|
482 |
+
for _ in range(c_w):
|
483 |
+
schedules[i].append('W')
|
484 |
+
for _ in range(c_ww):
|
485 |
+
schedules[i].append('w')
|
486 |
+
max_len = max(max_len, len(schedules[i]))
|
487 |
+
for i in range(p):
|
488 |
+
for _ in range(len(schedules[i]), max_len):
|
489 |
+
schedules[i].append(' ')
|
490 |
|
491 |
schedules = squeeze_without_change_order(schedules, m)
|
492 |
return schedules
|
493 |
|
494 |
|
495 |
+
def check_and_get_schedule_len(schedules):
|
496 |
+
max_len = 0
|
497 |
+
for seq in schedules:
|
498 |
+
assert max_len == 0 or max_len == len(seq)
|
499 |
+
max_len = max(max_len, len(seq))
|
500 |
+
return max_len
|
501 |
+
|
502 |
+
|
503 |
+
def release_w_in_warmup_if_under_memory(schedules, peak_mem = None):
|
504 |
+
"""
|
505 |
+
FF fBWfBW bwbw -> FF fBfBWW bwbw
|
506 |
+
FF f fBW BW bwbw -> FF f fBWBW bwbw
|
507 |
+
FF f f BW BbWbww -> FF f f BWBbWbww
|
508 |
+
FfFf BbWBbwWw -> FfFf BbBbWwWw
|
509 |
+
When the number of micro-batches is too small (than mem), the warmup phase is not optimal. We simply remove some
|
510 |
+
preceding W to fully utilize the memory to reduce unnecessary bubbles.
|
511 |
+
"""
|
512 |
+
p = len(schedules)
|
513 |
+
max_len = check_and_get_schedule_len(schedules)
|
514 |
+
all_peak_mem = get_peak_mem(schedules, return_all=True)
|
515 |
+
peak_mem = peak_mem or max(all_peak_mem)
|
516 |
+
min_peak = min(all_peak_mem)
|
517 |
+
for i in range(p):
|
518 |
+
cnt = 0
|
519 |
+
padding = [" "] * (peak_mem - min_peak)
|
520 |
+
for j in range(max_len):
|
521 |
+
if all_peak_mem[i] + cnt >= peak_mem:
|
522 |
+
break
|
523 |
+
if schedules[i][j] in "Ww":
|
524 |
+
padding[cnt] = schedules[i][j]
|
525 |
+
schedules[i][j] = ' '
|
526 |
+
cnt += 1
|
527 |
+
schedules[i].extend(padding)
|
528 |
+
# max_len += peak_mem - min_peak
|
529 |
+
return schedules
|
530 |
+
|
531 |
+
|
532 |
+
def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index = None, ending_index = None, peak_mem = None):
|
533 |
"""
|
534 |
We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
|
535 |
find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
|
|
|
537 |
to the vacant cell, and the bubble is filled.
|
538 |
"""
|
539 |
p = len(schedules)
|
|
|
|
|
|
|
|
|
540 |
if starting_index is not None:
|
541 |
assert isinstance(starting_index, list) and len(starting_index) == p
|
542 |
if ending_index is not None:
|
543 |
assert isinstance(ending_index, list) and len(ending_index) == p
|
544 |
+
|
545 |
+
peak_mem = peak_mem or get_peak_mem(schedules)
|
546 |
+
max_len = check_and_get_schedule_len(schedules)
|
547 |
starting_index = starting_index or [0] * p
|
548 |
ending_index = ending_index or [max_len] * p
|
|
|
549 |
last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
|
550 |
for i in range(p):
|
551 |
for j in range(max_len):
|
|
|
554 |
continue
|
555 |
last_index[i][identifier] = j
|
556 |
|
|
|
557 |
stage_mem = [0] * p
|
558 |
def update_mem(stage_i, pass_c):
|
559 |
if pass_c in "Ff":
|
|
|
688 |
return False
|
689 |
return True
|
690 |
|
691 |
+
|
692 |
def relabel_w(schedules, m):
|
693 |
p = len(schedules)
|
694 |
c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
|
|
698 |
continue
|
699 |
c_cnt[i][schedules[i][j]] += 1
|
700 |
for c in "FfBbWw":
|
701 |
+
assert c_cnt[i][c] == m, f"{i}, {c}, {c_cnt[i][c]}"
|
702 |
for i in range(p):
|
703 |
w_queue = deque(maxlen=2 * m)
|
704 |
for j in range(len(schedules[i])):
|
|
|
766 |
if m < redundant_m:
|
767 |
# 4. remove redundancy
|
768 |
schedules = remove_redundancy(schedules, m)
|
769 |
+
if m <= p and 2 * m <= max_mem:
|
770 |
+
schedules = release_w_in_warmup_if_under_memory(schedules, peak_mem=min(2 * p, peak_mem))
|
771 |
schedules = squeeze_without_change_order(schedules, m)
|
772 |
print_schedules(schedules, "after removing redundancy")
|
773 |
init_peak_mem = peak_mem = get_peak_mem(schedules)
|
|
|
866 |
[4, -1, 4, -1],
|
867 |
[5, -1, 5, -1]
|
868 |
]
|
869 |
+
# available_starting_patterns = available_starting_patterns[:1]
|
870 |
|
871 |
best_schedule = None
|
872 |
best_bubble = None
|