import os import gradio as gr from huggingface_hub import Repository from text_generation import Client # from dialogues import DialogueTemplate from share_btn import (community_icon_html, loading_icon_html, share_btn_css, share_js) HF_TOKEN = os.environ.get("HF_TOKEN", None) API_TOKEN = os.environ.get("API_TOKEN", None) API_URL = os.environ.get("API_URL", None) API_URL = "https://api-inference.huggingface.co/models/timdettmers/guanaco-33b-merged" client = Client( API_URL, headers={"Authorization": f"Bearer {API_TOKEN}"}, ) repo = None def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def has_no_history(chatbot, history): return not chatbot and not history header = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." prompt_template = "### Human: {query}\n ### Assistant:{response}" def generate( user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, ): # Don't return meaningless message when the input is empty if not user_message: print("Empty input") history.append(user_message) past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) if len(past_messages) < 1: prompt = header + prompt_template.format(query=user_message, response="") else: prompt = header for i in range(0, len(past_messages), 2): intermediate_prompt = prompt_template.format(query=past_messages[i]["content"], response=past_messages[i+1]["content"]) print("intermediate: ", intermediate_prompt) prompt = prompt + '\n' + intermediate_prompt prompt = prompt + prompt_template.format(query=user_message, response="") generate_kwargs = { "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, } temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, truncate=999, seed=42, ) stream = client.generate_stream( prompt, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): if response.token.text == '': break if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" return chat, history, user_message, "" examples = [ "A Llama entered in my garden, what should I do?" ] def clear_chat(): return [], [] def process_example(args): for [x, y] in generate(args): pass return [x, y] title = """