File size: 7,084 Bytes
ddd83b8
 
 
 
 
 
 
 
f100357
ddd83b8
892aa5d
ddd83b8
19d9cd8
 
 
 
7671053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd83b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d2597
ddd83b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d2597
0a6b310
ddd83b8
 
 
 
 
 
 
 
 
54e3eca
 
ddd83b8
e232fd5
ddd83b8
 
7671053
 
 
 
ddd83b8
 
 
 
54e3eca
 
 
ddd83b8
 
 
 
 
 
 
 
7671053
 
54e3eca
 
ddd83b8
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import mdtex2html
import torch

"""Override Chatbot.postprocess"""


model_path = 'THUDM/BPO'

device = 'cuda:0'

if torch.cuda.is_available():
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, add_prefix_space=True)
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map=device, load_in_8bit=True)
    model = model.eval()


DESCRIPTION = """This Space demonstrates model [BPO](https://huggingface.co/THUDM/BPO), which is built on LLaMA-2-7b-chat.
BPO aims to improve the alignment of LLMs with human preferences by optimizing user prompts.

Feel free to play with it, or duplicate to run generations without a queue! ๐Ÿ”Ž For more details about the BPO model, take a look [at our paper](https://arxiv.org/pdf/2311.04155.pdf).
"""

LICENSE = """
---
As BPO is a fine-tuned version of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/CCCCCC/BPO_demo/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/CCCCCC/BPO_demo/blob/main/USE_POLICY.md).
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"

prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history):

    if input.strip() == "":
        chatbot = [(parse_text(input), parse_text("Please input a valid user prompt. Empty string is not supported."))]
        return chatbot, history

    prompt = prompt_template.format(input)
    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = model.generate(**model_inputs, max_length=max_length, do_sample=True, top_p=top_p, 
                            temperature=temperature, num_beams=1)
    resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()

    optimized_prompt = """Here are several optimized prompts:

====================Stable Optimization====================
"""
    optimized_prompt += resp
    chatbot = [(parse_text(input), parse_text(optimized_prompt))]
    yield chatbot, history

    optimized_prompt += "\n\n====================Aggressive Optimization===================="

    texts = [input] * 5  
    responses = []
    num = 0
    for text in texts:
        num += 1
        seed = torch.seed()
        torch.manual_seed(seed)
        prompt = prompt_template.format(text)
        min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5
        model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
        bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]]
        # eos and \n
        eos_token_ids = [tokenizer.eos_token_id, 13]
        output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length)
        resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip()
        
        optimized_prompt += f"\n{num}. {resp}"

        chatbot = [(parse_text(input), parse_text(optimized_prompt))]
        yield chatbot, history
    # return chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], []

def update_textbox_from_dropdown(selected_example):
    return selected_example

with gr.Blocks(css="sty.css") as demo:
    gr.HTML("""<h1 align="center">Prompt Preference Optimizer</h1>""")

    gr.Markdown(DESCRIPTION)

    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    
    chatbot = gr.Chatbot(label="Prompt Optimization Chatbot")
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                dropdown = gr.Dropdown(["tell me about harry potter", "give me 3 tips to learn English", "write a story about love"], label="Choose an example input")
                user_input = gr.Textbox(show_label=False, placeholder="User Prompt...", lines=5).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.9, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)

    gr.Markdown(LICENSE)
    
    dropdown.change(update_textbox_from_dropdown, dropdown, user_input)
    
    history = gr.State([])

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
                    show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)