aixsatoshi's picture
Update app.py
0e40292 verified
raw
history blame
4.61 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "elyza/Llama-3-ELYZA-JP-8B"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODELS.split("/")[-1]
TITLE = "<h1><center>Llama-3-ELYZA-JP-8B Chat webui</center></h1>"
DESCRIPTION = f"""
<h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
<center>
<p>Llama-3-Elyza-JA-8B is the large language model built by Elyza.
<br>
Feel free to test without log.
</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
.chatbox .messages .message.user {
background-color: #e1f5fe;
}
.chatbox .messages .message.bot {
background-color: #eeeeee;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
#print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id = [128001, 128009],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=500)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
examples=[
["語彙の勉強を助けてください: 空欄を埋める問題のための例文を作成してください, またそのための選択肢も生成してください"],
["子供の夏休みの自由研究のための、5つのアイデアと、その手法を簡潔に教えてください。"],
["パズルゲームのスクリプト作成のためにアドバイスお願いします"],
["マークダウン記法にて、ブロック崩しのゲーム作成の教科書作成してください"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()