File size: 7,808 Bytes
8262f40
5313bd0
0bb8dc5
5ae38c5
 
5313bd0
 
5ae38c5
 
 
425b71f
0bb8dc5
 
 
5ae38c5
 
 
 
 
 
 
 
 
 
 
 
 
 
2aa2690
0bb8dc5
5ae38c5
 
 
 
 
0bb8dc5
5ae38c5
 
 
 
 
 
 
5313bd0
 
 
 
 
 
 
 
 
 
 
 
5ae38c5
5313bd0
 
 
5ae38c5
5313bd0
 
5ae38c5
475d026
 
 
 
 
 
 
 
 
5ae38c5
475d026
5ae38c5
 
0bb8dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae38c5
 
 
 
0bb8dc5
 
 
 
 
 
5ae38c5
0bb8dc5
 
5ae38c5
 
 
 
 
0bb8dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae38c5
0bb8dc5
 
5ae38c5
0bb8dc5
5ae38c5
 
 
0bb8dc5
5ae38c5
 
 
 
0bb8dc5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import spaces
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from open_lm.hf import *
from open_lm.precision import get_autocast

# Define model options
MODEL_OPTIONS = {
    "TRI DCLM-1B": "TRI-ML/DCLM-1B",
    "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B",
    "[IT] TRI DCLM-1B": "TRI-ML/DCLM-1B-IT",
    "[IT] Apple DCLM-Baseline-7B": "mlfoundations/dclm-7b-it",
}

# Global variables for model and tokenizer
current_model = None
current_tokenizer = None

def load_model(model_name):
    global current_model, current_tokenizer
    current_tokenizer = AutoTokenizer.from_pretrained(MODEL_OPTIONS[model_name])
    current_model = AutoModelForCausalLM.from_pretrained(MODEL_OPTIONS[model_name])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    current_model = current_model.to(device)
    return f"Loaded model: {model_name}"

@spaces.GPU
def generate_completion(
    prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    global current_model, current_tokenizer
    
    if current_model is None or current_tokenizer is None:
        return "Please select a model first."

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    
    inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
    autocast = get_autocast("amp_bf16")

    with autocast():
        generate_kwargs = dict(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=current_tokenizer.eos_token_id
        )

        streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
        streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
        generate_kwargs["streamer"] = streamer

        thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
        thread.start()

        output = "<span style='color: blue;'>" + prompt + "</span>"
        for new_text in streamer:
            if isinstance(new_text, torch.Tensor):
                new_text = current_tokenizer.decode(new_text)
            if streamer.stop_signal in new_text:
                output += new_text.split(streamer.stop_signal)[0]
                break
            output += new_text
            yield output

        thread.join()
    return output

def format_prompt(message, history):
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n"
    prompt += f"User: {message}\nAssistant:"
    return prompt

@spaces.GPU
def generate_chat(
    message, chat_history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    global current_model, current_tokenizer
    
    if current_model is None or current_tokenizer is None:
        yield chat_history + [("Error", "Please select a model first.")]
        return

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    
    formatted_prompt = format_prompt(message, chat_history)
    inputs = current_tokenizer(formatted_prompt, return_tensors="pt").to(current_model.device)
    
    generate_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        pad_token_id=current_tokenizer.eos_token_id
    )

    streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
    streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
    generate_kwargs["streamer"] = streamer

    thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
    thread.start()

    new_history = chat_history + [(message, "")]
    for new_text in streamer:
        if isinstance(new_text, torch.Tensor):
            new_text = current_tokenizer.decode(new_text)
        if streamer.stop_signal in new_text:
            new_text = new_text.split(streamer.stop_signal)[0]
            new_history[-1] = (message, new_history[-1][1] + new_text)
            break
        new_history[-1] = (message, new_history[-1][1] + new_text)
        yield new_history

    thread.join()

additional_inputs = [
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # DCLM Demo
        This demo allows you to generate text using DCLM models in two modes: 
        1. Text Completion:
            For non-Instruction-Tuned models, it generates the continuation of the input text.
        2. Chatbot:
            For Instruction-Tuned [IT] models, it generates responses to user messages as a chatbot.
        
        Select a model from the dropdown to start, it might take a few seconds to load. 
        The interface will automatically switch between Text Completion and Chatbot modes based on the selected model.
        """
    )

    with gr.Row():
        model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
        model_status = gr.Textbox(label="Model Status")

    # Text Completion interface
    with gr.Row(visible=False) as completion_interface:
        with gr.Column():
            text_input = gr.Textbox(lines=3, label="Input Text")
            text_output = gr.Markdown(label="Generated Text")
            generate_button = gr.Button("Generate")

    # Chatbot interface
    with gr.Row(visible=False) as chat_interface:
        with gr.Column():
            chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel")
            msg = gr.Textbox(label="Message")
            clear = gr.Button("Clear")

    with gr.Accordion("Advanced Options", open=False):
        for input_component in additional_inputs:
            input_component.render()

    def switch_interface(model_name):
        is_it_model = model_name.startswith("[IT]")
        status = load_model(model_name)
        return (
            gr.Row(visible=not is_it_model),  # completion_interface
            gr.Row(visible=is_it_model),      # chat_interface
            status                            # model_status
        )

    model_dropdown.change(
        switch_interface,
        inputs=[model_dropdown],
        outputs=[completion_interface, chat_interface, model_status]
    )

    generate_button.click(
        generate_completion,
        inputs=[text_input, model_dropdown, *additional_inputs],
        outputs=[text_output]
    )

    msg.submit(generate_chat, [msg, chatbot, *additional_inputs], chatbot)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue().launch()