OpenChat_3.6 / app.py
killerbng's picture
Update app.py
f909466 verified
raw history blame
No virus
4.09 kB
import gradio as gr
import os
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">OpenChat 3.6</h1>
</div>
'''
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://raw.githubusercontent.com/imoneoi/openchat/master/assets/logo_new.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">OpenChat 3.6</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
footer {
visibility: hidden
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.6-8b-20240522")
model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.6-8b-20240522", device_map="auto") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_openchat_36(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str:
"""
Generate a streaming response using the openchat-3.6 model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
#print(outputs)
yield "".join(outputs)
# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, show_label=False, layout="panel", avatar_images=(None, "bot.png"), likeable=True, show_copy_button=True)
with gr.Blocks(fill_height=True, css=css, theme="theme-repo/STONE_Theme") as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_openchat_36,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False ),
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()