Nyamdavaa Amar
Pipeline Parallelism with Controllable Memory
3d4d40d
raw
history blame
No virus
10.3 kB
import gradio as gr
import hand_schedule
import adaptive_schedule
import interleaved_variant
import type2
import schedule1f1bv
from PIL import Image
from svg_event import render_manual_graph
import pathlib
def percentage(x):
return f"{x*100:.2f}%"
def get_schedule_time(result):
result = [
list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
]
time = max(
[
max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result
]
)
return time
def get_memory_usage(result):
max_mem = 0
has_w = False
for r in result:
for x in r:
if x.type in ('W', 'w'):
has_w = True
for r in result:
cur = 0
for x in r:
if x.type in ('F', 'f'):
cur += 1
if x.type in ('W', 'w'):
cur -= 1
if has_w == False and x.type in ('B', 'b'):
cur -= 1
max_mem = max(max_mem, cur)
return max_mem
img_queue = []
def get_schedule_image(result, max_time):
result = [
list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
]
svg = render_manual_graph(result, max_time, len(result[0]) <= 72)
img_queue.append(svg)
if len(img_queue) > 32:
poped = img_queue.pop(0)
pathlib.Path(poped).unlink()
return pathlib.Path(svg)
def calculate(p, m, f, b, w, c, mem):
baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
baseline_result = [
list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
]
baseline_time = get_schedule_time(baseline_result)
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
baseline_mem = get_memory_usage(baseline_result)
baseline_acceleration=percentage(0)
adapt_result = adaptive_schedule.schedule(
p,
m,
[f/2, b/2, w/2, c],
max_mem=mem * 2
)
adapt_time = get_schedule_time(adapt_result)
adapt_mem = get_memory_usage(adapt_result) / 2
adapt_bubble=percentage(adapt_time/(f+b+w)/m - 1)
adapt_acceleration=percentage(baseline_time/adapt_time - 1) if baseline_time is not None else None
schedule1f1bv_result = schedule1f1bv.schedule(
p,
m,
[f / 2, b / 2, w / 2, c]
)
schedule1f1bv_time = get_schedule_time(schedule1f1bv_result)
schedule1f1bv_mem = get_memory_usage(schedule1f1bv_result) / 2
schedule1f1bv_bubble=percentage(schedule1f1bv_time/(f+b+w)/m - 1)
schedule1f1bv_acceleration=percentage(baseline_time/schedule1f1bv_time - 1) if baseline_time is not None else None
type2_result = type2.schedule(
p,
m,
[f, b, w, c]
)
type2_time = get_schedule_time(type2_result)
type2_mem = get_memory_usage(type2_result)
type2_bubble=percentage(type2_time/(f+b+w)/m - 1)
type2_acceleration=percentage(baseline_time/type2_time - 1) if baseline_time is not None else None
interleaved_result = interleaved_variant.get_interleaved_variation(
p,
m,
[f/2, b/2, w/2, c]
)
interleaved_time = get_schedule_time(interleaved_result)
interleaved_mem = get_memory_usage(interleaved_result) / 2
interleaved_bubble=percentage(interleaved_time/(f+b+w)/m - 1)
interleaved_acceleration=percentage(baseline_time/interleaved_time - 1) if baseline_time is not None else None
max_time = max(filter(lambda x: x is not None, [baseline_time, adapt_time, interleaved_time, type2_time, schedule1f1bv_time]))
print(max_time)
if baseline_result is not None:
baseline_image = get_schedule_image(baseline_result, max_time)
if adapt_result is not None:
adapt_image = get_schedule_image(adapt_result, max_time)
if interleaved_result is not None:
interleaved_image = get_schedule_image(interleaved_result, max_time)
if type2_result is not None:
type2_image = get_schedule_image(type2_result, max_time)
if schedule1f1bv_result is not None:
schedule1f1bv_image = get_schedule_image(schedule1f1bv_result, max_time)
return [baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
type2_acceleration, type2_mem, type2_bubble, type2_image,
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image]
with gr.Blocks() as demo:
gr.Markdown(open("description1.md").read())
gr.Markdown("# Pipeline Scheduler Playground")
presets = {
'Real Case': (6, 12, 1049, 1122, 903, 79, 'V-Half'),
'Ideal Case': (6, 12, 20, 20, 20, 0, 'V-Min'),
'Zero Bubble Case': (6, 12, 1049, 1122, 903, 79, 'V-ZB')
}
preset_buttons = {}
with gr.Group():
gr.Markdown("Preset Setups")
with gr.Row():
for (k, v) in presets.items():
preset_buttons[k] = gr.Button(k, variant="secondary")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("Basic Parameters")
with gr.Row():
p=gr.Number(label="Number of stages (p)", value=6, interactive=True, precision=0)
m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
with gr.Row():
f=gr.Number(label="Time of F", value=1049, interactive=True, precision=0)
b=gr.Number(label="Time of B", value=1122, interactive=True, precision=0)
w=gr.Number(label="Time of W", value=903, interactive=True, precision=0)
c=gr.Number(label="Time of one P2P communication", value=79, interactive=True, precision=0)
with gr.Group():
gr.Markdown("Activation memory limit.")
def update_mem(p, s, mem):
print("update")
if s == "custom":
return mem
if s == "V-Min":
return (p + 4) // 3
if s == "V-Half":
return (p + 2) // 2
if s == "V-ZB":
return p
assert False
memsel=gr.Radio(choices=["V-Min", "V-Half", "V-ZB", "custom"], value="V-Half")
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
button=gr.Button("Calculate", variant="primary")
with gr.Group():
gr.Markdown("1F1B")
with gr.Row():
with gr.Column(scale=1):
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
baseline_mem=gr.Textbox("", label="Maximum memory usage")
baseline_bubble=gr.Textbox("", label="Bubble Rate")
with gr.Column(scale=4):
baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
with gr.Group():
gr.Markdown("Adaptive Scheduler")
with gr.Row():
with gr.Column(scale=1):
adapt_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
adapt_mem=gr.Textbox("", label="Maximum memory usage")
adapt_bubble=gr.Textbox("", label="Bubble Rate")
with gr.Column(scale=4):
adapt_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
gr.Markdown(open("description2.md").read())
with gr.Group():
gr.Markdown("1F1B-V Schedule")
with gr.Row():
with gr.Column(scale=1):
schedule1f1bv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
schedule1f1bv_mem=gr.Textbox("", label="Maximum memory usage")
schedule1f1bv_bubble=gr.Textbox("", label="Bubble Rate")
with gr.Column(scale=4):
schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
with gr.Group():
gr.Markdown("Two microbatch in one building block schedule")
with gr.Row():
with gr.Column(scale=1):
type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
type2_mem=gr.Textbox("", label="Maximum memory usage")
type2_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
with gr.Column(scale=4):
type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
with gr.Group():
gr.Markdown("Interleaved 1F1B Schedule")
with gr.Row():
with gr.Column(scale=1):
interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
interleaved_mem=gr.Textbox("", label="Maximum memory usage")
interleaved_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
with gr.Column(scale=4):
interleaved_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
type2_acceleration, type2_mem, type2_bubble, type2_image,
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
for (k, v) in presets.items():
def update_preset(pb, p, m, f, b, w, c, mem):
print(pb)
print(presets[pb])
print(presets[pb][-1])
return *presets[pb],*calculate(*presets[pb][:-1], update_mem(p, presets[pb][-1], -1))
preset_buttons[k].click(
update_preset,
inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
outputs=[p, m, f, b, w, c, memsel,
baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
type2_acceleration, type2_mem, type2_bubble, type2_image,
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
demo.launch()