Spaces:
Sleeping
Sleeping
File size: 4,094 Bytes
6bcba58 0fea303 6bcba58 5103369 6bcba58 0ad74aa 6bcba58 6b02e11 454b0bf 0ad74aa 1fc0a63 454b0bf 8861375 e17f0b6 0ad74aa e17f0b6 6bcba58 0ad74aa 2932ae3 6bcba58 0ad74aa 6bcba58 0ad74aa 6bcba58 0fea303 6bcba58 2932ae3 6bcba58 478b5dd 6bcba58 1195994 6bcba58 4924586 6bcba58 96ac3aa 6bcba58 0ad74aa 6bcba58 6b02e11 6bcba58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import os
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">OpenChat 3.6</h1>
</div>
'''
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://raw.githubusercontent.com/imoneoi/openchat/master/assets/logo_new.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">OpenChat 3.6</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
footer {
visibility: hidden
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.6-8b-20240522")
model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.6-8b-20240522", device_map="auto") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_openchat_36(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str:
"""
Generate a streaming response using the openchat-3.6 model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
#print(outputs)
yield "".join(outputs)
# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, show_label=False, layout="panel", avatar_images=(None, "bot.png"), likeable=True, show_copy_button=True)
with gr.Blocks(fill_height=True, css=css, theme="theme-repo/STONE_Theme") as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_openchat_36,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False ),
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|