from dataclasses import dataclass @dataclass(eq=True, frozen=True) class ScheduledNode: type: str chunk: int stage: int minibatch: int start_time: int completion_time: int def get_interleaved_variation(_p, _n, cost): _f, _b, _w, _c = cost schedule = [] local_prev = {} f_order = [] b_order = [] left = [_n, _n] for id in range(min(_n, _p)): f_order.append(('F', id)) for id in range(min(_n, _p)): f_order.append(('f', id)) left = [max(0, _n - _p), max(0, _n - _p)] i = 0 cur = 0 for id in range(min(_n, _p)): b_order.append(('B', id)) while left[0] > 0 or left[1] > 0: if i >= _p and left[1 - cur] > 0: cur = 1 - cur if left[cur] > 0: if cur == 0: f_order.append(('F', _n - left[cur])) b_order.append(('b', _n - left[cur] - _p)) else: f_order.append(('f', _n - left[cur])) b_order.append(('B', _n - left[cur])) left[cur] -= 1 i += 3 for id in range(min(_n, _p)): b_order.append(('b', _n - _p + id)) for stage in range(_p): diff = min(_p + _p - stage, len(f_order)) stage_schedule = [] for i in range(diff): stage_schedule.append(f_order[i]) for i in range(len(f_order) - diff): stage_schedule.append(b_order[i]) stage_schedule.append(f_order[i + diff]) for i in range(diff): stage_schedule.append(b_order[len(b_order) - diff + i]) for i in range(len(stage_schedule) - 1): local_prev[(stage, *stage_schedule[i + 1])] = (stage, *stage_schedule[i]) schedule.append(stage_schedule) # print(stage_schedule) # return None cost = { 'F': _f, 'f': _f, 'B': _b+_w, 'b': _b+_w } pred = { 'f': 'F', 'B': 'f', 'b': 'B' } time_map = {} def get_time(stage, type, minibatch): if (stage, type, minibatch) in time_map: return time_map.get((stage, type, minibatch)) time = 0 if (stage, type, minibatch) in local_prev: time = get_time(*local_prev[(stage, type, minibatch)]) if stage > 0 and type in ('F', 'f'): time = max(time, get_time(stage - 1, type, minibatch) + _c) if stage == 0 and type in ('f'): time = max(time, get_time(_p - 1, pred[type], minibatch) + _c) if stage != _p - 1 and type in ('B', 'b'): time = max(time, get_time(stage + 1, type, minibatch) + _c) if stage == _p - 1 and type in ('b'): time = max(time, get_time(0, pred[type], minibatch) + _c) if stage == _p - 1 and type in ('B'): time = max(time, get_time(stage, pred[type], minibatch)) time_map[(stage, type, minibatch)] = time + cost[type] return time_map[(stage, type, minibatch)] result = [] for sid, stage in enumerate(schedule): result_stage = [] for type, minibatch in stage: result_stage.append(ScheduledNode( type.upper(), type in ('f', 'B', 'W'), sid, minibatch, get_time(sid, type, minibatch) - cost[type], get_time(sid, type, minibatch) )) result.append(result_stage) return result