import gradio as gr from threading import Thread from open_lm.hf import * from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch from gradio.layouts import Accordion # Define model options MODEL_OPTIONS = { "TRI-ML/DCLM-1B": "TRI-ML/DCLM-1B", "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B" } # Global variables for model and tokenizer current_model = None current_tokenizer = None def load_model(model_name): global current_model, current_tokenizer current_tokenizer = AutoTokenizer.from_pretrained(MODEL_OPTIONS[model_name]) current_model = AutoModelForCausalLM.from_pretrained(MODEL_OPTIONS[model_name]) device = "cuda" if torch.cuda.is_available() else "cpu" current_model = current_model.to(device) return f"Loaded model: {model_name}" def generate( prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ): global current_model, current_tokenizer if current_model is None or current_tokenizer is None: return "Please load a model first." temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device) generate_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=current_tokenizer.eos_token_id ) streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True) generate_kwargs["streamer"] = streamer thread = Thread(target=current_model.generate, kwargs=generate_kwargs) thread.start() # Write the prompt in blue output = "" + prompt + "" for new_text in streamer: if isinstance(new_text, torch.Tensor): new_text = current_tokenizer.decode(new_text) output += new_text yield output thread.join() return output additional_inputs=[ gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] with gr.Blocks() as demo: gr.Markdown( """ # DCLM Text Completion Demo This demo allows you to generate text using a DCLM model. These models are trained to predict the next word in a sequence of text, and can be used to generate text completions, they are not chatbots. First select a model from the dropdown and click "Load Model". Then enter some text in the text box and click "Generate" to see the model's completion. """ ) with gr.Row(): model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model") model_dropdown.select( load_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")] ) text_input = gr.Textbox(lines=3, label="Input Text") text_output = gr.HTML(label="Generated Text") generate_button = gr.Button("Generate") generate_button.click( generate, inputs=[text_input, model_dropdown, *additional_inputs], outputs=[text_output] ) with Accordion(label="Advanced Options", open=False): for input_component in additional_inputs: if not input_component.is_rendered: input_component.render() demo.launch()