File size: 4,557 Bytes
3708810
9ad7e40
3708810
9ad7e40
 
 
3708810
9ad7e40
 
3708810
 
9ad7e40
3708810
 
 
 
 
 
 
 
 
9ad7e40
 
 
3708810
9ad7e40
 
 
 
 
 
 
 
 
3708810
9ad7e40
 
 
 
3708810
 
 
9ad7e40
 
 
 
 
3708810
 
 
 
 
9ad7e40
3708810
 
 
 
 
 
 
 
 
9ad7e40
 
3708810
 
 
9ad7e40
 
3708810
 
9ad7e40
3708810
9ad7e40
 
3708810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ad7e40
 
 
 
 
 
 
3708810
 
 
 
9ad7e40
 
3708810
 
 
 
 
 
 
 
 
 
 
 
9ad7e40
3708810
 
 
9ad7e40
3708810
9ad7e40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from string import Template
from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer


auth_token = os.environ.get("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
    "CarperAI/vicuna-13b-fine-tuned-rlhf",
    use_auth_token=auth_token if auth_token else True,
)
model = AutoModelForCausalLM.from_pretrained(
    "CarperAI/vicuna-13b-fine-tuned-rlhf-fp16",
    torch_dtype=torch.float16,
    device_map="auto",
    offload_folder="./offload",
    low_cpu_mem_usage=True,  # Not required for demo but leave for now
    use_auth_token=auth_token if auth_token else True,
)
model.cuda()
max_context_length = model.config.max_position_embeddings
max_new_tokens = 500


prompt_template = Template("""\
### Human: $human
### Assistant: $bot\
""")


def bot(history):
    # print(f"History:\n`{history}`")
    history = history or []
    # Hack to inject prompt formatting into the history
    prompt_history = []
    for human, bot in history:
        if bot is not None:
            bot = bot.replace("<br>", "\n")
            bot = bot.rstrip()
        prompt_history.append(
            prompt_template.substitute(
                human=human, bot=bot if bot is not None else "")
        )

    messages = "\n\n".join(prompt_history)
    messages = messages.rstrip()
    # print(f"Messages:\n{messages}")

    # Use only the most recent context up to the maximum context length with room left over
    # for the max new tokens
    inputs = tokenizer(messages, return_tensors='pt').to('cuda')
    inputs = {k: v[:, -max_context_length + max_new_tokens:]
              for k, v in inputs.items()}
    if inputs.get("token_type_ids", None) is not None:
        inputs.pop("token_type_ids")
    # print(f"Inputs: {inputs}")
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    # Generate the response
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=1.0,
        num_beams=1,
    )

    # print(f"Generating with kwargs: {generate_kwargs}")
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    partial_text = ""
    for new_text in streamer:
        # Process out the prompt separator. NOTE: we should tune with special tokens for this
        new_text = new_text.replace("<br>", "\n")
        # print(f"New text: `{new_text}`")
        if "###" in new_text:
            new_text = new_text.split("###")[0]
            partial_text += new_text.strip()
            history[-1][1] = partial_text
            break
        else:
            # Filter empty trailing whitespaces
            if new_text.isspace():
                new_text = new_text.strip()
            partial_text += new_text
            history[-1][1] = partial_text
        yield history

    return partial_text


def user(user_message, history):
    return "", history + [[user_message, None]]


with gr.Blocks() as demo:
    gr.Markdown("Chat-RLHF by CarperAI")
    gr.HTML("<a href='https://huggingface.co/CarperAI/vicuna-13b-fine-tuned-rlhf'><code>CarperAI/vicuna-13b-fine-tuned-rlhf</a>")
    gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/chat-rlhf?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')

    chatbot = gr.Chatbot([], elem_id="chatbot").style(height=512)
    state = gr.State([])
    with gr.Row():
        with gr.Column():
            msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box",
                             show_label=False).style(container=False)
        with gr.Column():
            with gr.Row():
                submit = gr.Button("Submit")
                stop = gr.Button("Stop")
                clear = gr.Button("Clear")
    submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
        bot, chatbot, chatbot)
    submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
        bot, chatbot, chatbot)
    stop.click(fn=None, inputs=None, outputs=None, cancels=[
               submit_event, submit_click_event], queue=False)
    clear.click(lambda: None, None, chatbot, queue=True)

demo.queue(max_size=32, concurrency_count=2)
demo.launch(share=True)