Arnav Chavan
remove control panel
aa8b4d6
raw
history blame
5.83 kB
from typing import List
import gradio as gr
import pandas as pd
from src.leaderboard import get_leaderboard_df
from src.llm_perf import get_llm_perf_df
# from attention_implementations import get_attn_decode_fig, get_attn_prefill_fig
# from custom_kernels import get_kernel_decode_fig, get_kernel_prefill_fig
def create_control_panel(
machine: str,
backends: List[str],
hardware_provider: str,
hardware_type: str,
):
# controls
machine_value = gr.State(value=machine)
backends_value = gr.State(value=backends)
hardware_type_value = gr.State(value=hardware_type)
if hardware_provider == "ARM":
backends = ["llama_cpp"]
quantizations = ["Q8_0", "Q4_K_M", "Q4_0_4_4"]
else:
raise ValueError(f"Unknown hardware provider: {hardware_provider}")
with gr.Accordion("Control Panel", open=False, elem_id="control-panel"):
with gr.Row():
with gr.Column(scale=2, variant="panel"):
memory_slider = gr.Slider(
label="Model Size (GB)",
info="🎚️ Slide to maximum Model Size",
minimum=0,
maximum=16,
value=16,
elem_id="memory-slider",
)
with gr.Column(scale=1, variant="panel"):
quantization_checkboxes = gr.CheckboxGroup(
label="Quantizations",
choices=quantizations,
value=quantizations,
info="β˜‘οΈ Select the quantization schemes",
elem_id="quantization-checkboxes",
elem_classes="boxed-option",
)
with gr.Row():
filter_button = gr.Button(
value="Filter πŸš€",
elem_id="filter-button",
elem_classes="boxed-option",
)
return (
filter_button,
machine_value,
backends_value,
hardware_type_value,
memory_slider,
quantization_checkboxes,
)
def filter_rows_fn(
machine,
backends,
hardware_type,
# inputs
memory,
quantizations,
# interactive
columns,
search,
):
llm_perf_df = get_llm_perf_df(
machine=machine, backends=backends, hardware_type=hardware_type
)
# print(attentions)
# print(llm_perf_df["Attention πŸ‘οΈ"].unique())
filtered_llm_perf_df = llm_perf_df[
llm_perf_df["Model"].str.contains(search, case=False)
& llm_perf_df["Quantization"].isin(quantizations)
& llm_perf_df["Model Size (GB)"] <= memory
]
selected_filtered_llm_perf_df = select_columns_fn(
machine, backends, hardware_type, columns, search, filtered_llm_perf_df
)
# filtered_bt_prefill_fig = get_bt_prefill_fig(filtered_df)
# filtered_bt_decode_fig = get_bt_decode_fig(filtered_df)
# filtered_fa2_prefill_fig = get_fa2_prefill_fig(filtered_df)
# filtered_fa2_decode_fig = get_fa2_decode_fig(filtered_df)
# filtered_quant_prefill_fig = get_quant_prefill_fig(filtered_df)
# filtered_quant_decode_fig = get_quant_decode_fig(filtered_df)
return [
selected_filtered_llm_perf_df,
# filtered_bt_prefill_fig,
# filtered_bt_decode_fig,
# filtered_fa2_prefill_fig,
# filtered_fa2_decode_fig,
# filtered_quant_prefill_fig,
# filtered_quant_decode_fig,
]
def create_control_callback(
# button
filter_button,
# fixed
machine_value,
backends_value,
hardware_type_value,
# inputs
memory_slider,
quantization_checkboxes,
# interactive
columns_checkboxes,
search_bar,
# outputs
leaderboard_table,
# attn_prefill_plot,
# attn_decode_plot,
# fa2_prefill_plot,
# fa2_decode_plot,
# quant_prefill_plot,
# quant_decode_plot,
):
filter_button.click(
fn=filter_rows_fn,
inputs=[
# fixed
machine_value,
backends_value,
hardware_type_value,
# inputs
memory_slider,
quantization_checkboxes,
# interactive
columns_checkboxes,
search_bar,
],
outputs=[
leaderboard_table,
# attn_prefill_plot,
# attn_decode_plot,
# fa2_prefill_plot,
# fa2_decode_plot,
# quant_prefill_plot,
# quant_decode_plot,
],
)
def select_columns_fn(
machine, backends, hardware_type, columns, search, llm_perf_df=None
):
if llm_perf_df is None:
llm_perf_df = get_llm_perf_df(
machine=machine,
backends=backends,
hardware_type=hardware_type,
)
selected_leaderboard_df = get_leaderboard_df(llm_perf_df)
selected_leaderboard_df = selected_leaderboard_df[
selected_leaderboard_df["Model"].str.contains(search, case=False)
]
selected_leaderboard_df = selected_leaderboard_df[columns]
return selected_leaderboard_df
def create_select_callback(
# fixed
machine_value,
backends_value,
hardware_type_value,
# interactive
columns_checkboxes,
search_bar,
# outputs
leaderboard_table,
):
columns_checkboxes.change(
fn=select_columns_fn,
inputs=[
machine_value,
backends_value,
hardware_type_value,
columns_checkboxes,
search_bar,
],
outputs=[leaderboard_table],
)
search_bar.change(
fn=select_columns_fn,
inputs=[
machine_value,
backends_value,
hardware_type_value,
columns_checkboxes,
search_bar,
],
outputs=[leaderboard_table],
)