Spaces:
Sleeping
Sleeping
File size: 3,791 Bytes
2f0b879 bebd6a0 2f0b879 5a8ff4e bebd6a0 2f0b879 bbc843e bddf825 2f0b879 bddf825 2f0b879 d4dec50 2f0b879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import logging
from huggingface_hub import InferenceClient
import gradio as gr
log_level = os.environ.get("LOG_LEVEL", "WARNING")
logging.basicConfig(encoding='utf-8', level=log_level)
logging.info("Creating Inference Client")
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
def format_prompt(message, history):
"""Formats the prompt for the AI"""
logging.info("Formatting Prompt")
logging.debug("Input Message: %s", message)
logging.debug("Input History: %s", history)
prompt = ""
for user_prompt, bot_response in history:
prompt += f"<s> [INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
"""Generates the text based on user prompt."""
logging.info("Generating Response")
logging.debug("Input Prompt: %s", prompt)
logging.debug("Input History: %s", history)
logging.debug("Input System Prompt: %s", system_prompt)
logging.debug("Input Temperature: %s", temperature)
logging.debug("Input Max New Tokens: %s", max_new_tokens)
logging.debug("Input Top P: %s", top_p)
logging.debug("Input Repetition Penalty: %s", repetition_penalty)
logging.info("Converting Parameters to Correct Type")
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
logging.debug("Temperature: %s", temperature)
logging.debug("Top P: %s", top_p)
logging.info("Creating Generate kwargs")
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
logging.debug("Generate Args: %s", generate_kwargs)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
logging.debug("Prompt: %s", formatted_prompt)
logging.info("Generating Text")
stream = client.text_generation(
formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
logging.info("Creating Output")
output = ""
for response in stream:
output += response.token.text
yield output
logging.debug("Output: %s", output)
return output
additional_inputs = [
gr.Textbox(
label="System Prompt",
max_lines=1,
interactive=True,
),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
examples = []
logging.info("Creating Chat Interface")
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False,
show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="Mixtral Instruct",
examples=examples,
concurrency_limit=20,
).launch(show_api=False) |