import gradio as gr from src.llm_perf import get_llm_perf_df from src.leaderboard import get_leaderboard_df from src.latency_score_memory import get_lat_score_mem_fig from src.bettertransformer import get_bt_prefill_fig, get_bt_decode_fig from src.flashattentionv2 import get_fa2_prefill_fig, get_fa2_decode_fig from src.quantization_kernels import get_quant_prefill_fig, get_quant_decode_fig def create_control_panel(machine: str = "hf-dgx-01"): # descriptive text gr.HTML("Use this control panel to filter the leaderboard.", elem_id="text") # controls machine_textbox = gr.Textbox(value=machine, visible=False) with gr.Row(): with gr.Column(): search_bar = gr.Textbox( label="Model 🤗", info="🔍 Search for a model name", elem_id="search-bar", ) with gr.Row(): with gr.Column(scale=1, variant="panel"): score_slider = gr.Slider( label="Open LLM Score (%) 📈", info="🎚️ Slide to minimum Open LLM score", value=0, elem_id="threshold-slider", ) with gr.Column(scale=1, variant="panel"): memory_slider = gr.Slider( label="Peak Memory (MB) 📈", info="🎚️ Slide to maximum Peak Memory", minimum=0, maximum=80 * 1024, value=80 * 1024, elem_id="memory-slider", ) with gr.Column(scale=1): backend_checkboxes = gr.CheckboxGroup( label="Backends 🏭", choices=["pytorch"], value=["pytorch"], info="☑️ Select the backends", elem_id="backend-checkboxes", ) with gr.Row(): with gr.Column(scale=1, variant="panel"): datatype_checkboxes = gr.CheckboxGroup( label="Load DTypes 📥", choices=["float32", "float16", "bfloat16"], value=["float32", "float16", "bfloat16"], info="☑️ Select the load data types", elem_id="dtype-checkboxes", ) with gr.Column(scale=1, variant="panel"): optimization_checkboxes = gr.CheckboxGroup( label="Optimizations 🛠️", choices=["None", "BetterTransformer", "FlashAttentionV2"], value=["None", "BetterTransformer", "FlashAttentionV2"], info="☑️ Select the optimization", elem_id="optimization-checkboxes", ) with gr.Column(scale=2): quantization_checkboxes = gr.CheckboxGroup( label="Quantizations 🗜️", choices=[ "None", "BnB.4bit", "BnB.8bit", "GPTQ.4bit", "GPTQ.4bit+ExllamaV1", "GPTQ.4bit+ExllamaV2", "AWQ.4bit+GEMM", "AWQ.4bit+GEMV", ], value=[ "None", "BnB.4bit", "BnB.8bit", "GPTQ.4bit", "GPTQ.4bit+ExllamaV1", "GPTQ.4bit+ExllamaV2", "AWQ.4bit+GEMM", "AWQ.4bit+GEMV", ], info="☑️ Select the quantization schemes", elem_id="quantization-checkboxes", ) with gr.Row(): filter_button = gr.Button( value="Filter 🚀", elem_id="filter-button", ) return ( filter_button, machine_textbox, search_bar, score_slider, memory_slider, backend_checkboxes, datatype_checkboxes, optimization_checkboxes, quantization_checkboxes, ) def filter_fn( machine, model, backends, datatypes, optimizations, quantizations, score, memory, ): raw_df = get_llm_perf_df(machine=machine) filtered_df = raw_df[ raw_df["Model 🤗"].str.contains(model, case=False) & raw_df["Backend 🏭"].isin(backends) & raw_df["DType 📥"].isin(datatypes) & raw_df["Optimization 🛠️"].isin(optimizations) & raw_df["Quantization 🗜️"].isin(quantizations) & (raw_df["Open LLM Score (%)"] >= score) & (raw_df["Allocated Memory (MB)"] <= memory) ] filtered_leaderboard_df = get_leaderboard_df(filtered_df) filtered_lat_score_mem_fig = get_lat_score_mem_fig(filtered_df) filtered_bt_prefill_fig = get_bt_prefill_fig(filtered_df) filtered_bt_decode_fig = get_bt_decode_fig(filtered_df) filtered_fa2_prefill_fig = get_fa2_prefill_fig(filtered_df) filtered_fa2_decode_fig = get_fa2_decode_fig(filtered_df) filtered_quant_prefill_fig = get_quant_prefill_fig(filtered_df) filtered_quant_decode_fig = get_quant_decode_fig(filtered_df) return [ filtered_leaderboard_df, filtered_lat_score_mem_fig, filtered_bt_prefill_fig, filtered_bt_decode_fig, filtered_fa2_prefill_fig, filtered_fa2_decode_fig, filtered_quant_prefill_fig, filtered_quant_decode_fig, ] def create_control_callback( # button filter_button, # inputs machine_textbox, search_bar, score_slider, memory_slider, backend_checkboxes, datatype_checkboxes, optimization_checkboxes, quantization_checkboxes, # outputs leaderboard_table, lat_score_mem_plot, bt_prefill_plot, bt_decode_plot, fa2_prefill_plot, fa2_decode_plot, quant_prefill_plot, quant_decode_plot, ): filter_button.click( fn=filter_fn, inputs=[ machine_textbox, search_bar, backend_checkboxes, datatype_checkboxes, optimization_checkboxes, quantization_checkboxes, score_slider, memory_slider, ], outputs=[ leaderboard_table, lat_score_mem_plot, bt_prefill_plot, bt_decode_plot, fa2_prefill_plot, fa2_decode_plot, quant_prefill_plot, quant_decode_plot, ], )