LLM-Blender / app.py
DongfuJiang's picture
update
62174a3
raw
history blame
12.9 kB
import gradio as gr
import sys
import os
from datasets import load_dataset
from typing import List
MAX_BASE_LLM_NUM = 20
MIN_BASE_LLM_NUM = 3
SOURCE_MAX_LENGTH = 256
DEFAULT_SOURCE_MAX_LENGTH = 128
CANDIDATE_MAX_LENGTH = 256
DEFAULT_CANDIDATE_MAX_LENGTH = 128
FUSER_MAX_NEW_TOKENS = 512
DEFAULT_FUSER_MAX_NEW_TOKENS = 256
DESCRIPTIONS = """# LLM-BLENDER
LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs). LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
"""
EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
EXAMPLES = []
CANDIDATE_EXAMPLES = {}
for example in SHUFFLED_EXAMPLES_DATASET.take(100):
EXAMPLES.append([
example['instruction'],
example['input'],
])
CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
# Download ranker checkpoint
if not os.path.exists("pairranker-deberta-v3-large.zip"):
os.system("gdown https://drive.google.com/uc?id=1EpvFu_qYY0MaIu0BAAhK-sYKHVWtccWg")
if not os.path.exists("pairranker-deberta-v3-large"):
os.system("unzip pairranker-deberta-v3-large.zip")
# Load Blender
import llm_blender
from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
ranker_config = llm_blender.RankerConfig()
ranker_config.ranker_type = "pairranker"
ranker_config.model_type = "deberta"
ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
ranker_config.load_checkpoint = "./pairranker-deberta-v3-large" # ranker checkpoint <your checkpoint path>
ranker_config.source_maxlength = DEFAULT_SOURCE_MAX_LENGTH
ranker_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH
ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
fuser_config = llm_blender.GenFuserConfig()
fuser_config.model_name = "llm-blender/gen_fuser_3b" # our pre-trained fuser
fuser_config.max_length = 1024
fuser_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH
blender_config = llm_blender.BlenderConfig()
blender_config.device = "cpu" # blender ranker and fuser device
blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
def update_base_llms_num(k, llm_outputs):
k = int(k)
return [gr.Dropdown.update(choices=[f"LLM-{i+1}" for i in range(k)],
value=f"LLM-1" if k >= 1 else "", visible=True),
{f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
def display_llm_output(llm_outputs, selected_base_llm_name):
return gr.Textbox.update(value=llm_outputs.get(selected_base_llm_name, ""),
label=selected_base_llm_name + " (Click Save to save current content)",
placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)
def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
llm_outputs.update({selected_base_llm_name: selected_base_llm_output})
return llm_outputs
def get_preprocess_examples(inst, input):
# get the num_of_base_llms
candidates = CANDIDATE_EXAMPLES[inst+input]
num_candiates = len(candidates)
dummy_text = inst+input
return inst, input, num_candiates, dummy_text
def update_base_llm_dropdown_along_examples(dummy_text):
candidates = CANDIDATE_EXAMPLES[dummy_text]
ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
return ex_llm_outputs
def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
if not inst and not input:
raise gr.Error("Please enter instruction or input context")
if not all([x for x in llm_outputs.values()]):
empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output]
raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names)
return {
"inst": inst,
"input": input,
"candidates": list(llm_outputs.values()),
}
def check_fuser_inputs(blender_state, blender_config, ranks):
pass
def llms_rank(inst, input, llm_outputs, blender_config):
candidates = list(llm_outputs.values())
rank_params = {
"source_max_length": blender_config['source_max_length'],
"candidate_max_length": blender_config['candidate_max_length'],
}
ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0]
return [ranks, ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])]
def llms_fuse(blender_state, blender_config, ranks):
inst = blender_state['inst']
input = blender_state['input']
candidates = blender_state['candidates']
top_k_for_fuser = blender_config['top_k_for_fuser']
fuse_params = blender_config.copy()
del fuse_params["top_k_for_fuser"]
top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params)[0]
return [fuser_outputs, fuser_outputs]
def display_fuser_output(fuser_output):
return fuser_output
with gr.Blocks(theme='ParityError/Anime') as demo:
gr.Markdown(DESCRIPTIONS)
gr.Markdown("## Input and Base LLMs")
with gr.Row():
with gr.Column():
inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
with gr.Column():
saved_llm_outputs = gr.State(value={})
with gr.Group():
selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
placeholder="Enter LLM-1 output here", show_label=True)
with gr.Row():
base_llm_outputs_save_button = gr.Button('Save', variant='primary')
base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
base_llms_num = gr.Slider(
label='Number of base llms',
minimum=MIN_BASE_LLM_NUM,
maximum=MAX_BASE_LLM_NUM,
step=1,
value=MIN_BASE_LLM_NUM,
)
blender_state = gr.State(value={})
saved_rank_outputs = gr.State(value=[])
saved_fuse_outputs = gr.State(value=[])
gr.Markdown("## Blender Outputs")
with gr.Group():
rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
with gr.Row():
rank_button = gr.Button('Rank LLM Outputs', variant='primary')
fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
clear_button = gr.Button('Clear Blender Outputs', variant='primary')
blender_config = gr.State(value={
"source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
"candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
"top_k_for_fuser": 3,
"max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
"temperature": 0.7,
"top_p": 1.0,
})
with gr.Accordion(label='Advanced options', open=False):
source_max_length = gr.Slider(
label='Max length of Instruction + Input',
minimum=1,
maximum=SOURCE_MAX_LENGTH,
step=1,
value=DEFAULT_SOURCE_MAX_LENGTH,
)
candidate_max_length = gr.Slider(
label='Max length of LLM-Output Candidate',
minimum=1,
maximum=CANDIDATE_MAX_LENGTH,
step=1,
value=DEFAULT_CANDIDATE_MAX_LENGTH,
)
top_k_for_fuser = gr.Slider(
label='Top-k ranked candidates to fuse',
minimum=1,
maximum=3,
step=1,
value=3,
)
max_new_tokens = gr.Slider(
label='Max new tokens fuser can generate',
minimum=1,
maximum=FUSER_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_FUSER_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label='Temperature of fuser generation',
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.7,
)
top_p = gr.Slider(
label='Top-p of fuser generation',
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
)
examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
batch_examples = gr.Examples(
examples=EXAMPLES,
fn=get_preprocess_examples,
cache_examples=True,
examples_per_page=5,
inputs=[inst_textbox, input_textbox],
outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
)
base_llms_num.change(
fn=update_base_llms_num,
inputs=[base_llms_num, saved_llm_outputs],
outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
)
examples_dummy_textbox.change(
fn=update_base_llm_dropdown_along_examples,
inputs=[examples_dummy_textbox],
outputs=saved_llm_outputs,
).then(
fn=display_llm_output,
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
outputs=selected_base_llm_output,
)
selected_base_llm_name_dropdown.change(
fn=display_llm_output,
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
outputs=selected_base_llm_output,
)
base_llm_outputs_save_button.click(
fn=save_llm_output,
inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
outputs=saved_llm_outputs,
)
base_llm_outputs_clear_all_button.click(
fn=lambda: [{}, ""],
inputs=[],
outputs=[saved_llm_outputs, selected_base_llm_output],
)
base_llm_outputs_clear_single_button.click(
fn=lambda: "",
inputs=[],
outputs=selected_base_llm_output,
)
rank_button.click(
fn=check_save_ranker_inputs,
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
outputs=blender_state,
).success(
fn=llms_rank,
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
outputs=[saved_rank_outputs, rank_outputs],
)
fuse_button.click(
fn=check_fuser_inputs,
inputs=[blender_state, blender_config, saved_rank_outputs],
outputs=[],
).success(
fn=llms_fuse,
inputs=[blender_state, blender_config, saved_rank_outputs],
outputs=[saved_fuse_outputs, fuser_outputs],
)
clear_button.click(
fn=lambda: ["", "", {}, []],
inputs=[],
outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
)
# update blender config
source_max_length.change(
fn=lambda x, y: y.update({"source_max_length": x}) or y,
inputs=[source_max_length, blender_config],
outputs=blender_config,
)
candidate_max_length.change(
fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
inputs=[candidate_max_length, blender_config],
outputs=blender_config,
)
top_k_for_fuser.change(
fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
inputs=[top_k_for_fuser, blender_config],
outputs=blender_config,
)
max_new_tokens.change(
fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
inputs=[max_new_tokens, blender_config],
outputs=blender_config,
)
temperature.change(
fn=lambda x, y: y.update({"temperature": x}) or y,
inputs=[temperature, blender_config],
outputs=blender_config,
)
top_p.change(
fn=lambda x, y: y.update({"top_p": x}) or y,
inputs=[top_p, blender_config],
outputs=blender_config,
)
demo.queue(max_size=20).launch()