File size: 1,675 Bytes
ab5f5f1
 
 
 
 
 
 
 
0232cf1
 
dc685a9
 
ab5f5f1
0232cf1
 
 
 
ab5f5f1
 
 
 
 
 
 
 
0232cf1
ab5f5f1
0232cf1
 
ab5f5f1
 
 
 
 
 
 
 
 
 
57896bb
ab5f5f1
 
 
 
 
dc685a9
ab5f5f1
 
 
 
 
 
 
 
 
 
 
0232cf1
ab5f5f1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import plotly.express as px


SCORE_MEMORY_LATENCY_DATA = [
    "Model πŸ€—",
    "DType πŸ“₯",
    "Backend 🏭",
    "Params (B)",
    "Architecture πŸ›οΈ",
    "Optimization πŸ› οΈ",
    "Quantization πŸ—œοΈ",
    "Open LLM Score (%)",
    "Prefill (s)",
    "Decode (tokens/s)",
    "Memory (MB)",
    "End-to-End (s)",
]


def get_lat_score_mem_fig(llm_perf_df):
    copy_df = llm_perf_df.copy()
    # plot
    fig = px.scatter(
        copy_df,
        x="End-to-End (s)",
        y="Open LLM Score (%)",
        size="Memory (MB)",
        color="Architecture πŸ›οΈ",
        custom_data=SCORE_MEMORY_LATENCY_DATA,
        color_discrete_sequence=px.colors.qualitative.Light24,
    )
    fig.update_traces(
        hovertemplate="<br>".join(
            [f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(SCORE_MEMORY_LATENCY_DATA)]
        )
    )
    fig.update_layout(
        title={
            "text": "Latency vs. Score vs. Memory",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="Time To Generate 256 Tokens (s)",
        yaxis_title="Open LLM Score (%)",
        legend_title="LLM Architecture",
        width=1200,
        height=600,
    )

    return fig


def create_lat_score_mem_plot(llm_perf_df):
    # descriptive text
    gr.HTML("πŸ‘† Hover over the points πŸ‘† for additional information. ", elem_id="text")
    # get figure
    fig = get_lat_score_mem_fig(llm_perf_df)
    # create plot
    plot = gr.components.Plot(
        value=fig,
        elem_id="plot",
        show_label=False,
    )

    return plot