File size: 9,698 Bytes
9123479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import gradio as gr
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from datasets import load_dataset
from typing import List

MAX_BASE_LLM_NUM = 20
MIN_BASE_LLM_NUM = 3
DESCRIPTIONS = """
"""
MAX_MAX_NEW_TOKENS=1024
DEFAULT_MAX_NEW_TOKENS=256
EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
EXAMPLES = []
CANDIDATE_EXAMPLES = {}
for example in SHUFFLED_EXAMPLES_DATASET.take(100):
    EXAMPLES.append([
        example['instruction'],
        example['input'],
    ])
    CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']

# Download ranker checkpoint
if not os.path.exists("pairranker-deberta-v3-large.zip"):
    os.system("gdown https://drive.google.com/uc?id=1EpvFu_qYY0MaIu0BAAhK-sYKHVWtccWg")
if not os.path.exists("pairranker-deberta-v3-large"):
    os.system("unzip pairranker-deberta-v3-large.zip")

# Load Blender
import llm_blender
from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
ranker_config = llm_blender.RankerConfig()
ranker_config.ranker_type = "pairranker"
ranker_config.model_type = "deberta"
ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
ranker_config.load_checkpoint = "./pairranker-deberta-v3-large" # ranker checkpoint <your checkpoint path>
ranker_config.source_maxlength = 128
ranker_config.candidate_maxlength = 128
ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
fuser_config = llm_blender.GenFuserConfig()
fuser_config.model_name = "llm-blender/gen_fuser_3b" # our pre-trained fuser
fuser_config.max_length = 1024
fuser_config.candidate_maxlength = 128
blender_config = llm_blender.BlenderConfig()
blender_config.device = "cpu" # blender ranker and fuser device
blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)

def update_base_llms_num(k, llm_outputs):
    k = int(k)
    return [gr.Dropdown.update(choices=[f"LLM-{i+1}" for i in range(k)], 
        value=f"LLM-1" if k >= 1 else "", visible=True),
        {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
    

def display_llm_output(llm_outputs, selected_base_llm_name):
    return gr.Textbox.update(value=llm_outputs.get(selected_base_llm_name, ""), 
        label=selected_base_llm_name + " (Click Save to save current content)", 
        placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)

def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
    llm_outputs.update({selected_base_llm_name: selected_base_llm_output})
    return llm_outputs

def get_preprocess_examples(inst, input):
    # get the num_of_base_llms
    candidates = CANDIDATE_EXAMPLES[inst+input]
    num_candiates = len(candidates)
    dummy_text = inst+input
    return inst, input, num_candiates, dummy_text

def update_base_llm_dropdown_along_examples(dummy_text):
    candidates = CANDIDATE_EXAMPLES[dummy_text]
    ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
    return ex_llm_outputs
    
def check_save_ranker_inputs(inst, input, llm_outputs):
    if not inst and not input:
        raise gr.Error("Please enter instruction or input context")
    
    if not all([x for x in llm_outputs.values()]):
        empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output]
        raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names)
    return {
        "inst": inst,
        "input": input,
        "candidates": list(llm_outputs.values()),
    }

def check_fuser_inputs(blender_state, top_k_for_fuser, ranks):
    pass

def llms_rank(inst, input, llm_outputs):
    candidates = list(llm_outputs.values())
    
    return blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0]

def display_ranks(ranks):
    return ",  ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])

def llms_fuse(blender_state, top_k_for_fuser, ranks):
    inst = blender_state['inst']
    input = blender_state['input']
    candidates = blender_state['candidates']
    top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
    return blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates])[0]

def display_fuser_output(fuser_output):
    return fuser_output

        
with gr.Blocks(theme='ParityError/Anime') as demo:
    gr.Markdown(DESCRIPTIONS)
    with gr.Row():
        with gr.Column():
            inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
            input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
        with gr.Column():
            saved_llm_outputs = gr.State(value={})
            selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
                choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
            selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
                placeholder="Enter LLM-1 output here", show_label=True)
            with gr.Row():
                base_llm_outputs_save_button = gr.Button('Save', variant='primary')
                
                base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
                
                base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
            base_llms_num = gr.Slider(
                    label='Number of base llms',
                    minimum=MIN_BASE_LLM_NUM,
                    maximum=MAX_BASE_LLM_NUM,
                    step=1,
                    value=MIN_BASE_LLM_NUM,
                )
    
    blender_state = gr.State(value={})
    with gr.Tab("Ranking outputs"):
        saved_rank_outputs = gr.State(value=[])
        rank_outputs = gr.Textbox(lines=4, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
    with gr.Tab("Fusing outputs"):
        saved_fuse_outputs = gr.State(value=[])
        fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
    with gr.Row():
        rank_button = gr.Button('Rank LLM Outputs', variant='primary',
            scale=1, min_width=0)
        fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary',
            scale=1, min_width=0)
        clear_button = gr.Button('Clear Blender', variant='primary',
            scale=1, min_width=0)
        
    with gr.Accordion(label='Advanced options', open=False):
        
        top_k_for_fuser = gr.Slider(
            label='Top k for fuser',
            minimum=1,
            maximum=3,
            step=1,
            value=1,
        )
    
    examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)     
    batch_examples = gr.Examples(
        examples=EXAMPLES,
        fn=get_preprocess_examples,
        cache_examples=True,
        examples_per_page=5,
        inputs=[inst_textbox, input_textbox],
        outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
    )
        
    base_llms_num.change(
        fn=update_base_llms_num,
        inputs=[base_llms_num, saved_llm_outputs],
        outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
    )
    
    examples_dummy_textbox.change(
        fn=update_base_llm_dropdown_along_examples,
        inputs=[examples_dummy_textbox],
        outputs=saved_llm_outputs,
    ).then(
        fn=display_llm_output,
        inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
        outputs=selected_base_llm_output,
    )
    
    selected_base_llm_name_dropdown.change(
        fn=display_llm_output,
        inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
        outputs=selected_base_llm_output,
    )
    
    base_llm_outputs_save_button.click(
        fn=save_llm_output,
        inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
        outputs=saved_llm_outputs,
    )
    base_llm_outputs_clear_all_button.click(
        fn=lambda: [{}, ""],
        inputs=[],
        outputs=[saved_llm_outputs, selected_base_llm_output],
    )
    base_llm_outputs_clear_single_button.click(
        fn=lambda: "",
        inputs=[],
        outputs=selected_base_llm_output,
    )
        

    rank_button.click(
        fn=check_save_ranker_inputs,
        inputs=[inst_textbox, input_textbox, saved_llm_outputs],
        outputs=blender_state,
    ).success(
        fn=llms_rank,
        inputs=[inst_textbox, input_textbox, saved_llm_outputs],
        outputs=[saved_rank_outputs],
    ).then(
        fn=display_ranks,
        inputs=[saved_rank_outputs],
        outputs=rank_outputs,
    )
    
    fuse_button.click(
        fn=check_fuser_inputs,
        inputs=[blender_state, top_k_for_fuser, saved_rank_outputs],
        outputs=[],
    ).success(
        fn=llms_fuse,
        inputs=[blender_state, top_k_for_fuser, saved_rank_outputs],
        outputs=[saved_fuse_outputs],
    ).then(
        fn=display_fuser_output,
        inputs=[saved_fuse_outputs],
        outputs=fuser_outputs,
    )
    
    clear_button.click(
        fn=lambda: ["", "", {}, []],
        inputs=[],
        outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
    )
    
        
    

demo.queue(max_size=20).launch()