RAGOndevice / app.py
cutechicken's picture
Update app.py
67209ed verified
raw
history blame
8.4 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
torch.cuda.empty_cache()
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
embedding_model = SentenceTransformer('sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')
# ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)
# ๋ฐ์ดํ„ฐ์…‹์˜ ์งˆ๋ฌธ๋“ค์„ ์ž„๋ฒ ๋”ฉ
questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
question_embeddings = embedding_model.encode(questions, convert_to_tensor=True)
def find_relevant_context(query, top_k=3):
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ
query_embedding = embedding_model.encode(query, convert_to_tensor=True)
# ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarities = cosine_similarity(
query_embedding.cpu().numpy().reshape(1, -1),
question_embeddings.cpu().numpy()
)[0]
# ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ๋“ค์˜ ์ธ๋ฑ์Šค
top_indices = np.argsort(similarities)[-top_k:][::-1]
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ถ”์ถœ
relevant_contexts = []
for idx in top_indices:
relevant_contexts.append({
'question': questions[idx],
'answer': wiki_dataset['train']['answer'][idx]
})
return relevant_contexts
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
# RAG: ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ฐพ๊ธฐ
relevant_contexts = find_relevant_context(message)
context_prompt = "\n\n๊ด€๋ จ ์ฐธ๊ณ  ์ •๋ณด:\n"
for ctx in relevant_contexts:
context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n\n"
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๊ตฌ์„ฑ
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# ์ปจํ…์ŠคํŠธ๋ฅผ ํฌํ•จํ•œ ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
final_message = context_prompt + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
conversation.append({"role": "user", "content": final_message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=500)
CSS = """
/* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* ๋ฉ”์ธ ์ปจํ…Œ์ด๋„ˆ */
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
transform: perspective(1000px) translateZ(0);
transition: all 0.3s ease;
}
/* ์ œ๋ชฉ ์Šคํƒ€์ผ๋ง */
h1 {
color: #2d3436;
font-size: 2.5rem;
text-align: center;
margin-bottom: 2rem;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
transform: perspective(1000px) translateZ(20px);
}
h3 {
text-align: center;
color: #2d3436;
font-size: 1.5rem;
margin: 1rem 0;
}
/* ์ฑ„ํŒ…๋ฐ•์Šค ์Šคํƒ€์ผ๋ง */
.chatbox {
background: white;
border-radius: 15px;
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15);
backdrop-filter: blur(4px);
border: 1px solid rgba(255, 255, 255, 0.18);
padding: 1rem;
margin: 1rem 0;
transform: translateZ(0);
transition: all 0.3s ease;
}
/* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ๋ง */
.chatbox .messages .message.user {
background: linear-gradient(145deg, #e1f5fe, #bbdefb);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
.chatbox .messages .message.bot {
background: linear-gradient(145deg, #f5f5f5, #eeeeee);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
.duplicate-button {
background: linear-gradient(145deg, #24292e, #1a1e22) !important;
color: white !important;
border-radius: 100vh !important;
padding: 0.8rem 1.5rem !important;
box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2) !important;
transition: all 0.3s ease !important;
border: none !important;
cursor: pointer !important;
}
.duplicate-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
}
/* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
"""
with gr.Blocks(css=CSS) as demo:
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
additional_inputs_accordion=gr.Accordion(label="โš™๏ธ ์˜ต์…˜์…˜", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="์˜จ๋„",
render=False,
),
gr.Slider(
minimum=128,
maximum=8000,
step=1,
value=4000,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="์ƒ์œ„ ํ™•๋ฅ ",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="์ƒ์œ„ K",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ",
render=False,
),
],
examples=[
["์•„์ด์˜ ์—ฌ๋ฆ„๋ฐฉํ•™ ๊ณผํ•™ ํ”„๋กœ์ ํŠธ๋ฅผ ์œ„ํ•œ 5๊ฐ€์ง€ ์•„์ด๋””์–ด๋ฅผ ์ฃผ์„ธ์š”."],
["๋งˆํฌ๋‹ค์šด์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ธŒ๋ ˆ์ดํฌ์•„์›ƒ ๊ฒŒ์ž„ ๋งŒ๋“ค๊ธฐ ํŠœํ† ๋ฆฌ์–ผ์„ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”."],
["์ดˆ๋Šฅ๋ ฅ์„ ๊ฐ€์ง„ ์ฃผ์ธ๊ณต์˜ SF ์ด์•ผ๊ธฐ ์‹œ๋‚˜๋ฆฌ์˜ค๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”. ๋ณต์„  ์„ค์ •, ํ…Œ๋งˆ์™€ ๋กœ๊ทธ๋ผ์ธ์„ ๋…ผ๋ฆฌ์ ์œผ๋กœ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”"],
["์•„์ด์˜ ์—ฌ๋ฆ„๋ฐฉํ•™ ์ž์œ ์—ฐ๊ตฌ๋ฅผ ์œ„ํ•œ 5๊ฐ€์ง€ ์•„์ด๋””์–ด์™€ ๊ทธ ๋ฐฉ๋ฒ•์„ ๊ฐ„๋‹จํžˆ ์•Œ๋ ค์ฃผ์„ธ์š”."],
["ํผ์ฆ ๊ฒŒ์ž„ ์Šคํฌ๋ฆฝํŠธ ์ž‘์„ฑ์„ ์œ„ํ•œ ์กฐ์–ธ ๋ถ€ํƒ๋“œ๋ฆฝ๋‹ˆ๋‹ค"],
["๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ๋ธ”๋ก ๊นจ๊ธฐ ๊ฒŒ์ž„ ์ œ์ž‘ ๊ต๊ณผ์„œ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”"],
["์‹ค๋ฒ„ ๅทๆŸณ๋ฅผ ์ƒ๊ฐํ•ด์ฃผ์„ธ์š”"],
["์ผ๋ณธ์–ด ๊ด€์šฉ๊ตฌ, ์†๋‹ด์— ๊ด€ํ•œ ์‹œํ—˜ ๋ฌธ์ œ๋ฅผ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”"],
["๋„๋ผ์—๋ชฝ์˜ ๋“ฑ์žฅ์ธ๋ฌผ์„ ์•Œ๋ ค์ฃผ์„ธ์š”"],
["์˜ค์ฝ”๋…ธ๋ฏธ์•ผํ‚ค ๋งŒ๋“œ๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๋ ค์ฃผ์„ธ์š”"],
["๋ฌธ์ œ 9.11๊ณผ 9.9 ์ค‘ ์–ด๋Š ๊ฒƒ์ด ๋” ํฐ๊ฐ€์š”? step by step์œผ๋กœ ๋…ผ๋ฆฌ์ ์œผ๋กœ ์ƒ๊ฐํ•ด์ฃผ์„ธ์š”."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()