RAGOndevice / app.py
cutechicken's picture
Update app.py
1d43078 verified
raw
history blame
7.51 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>์˜จ๋””๋ฐ”์ด์Šค AI(Open LLM ๋ชจ๋ธ)</center></h1>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
.chatbox .messages .message.user {
background-color: #e1f5fe;
}
.chatbox .messages .message.bot {
background-color: #eeeeee;
}
"""
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("elyza/ELYZA-tasks-100")
print(dataset)
split_name = "train" if "train" in dataset else "test"
examples_list = list(dataset[split_name])
examples = random.sample(examples_list, 50)
example_inputs = [[example['input']] for example in examples]
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=500)
CSS = """
/* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* ๋ฉ”์ธ ์ปจํ…Œ์ด๋„ˆ */
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
transform: perspective(1000px) translateZ(0);
transition: all 0.3s ease;
}
/* ์ œ๋ชฉ ์Šคํƒ€์ผ๋ง */
h1 {
color: #2d3436;
font-size: 2.5rem;
text-align: center;
margin-bottom: 2rem;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
transform: perspective(1000px) translateZ(20px);
}
h3 {
text-align: center;
color: #2d3436;
font-size: 1.5rem;
margin: 1rem 0;
}
/* ์ฑ„ํŒ…๋ฐ•์Šค ์Šคํƒ€์ผ๋ง */
.chatbox {
background: white;
border-radius: 15px;
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15);
backdrop-filter: blur(4px);
border: 1px solid rgba(255, 255, 255, 0.18);
padding: 1rem;
margin: 1rem 0;
transform: translateZ(0);
transition: all 0.3s ease;
}
/* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ๋ง */
.chatbox .messages .message.user {
background: linear-gradient(145deg, #e1f5fe, #bbdefb);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
.chatbox .messages .message.bot {
background: linear-gradient(145deg, #f5f5f5, #eeeeee);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
.duplicate-button {
background: linear-gradient(145deg, #24292e, #1a1e22) !important;
color: white !important;
border-radius: 100vh !important;
padding: 0.8rem 1.5rem !important;
box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2) !important;
transition: all 0.3s ease !important;
border: none !important;
cursor: pointer !important;
}
.duplicate-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
}
/* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
"""
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
additional_inputs_accordion=gr.Accordion(label="โš™๏ธ ์˜ต์…˜์…˜", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="์˜จ๋„",
render=False,
),
gr.Slider(
minimum=128,
maximum=1000000,
step=1,
value=100000,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="์ƒ์œ„ ํ™•๋ฅ ",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="์ƒ์œ„ K",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ",
render=False,
),
],
examples=[
["์•„์ด์˜ ์—ฌ๋ฆ„๋ฐฉํ•™ ๊ณผํ•™ ํ”„๋กœ์ ํŠธ๋ฅผ ์œ„ํ•œ 5๊ฐ€์ง€ ์•„์ด๋””์–ด๋ฅผ ์ฃผ์„ธ์š”."],
["๋งˆํฌ๋‹ค์šด์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ธŒ๋ ˆ์ดํฌ์•„์›ƒ ๊ฒŒ์ž„ ๋งŒ๋“ค๊ธฐ ํŠœํ† ๋ฆฌ์–ผ์„ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”."],
["์ดˆ๋Šฅ๋ ฅ์„ ๊ฐ€์ง„ ์ฃผ์ธ๊ณต์˜ SF ์ด์•ผ๊ธฐ ์‹œ๋‚˜๋ฆฌ์˜ค๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”. ๋ณต์„  ์„ค์ •, ํ…Œ๋งˆ์™€ ๋กœ๊ทธ๋ผ์ธ์„ ๋…ผ๋ฆฌ์ ์œผ๋กœ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”"],
["์•„์ด์˜ ์—ฌ๋ฆ„๋ฐฉํ•™ ์ž์œ ์—ฐ๊ตฌ๋ฅผ ์œ„ํ•œ 5๊ฐ€์ง€ ์•„์ด๋””์–ด์™€ ๊ทธ ๋ฐฉ๋ฒ•์„ ๊ฐ„๋‹จํžˆ ์•Œ๋ ค์ฃผ์„ธ์š”."],
["ํผ์ฆ ๊ฒŒ์ž„ ์Šคํฌ๋ฆฝํŠธ ์ž‘์„ฑ์„ ์œ„ํ•œ ์กฐ์–ธ ๋ถ€ํƒ๋“œ๋ฆฝ๋‹ˆ๋‹ค"],
["๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ๋ธ”๋ก ๊นจ๊ธฐ ๊ฒŒ์ž„ ์ œ์ž‘ ๊ต๊ณผ์„œ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”"],
["์‹ค๋ฒ„ ๅทๆŸณ๋ฅผ ์ƒ๊ฐํ•ด์ฃผ์„ธ์š”"],
["์ผ๋ณธ์–ด ๊ด€์šฉ๊ตฌ, ์†๋‹ด์— ๊ด€ํ•œ ์‹œํ—˜ ๋ฌธ์ œ๋ฅผ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”"],
["๋„๋ผ์—๋ชฝ์˜ ๋“ฑ์žฅ์ธ๋ฌผ์„ ์•Œ๋ ค์ฃผ์„ธ์š”"],
["์˜ค์ฝ”๋…ธ๋ฏธ์•ผํ‚ค ๋งŒ๋“œ๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๋ ค์ฃผ์„ธ์š”"],
["๋ฌธ์ œ 9.11๊ณผ 9.9 ์ค‘ ์–ด๋Š ๊ฒƒ์ด ๋” ํฐ๊ฐ€์š”? step by step์œผ๋กœ ๋…ผ๋ฆฌ์ ์œผ๋กœ ์ƒ๊ฐํ•ด์ฃผ์„ธ์š”."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()