pattern_size = 6 from collections import Counter from dataclasses import dataclass @dataclass(eq=True, frozen=True) class ScheduledNode: type: str chunk: int stage: int minibatch: int start_time: int completion_time: int def transform_schedule(schedule, f, b, w, c): result = [] stage_order = [] local_prev = {} stages = len(schedule) for sid, stage in enumerate(schedule): counter = Counter() order = [] for p in stage: if not p.strip(): continue mb = counter.get(p, 0) if order: local_prev[(sid, p, mb)] = order[-1] order.append((p, mb)) counter.update(p) stage_order.append(order) nmb = max(counter.values()) time_map = {} cost = { 'F': f, 'B': b + w, 'f': f, 'b': b + w, } def get_time(stage, type, mb): if (stage, type, mb) in time_map: return time_map.get((stage, type, mb)) time = 0 if (stage, type, mb) in local_prev: time = get_time(stage, *local_prev[(stage, type, mb)]) if type in "FB"and stage > 0: time = max(time, get_time(stage - 1, type, mb) + c) if type in "fb" and stage + 1< len(schedule): time = max(time, get_time(stage + 1, type, mb) + c) time_map[(stage, type, mb)] = time + cost[type] return time_map[(stage, type, mb)] r = 0 for sid, stage in enumerate(schedule): r = max(get_time(sid, 'b', nmb - 1) - get_time(sid, 'F', 0) + f, r) for sid, stage in enumerate(stage_order): result_stage = [] for p, mb in stage: result_stage.append(ScheduledNode( p.upper(), p in "fBW", sid, mb, get_time(sid, p, mb) - cost[p], get_time(sid, p, mb) ) ) result.append(result_stage) return result def get_pattern_str(pos): pattern = [" "] * pattern_size notations = "FfBbWw" for i, v in enumerate(pos): if v < 0: continue pattern[v] = notations[i] _str = "" for v in pattern: _str += v return _str def init_repeated_schedule(p, m, patterns): repeated = [] _len = 4 * p + m + 1 for i in range(p): str_i = get_pattern_str(patterns[i]) * _len repeated_i = [] for v in str_i: repeated_i.append(v) repeated.append(repeated_i) return repeated def clear_invalid(repeated, stage, pos, offset=-1): while 0 <= pos < len(repeated[stage]): repeated[stage][pos] = ' ' pos += offset * pattern_size return repeated def clear_invalid_index(repeated, m): p = len(repeated) index = pattern_size for identifier in "FfBb": if identifier in "FB": _iter = range(p) else: _iter = range(p - 1, -1, -1) for i in _iter: for j in range(pattern_size): if repeated[i][index] == identifier: clear_invalid(repeated, i, index - pattern_size, offset=-1) clear_invalid(repeated, i, index + pattern_size * m, offset=1) index += 1 if identifier in "Bb": w_identifier = {'B': 'W', 'b': 'w'}[identifier] for k in range(pattern_size): if repeated[i][index + k] == w_identifier: clear_invalid(repeated, i, index + k - pattern_size, offset=-1) clear_invalid(repeated, i, index + k + pattern_size * m, offset=1) break break index += 1 return repeated def process_warmup_without_increasing_peak_mem(schedules, m): peak_mem = 0 mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))] loc = [[{key: -1 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(m + 2)] for _ in range(len(schedules))] cntr = [{key: 0 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(len(schedules))] for sid in range(len(schedules)): cur = 0 for i in range(len(schedules[sid])): if schedules[sid][i] in "Ff": cur += 1 if schedules[sid][i] in "Ww": cur -= 1 mem[sid][i] = cur peak_mem = max(peak_mem, cur) for i in range(len(schedules[0])): for sid in range(len(schedules)): if schedules[sid][i] == ' ': continue cntr[sid][schedules[sid][i]] += 1 cnt = cntr[sid][schedules[sid][i]] pos = -1 if cnt > 1: pos = loc[sid][cnt - 1][schedules[sid][i]] if schedules[sid][i] == 'W': pos = max(pos, loc[sid][cnt]['B']) if schedules[sid][i] == 'w': pos = max(pos, loc[sid][cnt]['b']) if schedules[sid][i] == 'F' and sid > 0: pos = max(pos, loc[sid - 1][cnt]['F']) if schedules[sid][i] == 'f': if sid != len(schedules) - 1: pos = max(pos, loc[sid + 1][cnt]['f']) else : pos = max(pos, loc[sid][cnt]['F']) if schedules[sid][i] == 'B': if sid != 0: #Because B and W are always combined pos = max(pos, loc[sid - 1][cnt]['W']) else : pos = max(pos, loc[sid][cnt]['f']) if schedules[sid][i] == 'b': if sid != len(schedules) - 1: #Because B and W are always combined pos = max(pos, loc[sid + 1][cnt]['w']) else : pos = max(pos, loc[sid][cnt]['W']) pos += 1 while schedules[sid][pos] != ' ' and pos < i: pos += 1 if schedules[sid][i] in "Bb": while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '): pos += 1 if pos == i: loc[sid][cnt][schedules[sid][i]] = i continue if schedules[sid][i] in "BbWw": schedules[sid][pos] = schedules[sid][i] schedules[sid][i] = ' ' if schedules[sid][pos] in "Ww": for j in range(pos, i): mem[sid][j] -= 1 loc[sid][cnt][schedules[sid][pos]] = pos continue #If F or f: place = i while place > pos and mem[sid][place - 1] < peak_mem: place -= 1 while place < i and schedules[sid][place] != ' ': place += 1 if place == i: loc[sid][cnt][schedules[sid][i]] = i continue pos = place schedules[sid][pos] = schedules[sid][i] schedules[sid][i] = ' ' for j in range(pos, i): mem[sid][j] += 1 loc[sid][cnt][schedules[sid][pos]] = pos return schedules def schedule_by_pattern(p, m, patterns): schedules = init_repeated_schedule(p, m, patterns) schedules = clear_invalid_index(schedules, m) schedules = process_warmup_without_increasing_peak_mem(schedules, m) for sid in range(len(schedules)): cnt = {_id: 0 for _id in "FfBbWw"} for i in range(len(schedules[sid])): if(schedules[sid][i] == ' '): continue if cnt[schedules[sid][i]] >= m: schedules[sid][i] = ' ' else: cnt[schedules[sid][i]] += 1 return schedules def create_whole_pattern(p): whole_pattern = [[0 for _ in range(6)] for _ in range(p)] now = 0 for i in range(p): now += 1 whole_pattern[i][0] = now for i in range(p): now += 1 whole_pattern[p - 1 - i][1] = now now += 1 if p % 3 == 0: now += 3 cyc = (3 - (p + 2) % 3) % 3 for i in range(p): whole_pattern[i][2], whole_pattern[i][4] = now, now + 1 cyc += 1 now += 2 if(cyc == 3): cyc = 0 now += 3 for i in range(p): whole_pattern[p - 1 - i][3], whole_pattern[p - 1 - i][5] = now, now + 1 cyc += 1 now += 2 if(cyc == 3): cyc = 0 now += 3 for sid in range(p): for i in range(6): whole_pattern[sid][i] %= 6 return whole_pattern def schedule(p, m, cost): whole_pattern = create_whole_pattern(p) s = schedule_by_pattern(p, m, whole_pattern) for sid in range(len(s)): for i in range(len(s[sid])): if s[sid][i] in "Ww": s[sid][i] = ' ' res = transform_schedule(s, *cost) return res