|
import streamlit as st |
|
from huggingface_hub import InferenceClient |
|
|
|
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
def format_prompt(message, history): |
|
prompt = "<s>" |
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
prompt += f"[INST] {message} [/INST]" |
|
return prompt |
|
|
|
def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0): |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
yield output |
|
return output |
|
|
|
|
|
message_input = st.text_input("You:", "") |
|
|
|
|
|
system_prompt_input = st.text_input("System Prompt:", "You are a helpful assistant.") |
|
|
|
|
|
temperature_slider = st.slider("Temperature", 0.0, 1.0, 0.9, key="temperature_slider") |
|
max_new_tokens_slider = st.slider("Max new tokens", 0, 1048, 256, key="max_new_tokens_slider") |
|
top_p_slider = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, key="top_p_slider") |
|
repetition_penalty_slider = st.slider("Repetition penalty", 1.0, 2.0, 1.0, key="repetition_penalty_slider") |
|
|
|
|
|
slider_container = st.container() |
|
|
|
|
|
slider_container.columns([1, 1, 1, 1])[0].slider("Temperature", 0.0, 1.0, 0.9, key="temperature_slider_left") |
|
slider_container.columns([1, 1, 1, 1])[1].slider("Max new tokens", 0, 1048, 256, key="max_new_tokens_slider_left") |
|
slider_container.columns([1, 1, 1, 1])[2].slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, key="top_p_slider_left") |
|
slider_container.columns([1, 1, 1, 1])[3].slider("Repetition penalty", 1.0, 2.0, 1.0, key="repetition_penalty_slider_left") |
|
|
|
|
|
if st.button("Generate"): |
|
|
|
history = [] |
|
|
|
|
|
output = generate(message_input, history, system_prompt_input, temperature=temperature_slider, max_new_tokens=max_new_tokens_slider, top_p=top_p_slider, repetition_penalty=repetition_penalty_slider) |
|
|
|
|
|
st.write("Assistant:", output) |
|
|
|
|
|
history.append((message_input, output)) |