ml-summit / src /scripts /plot.ju1.py
facat's picture
init
2fc4496 unverified
raw
history blame contribute delete
No virus
3.91 kB
# %%
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets
import plotly.graph_objects as go
import numpy as np
import polars as pl
tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B", trust_remote_code=True)
alpaca = datasets.load_dataset("tatsu-lab/alpaca", split="train").map(
lambda ex: {"tokens": tokenizer(ex["text"])["input_ids"].__len__()}, num_proc=4
)
pdf = pl.DataFrame(alpaca.to_pandas()).with_columns(index=pl.int_range(0, pl.count()))
tokens = pdf["tokens"].to_numpy()
# %%
def plot_batch(batch_size):
# 数据
data = pdf["tokens"].to_numpy().copy()
# np.random.shuffle(data)
data = data[:batch_size]
# 计算最大值
max_value = max(data)
# 创建横向柱状图
fig = go.Figure()
# 为每个数据点添加两个柱子,一个表示原始值,另一个表示与最大值的差
for i, value in enumerate(data):
fig.add_trace(
go.Bar(
x=[value],
y=[i + 1],
# name='原始值',
orientation="h",
marker_color="blue",
)
)
fig.add_trace(
go.Bar(
x=[max_value - value],
y=[i + 1],
# name='与最大值的差',
orientation="h",
marker_color="red",
)
)
# 更新图表布局
fig.update_layout(
barmode="stack", # 堆叠模式
# title="横向柱状图:蓝色表示原始数值,红色表示与最大值的差",
# xaxis_title="数值",
# yaxis_title="数据点",
showlegend=False,
xaxis=dict(range=[0, max_value]),
)
# 显示图表
return fig
def packing(pocket=8192):
num_pocket = 0
buffers = 0
for token in tokens:
tmp_len = buffers + token
if tmp_len > pocket:
num_pocket += 1
buffers = token
else:
buffers = tmp_len
if buffers:
num_pocket += 1
return num_pocket * pocket / tokens.sum()
# %%
plot_batch(30)
# %%
arrs = []
# for batch_size in np.linspace(1, len(pdf), 100, dtype=int):
for batch_size in range(1, 100):
arr = (
pdf.with_columns(
batch=pl.col("tokens").max().over(pl.col("index") // batch_size)
)
.select(
pl.col("tokens").sum().over(pl.col("index") // batch_size).mean(),
((pl.col("batch")) / pl.col("tokens")).mean(),
)
.to_numpy()
)
arrs.append(arr)
x_values, y_values = np.concatenate(arrs).transpose()
pxs = np.linspace(tokens.max(), x_values[-1], 100)
pys = [packing(pocket) for pocket in pxs]
fig = go.Figure()
# Adding the line plot for the function
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode="lines", name="Batching"))
# Adding a special point (70, 100)
fig.add_trace(
go.Scatter(
x=pxs,
y=pys,
mode="lines",
name="Packing",
# marker=dict(color="red", size=10),
)
)
worst = tokens.max() / tokens.mean()
fig.add_trace(
go.Scatter(
x=x_values,
y=[worst] * len(x_values),
mode="lines",
name="Worst",
line=dict(dash="dash"),
)
)
fig.add_trace(
go.Scatter(
x=[8192],
y=[packing(8192)],
mode="markers",
name="Chosen",
# marker=dict(color="green", size=10),
)
)
# fig.add_hline(
# y=worst,
# # mode="markers",
# line_dash="dash",
# annotation_text="Worst",
# # marker=dict(color="green", size=10),
# )
# Updating the layout
fig.update_layout(
# title="Sample Function Plot with a Special Point",
xaxis_title="throughput(tokens)",
yaxis_title="computational cost(ratio)",
yaxis=dict(range=[0, worst + 1]),
)
# The plot is ready to be shown
# fig.write_image("../../docs/1227-moda/figures/packing.png")
fig.show()