File size: 5,974 Bytes
ddd83b8
 
 
 
 
 
 
 
f100357
ddd83b8
f100357
ddd83b8
 
d2cf9fe
ddd83b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d2597
ddd83b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d2597
0a6b310
ddd83b8
 
 
 
 
 
 
 
 
54e3eca
 
ddd83b8
 
 
 
 
 
 
 
54e3eca
 
 
ddd83b8
 
 
 
 
 
 
 
54e3eca
 
ddd83b8
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import mdtex2html
import torch

"""Override Chatbot.postprocess"""


model_path = 'THUDM/BPO'

device = 'cuda'

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, add_prefix_space=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_8bit=True)
model = model.eval()

prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history):

    if input.strip() == "":
        chatbot = [(parse_text(input), parse_text("Please input a valid user prompt. Empty string is not supported."))]
        return chatbot, history

    prompt = prompt_template.format(input)
    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = model.generate(**model_inputs, max_length=max_length, do_sample=True, top_p=top_p, 
                            temperature=temperature, num_beams=1)
    resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()

    optimized_prompt = """Here are several optimized prompts:

====================Stable Optimization====================
"""
    optimized_prompt += resp
    chatbot = [(parse_text(input), parse_text(optimized_prompt))]
    yield chatbot, history

    optimized_prompt += "\n\n====================Aggressive Optimization===================="

    texts = [input] * 5  
    responses = []
    num = 0
    for text in texts:
        num += 1
        seed = torch.seed()
        torch.manual_seed(seed)
        prompt = prompt_template.format(text)
        min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5
        model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
        bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]]
        # eos and \n
        eos_token_ids = [tokenizer.eos_token_id, 13]
        output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length)
        resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip()
        
        optimized_prompt += f"\n{num}. {resp}"

        chatbot = [(parse_text(input), parse_text(optimized_prompt))]
        yield chatbot, history
    # return chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], []

def update_textbox_from_dropdown(selected_example):
    return selected_example

with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Prompt Preference Optimizer</h1>""")

    chatbot = gr.Chatbot(label="Prompt Optimization Chatbot")
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                dropdown = gr.Dropdown(["tell me about harry potter", "give me 3 tips to learn English", "write a story about love"], label="Choose an example input")
                user_input = gr.Textbox(show_label=False, placeholder="User Prompt...", lines=5).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.9, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)

    dropdown.change(update_textbox_from_dropdown, dropdown, user_input)
    
    history = gr.State([])

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
                    show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)