Spaces:
Runtime error
Runtime error
File size: 5,974 Bytes
ddd83b8 f100357 ddd83b8 f100357 ddd83b8 d2cf9fe ddd83b8 24d2597 ddd83b8 24d2597 0a6b310 ddd83b8 54e3eca ddd83b8 54e3eca ddd83b8 54e3eca ddd83b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import mdtex2html
import torch
"""Override Chatbot.postprocess"""
model_path = 'THUDM/BPO'
device = 'cuda'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, add_prefix_space=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_8bit=True)
model = model.eval()
prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history):
if input.strip() == "":
chatbot = [(parse_text(input), parse_text("Please input a valid user prompt. Empty string is not supported."))]
return chatbot, history
prompt = prompt_template.format(input)
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
output = model.generate(**model_inputs, max_length=max_length, do_sample=True, top_p=top_p,
temperature=temperature, num_beams=1)
resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()
optimized_prompt = """Here are several optimized prompts:
====================Stable Optimization====================
"""
optimized_prompt += resp
chatbot = [(parse_text(input), parse_text(optimized_prompt))]
yield chatbot, history
optimized_prompt += "\n\n====================Aggressive Optimization===================="
texts = [input] * 5
responses = []
num = 0
for text in texts:
num += 1
seed = torch.seed()
torch.manual_seed(seed)
prompt = prompt_template.format(text)
min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]]
# eos and \n
eos_token_ids = [tokenizer.eos_token_id, 13]
output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length)
resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip()
optimized_prompt += f"\n{num}. {resp}"
chatbot = [(parse_text(input), parse_text(optimized_prompt))]
yield chatbot, history
# return chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
def update_textbox_from_dropdown(selected_example):
return selected_example
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Prompt Preference Optimizer</h1>""")
chatbot = gr.Chatbot(label="Prompt Optimization Chatbot")
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
dropdown = gr.Dropdown(["tell me about harry potter", "give me 3 tips to learn English", "write a story about love"], label="Choose an example input")
user_input = gr.Textbox(show_label=False, placeholder="User Prompt...", lines=5).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.9, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
dropdown.change(update_textbox_from_dropdown, dropdown, user_input)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=False, inbrowser=True) |