CCCCCC commited on
Commit
ddd83b8
1 Parent(s): 7139dc1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import gradio as gr
3
+ import mdtex2html
4
+ import torch
5
+
6
+ """Override Chatbot.postprocess"""
7
+
8
+
9
+ # model_path = '/cjl/llm_finetuning/output/prompt_engineer_en_final/bpo_model'
10
+ model_path = 'lmsys/vicuna-7b-v1.5'
11
+
12
+ device = 'cpu'
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, add_prefix_space=True)
15
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device)
16
+ model = model.eval()
17
+
18
+ prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"
19
+
20
+
21
+ def postprocess(self, y):
22
+ if y is None:
23
+ return []
24
+ for i, (message, response) in enumerate(y):
25
+ y[i] = (
26
+ None if message is None else mdtex2html.convert((message)),
27
+ None if response is None else mdtex2html.convert(response),
28
+ )
29
+ return y
30
+
31
+
32
+ gr.Chatbot.postprocess = postprocess
33
+
34
+
35
+ def parse_text(text):
36
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
37
+ lines = text.split("\n")
38
+ lines = [line for line in lines if line != ""]
39
+ count = 0
40
+ for i, line in enumerate(lines):
41
+ if "```" in line:
42
+ count += 1
43
+ items = line.split('`')
44
+ if count % 2 == 1:
45
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
46
+ else:
47
+ lines[i] = f'<br></code></pre>'
48
+ else:
49
+ if i > 0:
50
+ if count % 2 == 1:
51
+ line = line.replace("`", "\`")
52
+ line = line.replace("<", "&lt;")
53
+ line = line.replace(">", "&gt;")
54
+ line = line.replace(" ", "&nbsp;")
55
+ line = line.replace("*", "&ast;")
56
+ line = line.replace("_", "&lowbar;")
57
+ line = line.replace("-", "&#45;")
58
+ line = line.replace(".", "&#46;")
59
+ line = line.replace("!", "&#33;")
60
+ line = line.replace("(", "&#40;")
61
+ line = line.replace(")", "&#41;")
62
+ line = line.replace("$", "&#36;")
63
+ lines[i] = "<br>"+line
64
+ text = "".join(lines)
65
+ return text
66
+
67
+
68
+ def predict(input, chatbot, max_length, top_p, temperature, history):
69
+
70
+ if input.strip() == "":
71
+ chatbot = [(parse_text(input), parse_text("Please input a valid user prompt. Empty string is not supported."))]
72
+ return chatbot, history
73
+
74
+ prompt = prompt_template.format(input)
75
+ model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
76
+ output = model.generate(**model_inputs, max_length=max_length, do_sample=True, top_p=top_p,
77
+ temperature=temperature, num_beams=1)
78
+ resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()
79
+
80
+ optimized_prompt = """Here are several optimized prompts:
81
+
82
+ ====================Stable Optimization====================
83
+ """
84
+ optimized_prompt += resp
85
+ chatbot = [(parse_text(input), parse_text(optimized_prompt))]
86
+ yield chatbot, history
87
+
88
+ optimized_prompt += "\n\n====================Aggressive Optimization===================="
89
+
90
+ texts = [input] * 5
91
+ responses = []
92
+ num = 0
93
+ for text in texts:
94
+ num += 1
95
+ seed = torch.seed()
96
+ torch.manual_seed(seed)
97
+ prompt = prompt_template.format(text)
98
+ min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5
99
+ model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
100
+ bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]]
101
+ # eos and \n
102
+ eos_token_ids = [tokenizer.eos_token_id, 13]
103
+ output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length)
104
+ resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip()
105
+
106
+ optimized_prompt += f"\n{num}. {resp}"
107
+
108
+ chatbot = [(parse_text(input), parse_text(optimized_prompt))]
109
+ yield chatbot, history
110
+
111
+ # for i in responses:
112
+ # print("[Aggressive Optimization] ", i)
113
+
114
+
115
+ def reset_user_input():
116
+ return gr.update(value='')
117
+
118
+
119
+ def reset_state():
120
+ return [], []
121
+
122
+
123
+ with gr.Blocks() as demo:
124
+ gr.HTML("""<h1 align="center">Prompt Preference Optimizer</h1>""")
125
+
126
+ chatbot = gr.Chatbot(label="Prompt Optimization Chatbot")
127
+ with gr.Row():
128
+ with gr.Column(scale=4):
129
+ with gr.Column(scale=12):
130
+ user_input = gr.Textbox(show_label=False, placeholder="User Prompt...", lines=10).style(
131
+ container=False)
132
+ with gr.Column(min_width=32, scale=1):
133
+ submitBtn = gr.Button("Submit", variant="primary")
134
+ with gr.Column(scale=1):
135
+ emptyBtn = gr.Button("Clear History")
136
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
137
+ top_p = gr.Slider(0, 1, value=0.9, step=0.01, label="Top P", interactive=True)
138
+ temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
139
+
140
+ history = gr.State([])
141
+
142
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
143
+ show_progress=True)
144
+ submitBtn.click(reset_user_input, [], [user_input])
145
+
146
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
147
+
148
+ demo.queue().launch(share=False, inbrowser=True)