|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import datasets |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import polars as pl |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B", trust_remote_code=True) |
|
alpaca = datasets.load_dataset("tatsu-lab/alpaca", split="train").map( |
|
lambda ex: {"tokens": tokenizer(ex["text"])["input_ids"].__len__()}, num_proc=4 |
|
) |
|
|
|
|
|
pdf = pl.DataFrame(alpaca.to_pandas()).with_columns(index=pl.int_range(0, pl.count())) |
|
tokens = pdf["tokens"].to_numpy() |
|
|
|
|
|
|
|
|
|
def plot_batch(batch_size): |
|
|
|
data = pdf["tokens"].to_numpy().copy() |
|
|
|
data = data[:batch_size] |
|
|
|
max_value = max(data) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
for i, value in enumerate(data): |
|
fig.add_trace( |
|
go.Bar( |
|
x=[value], |
|
y=[i + 1], |
|
|
|
orientation="h", |
|
marker_color="blue", |
|
) |
|
) |
|
fig.add_trace( |
|
go.Bar( |
|
x=[max_value - value], |
|
y=[i + 1], |
|
|
|
orientation="h", |
|
marker_color="red", |
|
) |
|
) |
|
|
|
|
|
fig.update_layout( |
|
barmode="stack", |
|
|
|
|
|
|
|
showlegend=False, |
|
xaxis=dict(range=[0, max_value]), |
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
def packing(pocket=8192): |
|
num_pocket = 0 |
|
buffers = 0 |
|
|
|
for token in tokens: |
|
tmp_len = buffers + token |
|
if tmp_len > pocket: |
|
num_pocket += 1 |
|
buffers = token |
|
else: |
|
buffers = tmp_len |
|
if buffers: |
|
num_pocket += 1 |
|
return num_pocket * pocket / tokens.sum() |
|
|
|
|
|
|
|
|
|
plot_batch(30) |
|
|
|
|
|
arrs = [] |
|
|
|
for batch_size in range(1, 100): |
|
arr = ( |
|
pdf.with_columns( |
|
batch=pl.col("tokens").max().over(pl.col("index") // batch_size) |
|
) |
|
.select( |
|
pl.col("tokens").sum().over(pl.col("index") // batch_size).mean(), |
|
((pl.col("batch")) / pl.col("tokens")).mean(), |
|
) |
|
.to_numpy() |
|
) |
|
arrs.append(arr) |
|
x_values, y_values = np.concatenate(arrs).transpose() |
|
pxs = np.linspace(tokens.max(), x_values[-1], 100) |
|
pys = [packing(pocket) for pocket in pxs] |
|
|
|
|
|
fig = go.Figure() |
|
|
|
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode="lines", name="Batching")) |
|
|
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=pxs, |
|
y=pys, |
|
mode="lines", |
|
name="Packing", |
|
|
|
) |
|
) |
|
|
|
worst = tokens.max() / tokens.mean() |
|
fig.add_trace( |
|
go.Scatter( |
|
x=x_values, |
|
y=[worst] * len(x_values), |
|
mode="lines", |
|
name="Worst", |
|
line=dict(dash="dash"), |
|
) |
|
) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=[8192], |
|
y=[packing(8192)], |
|
mode="markers", |
|
name="Chosen", |
|
|
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
|
xaxis_title="throughput(tokens)", |
|
yaxis_title="computational cost(ratio)", |
|
yaxis=dict(range=[0, worst + 1]), |
|
) |
|
|
|
|
|
|
|
|
|
fig.show() |
|
|