import gradio as gr import sys import os import zipfile from datasets import load_dataset from typing import List MAX_BASE_LLM_NUM = 20 MIN_BASE_LLM_NUM = 3 SOURCE_MAX_LENGTH = 256 DEFAULT_SOURCE_MAX_LENGTH = 128 CANDIDATE_MAX_LENGTH = 256 DEFAULT_CANDIDATE_MAX_LENGTH = 128 FUSER_MAX_NEW_TOKENS = 512 DEFAULT_FUSER_MAX_NEW_TOKENS = 256 DESCRIPTIONS = """# LLM-BLENDER LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs). LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs. """ EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True) SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000) EXAMPLES = [] CANDIDATE_EXAMPLES = {} for example in SHUFFLED_EXAMPLES_DATASET.take(100): EXAMPLES.append([ example['instruction'], example['input'], ]) CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates'] # Download ranker checkpoint if not os.path.exists("pairranker-deberta-v3-large.zip"): os.system("gdown https://drive.google.com/uc?id=1EpvFu_qYY0MaIu0BAAhK-sYKHVWtccWg") if not os.path.exists("pairranker-deberta-v3-large"): with zipfile.ZipFile("pairranker-deberta-v3-large.zip", 'r') as zip_ref: zip_ref.extractall(".") # Load Blender import llm_blender from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks ranker_config = llm_blender.RankerConfig() ranker_config.ranker_type = "pairranker" ranker_config.model_type = "deberta" ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone ranker_config.load_checkpoint = "./pairranker-deberta-v3-large" # ranker checkpoint ranker_config.source_maxlength = DEFAULT_SOURCE_MAX_LENGTH ranker_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1. fuser_config = llm_blender.GenFuserConfig() fuser_config.model_name = "llm-blender/gen_fuser_3b" # our pre-trained fuser fuser_config.max_length = 1024 fuser_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH blender_config = llm_blender.BlenderConfig() blender_config.load_in_8bit = True blender_config.device = "cuda" # blender ranker and fuser device blender = llm_blender.Blender(blender_config, ranker_config, fuser_config) def update_base_llms_num(k, llm_outputs): k = int(k) return [gr.Dropdown.update(choices=[f"LLM-{i+1}" for i in range(k)], value=f"LLM-1" if k >= 1 else "", visible=True), {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}] def display_llm_output(llm_outputs, selected_base_llm_name): return gr.Textbox.update(value=llm_outputs.get(selected_base_llm_name, ""), label=selected_base_llm_name + " (Click Save to save current content)", placeholder=f"Enter {selected_base_llm_name} output here", show_label=True) def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs): llm_outputs.update({selected_base_llm_name: selected_base_llm_output}) return llm_outputs def get_preprocess_examples(inst, input): # get the num_of_base_llms candidates = CANDIDATE_EXAMPLES[inst+input] num_candiates = len(candidates) dummy_text = inst+input return inst, input, num_candiates, dummy_text def update_base_llm_dropdown_along_examples(dummy_text): candidates = CANDIDATE_EXAMPLES[dummy_text] ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))} return ex_llm_outputs, "", "" def check_save_ranker_inputs(inst, input, llm_outputs, blender_config): if not inst and not input: raise gr.Error("Please enter instruction or input context") if not all([x for x in llm_outputs.values()]): empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output] raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names) return { "inst": inst, "input": input, "candidates": list(llm_outputs.values()), } def check_fuser_inputs(blender_state, blender_config, ranks): if not (blender_state.get("inst", None) or blender_state.get("input", None)): raise gr.Error("Please enter instruction or input context") if "candidates" not in blender_state or len(ranks)==0: raise gr.Error("Please rank LLM outputs first") return def llms_rank(inst, input, llm_outputs, blender_config): candidates = list(llm_outputs.values()) rank_params = { "source_max_length": blender_config['source_max_length'], "candidate_max_length": blender_config['candidate_max_length'], } ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0] return [ranks, ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])] def llms_fuse(blender_state, blender_config, ranks): inst = blender_state['inst'] input = blender_state['input'] candidates = blender_state['candidates'] top_k_for_fuser = blender_config['top_k_for_fuser'] fuse_params = blender_config.copy() fuse_params.pop("top_k_for_fuser") fuse_params.pop("source_max_length") fuse_params['no_repeat_ngram_size'] = 3 top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0] fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params, batch_size=1)[0] return [fuser_outputs, fuser_outputs] def display_fuser_output(fuser_output): return fuser_output with gr.Blocks(theme='ParityError/Anime') as demo: gr.Markdown(DESCRIPTIONS) gr.Markdown("## Input and Base LLMs") with gr.Row(): with gr.Column(): inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True) input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True) with gr.Column(): saved_llm_outputs = gr.State(value={}) with gr.Group(): selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM", choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True) selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)", placeholder="Enter LLM-1 output here", show_label=True) with gr.Row(): base_llm_outputs_save_button = gr.Button('Save', variant='primary') base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary') base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary') base_llms_num = gr.Slider( label='Number of base llms', minimum=MIN_BASE_LLM_NUM, maximum=MAX_BASE_LLM_NUM, step=1, value=MIN_BASE_LLM_NUM, ) blender_state = gr.State(value={}) saved_rank_outputs = gr.State(value=[]) saved_fuse_outputs = gr.State(value=[]) gr.Markdown("## Blender Outputs") with gr.Group(): rank_outputs = gr.Textbox(lines=1, label="Ranks of each LLM's output", placeholder="Ranking outputs", show_label=True) fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True) with gr.Row(): rank_button = gr.Button('Rank LLM Outputs', variant='primary') fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary') clear_button = gr.Button('Clear Blender Outputs', variant='primary') blender_config = gr.State(value={ "source_max_length": DEFAULT_SOURCE_MAX_LENGTH, "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH, "top_k_for_fuser": 3, "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS, "temperature": 0.7, "top_p": 1.0, }) with gr.Accordion(label='Advanced options', open=False): top_k_for_fuser = gr.Slider( label='Top-k ranked candidates to fuse', minimum=1, maximum=3, step=1, value=3, ) source_max_length = gr.Slider( label='Max length of Instruction + Input', minimum=1, maximum=SOURCE_MAX_LENGTH, step=1, value=DEFAULT_SOURCE_MAX_LENGTH, ) candidate_max_length = gr.Slider( label='Max length of LLM-Output Candidate', minimum=1, maximum=CANDIDATE_MAX_LENGTH, step=1, value=DEFAULT_CANDIDATE_MAX_LENGTH, ) max_new_tokens = gr.Slider( label='Max new tokens fuser can generate', minimum=1, maximum=FUSER_MAX_NEW_TOKENS, step=1, value=DEFAULT_FUSER_MAX_NEW_TOKENS, ) # temperature = gr.Slider( # label='Temperature of fuser generation', # minimum=0.1, # maximum=2.0, # step=0.1, # value=0.7, # ) # top_p = gr.Slider( # label='Top-p of fuser generation', # minimum=0.05, # maximum=1.0, # step=0.05, # value=1.0, # ) beam_size = gr.Slider( label='Beam size of fuser generation', minimum=1, maximum=10, step=1, value=4, ) examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False) batch_examples = gr.Examples( examples=EXAMPLES, fn=get_preprocess_examples, cache_examples=True, examples_per_page=5, inputs=[inst_textbox, input_textbox], outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox], ) base_llms_num.change( fn=update_base_llms_num, inputs=[base_llms_num, saved_llm_outputs], outputs=[selected_base_llm_name_dropdown, saved_llm_outputs], ) examples_dummy_textbox.change( fn=update_base_llm_dropdown_along_examples, inputs=[examples_dummy_textbox], outputs=[saved_llm_outputs, rank_outputs, fuser_outputs], ).then( fn=display_llm_output, inputs=[saved_llm_outputs, selected_base_llm_name_dropdown], outputs=selected_base_llm_output, ) selected_base_llm_name_dropdown.change( fn=display_llm_output, inputs=[saved_llm_outputs, selected_base_llm_name_dropdown], outputs=selected_base_llm_output, ) base_llm_outputs_save_button.click( fn=save_llm_output, inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs], outputs=saved_llm_outputs, ) base_llm_outputs_clear_all_button.click( fn=lambda: [{}, ""], inputs=[], outputs=[saved_llm_outputs, selected_base_llm_output], ) base_llm_outputs_clear_single_button.click( fn=lambda: "", inputs=[], outputs=selected_base_llm_output, ) rank_button.click( fn=check_save_ranker_inputs, inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config], outputs=blender_state, ).success( fn=llms_rank, inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config], outputs=[saved_rank_outputs, rank_outputs], ) fuse_button.click( fn=check_fuser_inputs, inputs=[blender_state, blender_config, saved_rank_outputs], outputs=fuser_outputs, ).success( fn=llms_fuse, inputs=[blender_state, blender_config, saved_rank_outputs], outputs=[saved_fuse_outputs, fuser_outputs], ) clear_button.click( fn=lambda: ["", "", {}, []], inputs=[], outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs], ) # update blender config source_max_length.change( fn=lambda x, y: y.update({"source_max_length": x}) or y, inputs=[source_max_length, blender_config], outputs=blender_config, ) candidate_max_length.change( fn=lambda x, y: y.update({"candidate_max_length": x}) or y, inputs=[candidate_max_length, blender_config], outputs=blender_config, ) top_k_for_fuser.change( fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y, inputs=[top_k_for_fuser, blender_config], outputs=blender_config, ) max_new_tokens.change( fn=lambda x, y: y.update({"max_new_tokens": x}) or y, inputs=[max_new_tokens, blender_config], outputs=blender_config, ) # temperature.change( # fn=lambda x, y: y.update({"temperature": x}) or y, # inputs=[temperature, blender_config], # outputs=blender_config, # ) # top_p.change( # fn=lambda x, y: y.update({"top_p": x}) or y, # inputs=[top_p, blender_config], # outputs=blender_config, # ) beam_size.change( fn=lambda x, y: y.update({"num_beams": x}) or y, inputs=[beam_size, blender_config], outputs=blender_config, ) demo.queue(max_size=20).launch()