import datetime import os import random import re from io import StringIO import gradio as gr import pandas as pd from huggingface_hub import upload_file from text_generation import Client HF_TOKEN = os.environ.get("HF_TOKEN", None) API_TOKEN = os.environ.get("API_TOKEN", None) DIALOGUES_DATASET = "HuggingFaceH4/starchat_playground_dialogues" model2endpoint = { "starchat-alpha": "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1", "starchat-beta": "https://api-inference.huggingface.co/models/HuggingFaceH4/starchat-beta", } model_names = list(model2endpoint.keys()) def randomize_seed_generator(): seed = random.randint(0, 1000000) return seed def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs, model): buffer = StringIO() timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") file_name = f"prompts_{timestamp}.jsonl" data = {"model": model, "inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs} pd.DataFrame([data]).to_json(buffer, orient="records", lines=True) # Push to Hub upload_file( path_in_repo=f"{now.date()}/{now.hour}/{file_name}", path_or_fileobj=buffer.getvalue().encode(), repo_id=DIALOGUES_DATASET, token=HF_TOKEN, repo_type="dataset", ) # Clean and rerun buffer.close() def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def wrap_html_code(text): pattern = r"<.*?>" matches = re.findall(pattern, text) if len(matches) > 0: return f"```{text}```" else: return text def has_no_history(chatbot, history): return not chatbot and not history def generate( RETRY_FLAG, model_name, system_message, user_message, chatbot, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, do_save=True, ): client = Client( model2endpoint[model_name], headers={"Authorization": f"Bearer {API_TOKEN}"}, timeout=60, ) # Don't return meaningless message when the input is empty if not user_message: print("Empty input") if not RETRY_FLAG: history.append(user_message) seed = 42 else: seed = randomize_seed_generator() past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) generate_kwargs = { "temperature": temperature, "top_k": top_k, "top_p": top_p, "max_new_tokens": max_new_tokens, } temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, truncate=4096, seed=seed, stop_sequences=["<|end|>"], ) stream = client.generate_stream( system_message, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [ (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip())) for i in range(0, len(history) - 1, 2) ] # chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" if HF_TOKEN and do_save: try: now = datetime.datetime.now() current_time = now.strftime("%Y-%m-%d %H:%M:%S") print(f"[{current_time}] Pushing prompt and completion to the Hub") save_inputs_and_outputs(now, prompt, output, generate_kwargs, model_name) except Exception as e: print(e) return chat, history, user_message, "" def clear_chat(): return [], [] def delete_last_turn(chat, history): if chat and history: chat.pop(-1) history.pop(-1) history.pop(-1) return chat, history def process_example(args): for [x, y] in generate(args): pass return [x, y] # Regenerate response def retry_last_answer( selected_model, system_message, user_message, chat, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, do_save, ): if chat and history: # Removing the previous conversation from chat chat.pop(-1) # Removing bot response from the history history.pop(-1) # Setting up a flag to capture a retry RETRY_FLAG = True # Getting last message from user user_message = history[-1] yield from generate( RETRY_FLAG, selected_model, system_message, user_message, chat, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, do_save, ) with gr.Blocks(analytics_enabled=False) as demo: with gr.Row(): do_save = gr.Checkbox( value=True, label="Store data", info="You agree to the storage of your prompt and generated text for research and development purposes:", ) with gr.Row(): selected_model = gr.Radio(choices=model_names, value=model_names[1], label="Select a model") with gr.Accordion(label="System Prompt", open=False, elem_id="parameters-accordion"): system_message = gr.Textbox( elem_id="system-message", placeholder="Below is a conversation between a human user and a helpful AI coding assistant.", show_label=False, ) with gr.Row(): with gr.Box(): output = gr.Markdown() chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") with gr.Row(): with gr.Column(scale=3): user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") with gr.Row(): send_button = gr.Button("Send", elem_id="send-btn", visible=True) regenerate_button = gr.Button("Regenerate", elem_id="retry-btn", visible=True) delete_turn_button = gr.Button("Delete last turn", elem_id="delete-btn", visible=True) clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=1.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) top_k = gr.Slider( label="Top-k", value=50, minimum=0.0, maximum=100, step=1, interactive=True, info="Sample from a shortlist of top-k tokens", ) top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) max_new_tokens = gr.Slider( label="Max new tokens", value=512, minimum=0, maximum=32000, step=4, interactive=True, info="The maximum numbers of new tokens", ) repetition_penalty = gr.Slider( label="Repetition Penalty", value=1.2, minimum=0.0, maximum=10, step=0.1, interactive=True, info="The parameter for repetition penalty. 1.0 means no penalty.", ) history = gr.State([]) RETRY_FLAG = gr.Checkbox(value=False, visible=False) # To clear out "message" input textbox and use this to regenerate message last_user_message = gr.State("") user_message.submit( generate, inputs=[ RETRY_FLAG, selected_model, system_message, user_message, chatbot, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, do_save, ], outputs=[chatbot, history, last_user_message, user_message], ) send_button.click( generate, inputs=[ RETRY_FLAG, selected_model, system_message, user_message, chatbot, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, do_save, ], outputs=[chatbot, history, last_user_message, user_message], ) regenerate_button.click( retry_last_answer, inputs=[ selected_model, system_message, user_message, chatbot, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, do_save, ], outputs=[chatbot, history, last_user_message, user_message], ) delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history]) clear_chat_button.click(clear_chat, outputs=[chatbot, history]) selected_model.change(clear_chat, outputs=[chatbot, history]) demo.queue(concurrency_count=16).launch(debug=True)