pattern_size = 6 from collections import Counter from dataclasses import dataclass @dataclass(eq=True, frozen=True) class ScheduledNode: type: str stage: int minibatch: int start_time: int completion_time: int def transform_schedule(schedule, f, b, w, c): result = [] stage_order = [] local_prev = {} stages = len(schedule) for sid, stage in enumerate(schedule): counter = Counter() order = [] for p in stage: if not p.strip(): continue mb = counter.get(p, 0) if order: local_prev[(sid, p, mb)] = order[-1] order.append((p, mb)) counter.update(p) stage_order.append(order) nmb = max(counter.values()) time_map = {} cost = { 'F': f, 'B': b, 'W': w, } def get_time(stage, type, mb): if (stage, type, mb) in time_map: return time_map.get((stage, type, mb)) time = 0 if (stage, type, mb) in local_prev: time = get_time(stage, *local_prev[(stage, type, mb)]) if type in ('F') and stage > 0: time = max(time, get_time(stage - 1, type, mb) + c) if type in ('B') and stage + 1< len(schedule): time = max(time, get_time(stage + 1, type, mb) + c) # print(f'{stage} {type}:{mb}', time + cost[type]) time_map[(stage, type, mb)] = time + cost[type] return time_map[(stage, type, mb)] r = 0 for sid, stage in enumerate(schedule): r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r) for sid, stage in enumerate(stage_order): result_stage = [] for p, mb in stage: result_stage.append(ScheduledNode( p.upper(), sid, mb, get_time(sid, p, mb) - cost[p], get_time(sid, p, mb) ) ) result.append(result_stage) return result def process_warmup_without_increasing_peak_mem(schedules, m): peak_mem = 0 mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))] loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))] cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))] for sid in range(len(schedules)): cur = 0 for i in range(len(schedules[sid])): if schedules[sid][i] in ('F'): cur += 1 if schedules[sid][i] in ('W'): cur -= 1 mem[sid][i] = cur peak_mem = max(peak_mem, cur) for i in range(len(schedules[0])): for sid in range(len(schedules)): if schedules[sid][i] == ' ': continue cntr[sid][schedules[sid][i]] += 1 cnt = cntr[sid][schedules[sid][i]] pos = -1 if cnt > 1: pos = loc[sid][cnt - 1][schedules[sid][i]] if schedules[sid][i] == 'W': pos = max(pos, loc[sid][cnt]['B']) if schedules[sid][i] == 'F' and sid > 0: pos = max(pos, loc[sid - 1][cnt]['F']) if schedules[sid][i] == 'B': if sid != len(schedules) - 1: pos = max(pos, loc[sid + 1][cnt]['B']) else : pos = max(pos, loc[sid][cnt]['F']) pos += 1 while schedules[sid][pos] != ' ' and pos < i: pos += 1 if pos == i: loc[sid][cnt][schedules[sid][i]] = i continue if schedules[sid][i] in ('B', 'W'): schedules[sid][pos] = schedules[sid][i] schedules[sid][i] = ' ' if schedules[sid][pos] in ('W'): for j in range(pos, i): mem[sid][j] -= 1 loc[sid][cnt][schedules[sid][pos]] = pos continue #If F: if (sid == 0): print(cnt, pos, i) place = i while place > pos and mem[sid][place - 1] < peak_mem: place -= 1 while place < i and schedules[sid][place] != ' ': place += 1 if place == i: loc[sid][cnt][schedules[sid][i]] = i continue if (sid == 0): print(place) pos = place schedules[sid][pos] = schedules[sid][i] schedules[sid][i] = ' ' for j in range(pos, i): mem[sid][j] += 1 loc[sid][cnt][schedules[sid][pos]] = pos return schedules def schedule(p, m, cost): schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)] f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2 for sid in range(p - 1, -1, -1): for mid in range((m + 1) // 2): if mid * 2 < m: schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B' if mid * 2 + 1 < m: schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B' f_0 -= 1 f_1 -= 1 b_0 += 1 b_1 += 1 cnt = 0 for i in range(len(schedules[0])): if schedules[sid][i] == 'B': cnt += 1 if schedules[sid][i] == ' ' and cnt > 0: cnt -= 1 schedules[sid][i] = 'W' schedules = process_warmup_without_increasing_peak_mem(schedules, m) res = transform_schedule(schedules, *cost) return res