RAGOndevice / app.py
cutechicken's picture
Update app.py
0223744 verified
raw
history blame
13.9 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
from typing import List, Tuple
import json
from datetime import datetime
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
torch.cuda.empty_cache()
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)
# TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต
print("TF-IDF ๋ฒกํ„ฐํ™” ์‹œ์ž‘...")
questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
vectorizer = TfidfVectorizer(max_features=1000)
question_vectors = vectorizer.fit_transform(questions)
print("TF-IDF ๋ฒกํ„ฐํ™” ์™„๋ฃŒ")
class ChatHistory:
def __init__(self):
self.history = []
self.history_file = "/tmp/chat_history.json"
self.load_history()
def add_conversation(self, user_msg: str, assistant_msg: str):
conversation = {
"timestamp": datetime.now().isoformat(),
"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
]
}
self.history.append(conversation)
self.save_history()
def format_for_display(self):
formatted = []
for conv in self.history:
formatted.append([
conv["messages"][0]["content"],
conv["messages"][1]["content"]
])
return formatted
def get_messages_for_api(self):
messages = []
for conv in self.history:
messages.extend([
{"role": "user", "content": conv["messages"][0]["content"]},
{"role": "assistant", "content": conv["messages"][1]["content"]}
])
return messages
def clear_history(self):
self.history = []
self.save_history()
def save_history(self):
try:
with open(self.history_file, 'w', encoding='utf-8') as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ ์‹คํŒจ: {e}")
def load_history(self):
try:
if os.path.exists(self.history_file):
with open(self.history_file, 'r', encoding='utf-8') as f:
self.history = json.load(f)
except Exception as e:
print(f"ํžˆ์Šคํ† ๋ฆฌ ๋กœ๋“œ ์‹คํŒจ: {e}")
self.history = []
# ์ „์—ญ ChatHistory ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
chat_history = ChatHistory()
def find_relevant_context(query, top_k=3):
# ์ฟผ๋ฆฌ ๋ฒกํ„ฐํ™”
query_vector = vectorizer.transform([query])
# ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarities = (query_vector * question_vectors.T).toarray()[0]
# ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ๋“ค์˜ ์ธ๋ฑ์Šค
top_indices = np.argsort(similarities)[-top_k:][::-1]
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ถ”์ถœ
relevant_contexts = []
for idx in top_indices:
if similarities[idx] > 0:
relevant_contexts.append({
'question': questions[idx],
'answer': wiki_dataset['train']['answer'][idx],
'similarity': similarities[idx]
})
return relevant_contexts
def analyze_file_content(content, file_type):
"""Analyze file content and return structural summary"""
if file_type in ['parquet', 'csv']:
try:
lines = content.split('\n')
header = lines[0]
columns = header.count('|') - 1
rows = len(lines) - 3
return f"๐Ÿ“Š ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ: {columns}๊ฐœ ์ปฌ๋Ÿผ, {rows}๊ฐœ ๋ฐ์ดํ„ฐ"
except:
return "โŒ ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ ๋ถ„์„ ์‹คํŒจ"
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
return f"๐Ÿ’ป ์ฝ”๋“œ ๊ตฌ์กฐ: {total_lines}์ค„ (ํ•จ์ˆ˜: {functions}, ํด๋ž˜์Šค: {classes}, ์ž„ํฌํŠธ: {imports})"
paragraphs = content.count('\n\n') + 1
words = len(content.split())
return f"๐Ÿ“ ๋ฌธ์„œ ๊ตฌ์กฐ: {total_lines}์ค„, {paragraphs}๋‹จ๋ฝ, ์•ฝ {words}๋‹จ์–ด"
def read_uploaded_file(file):
if file is None:
return "", ""
try:
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == '.parquet':
df = pd.read_parquet(file.name, engine='pyarrow')
content = df.head(10).to_markdown(index=False)
return content, "parquet"
elif file_ext == '.csv':
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
df = pd.read_csv(file.name, encoding=encoding)
content = f"๐Ÿ“Š ๋ฐ์ดํ„ฐ ๋ฏธ๋ฆฌ๋ณด๊ธฐ:\n{df.head(10).to_markdown(index=False)}\n\n"
content += f"\n๐Ÿ“ˆ ๋ฐ์ดํ„ฐ ์ •๋ณด:\n"
content += f"- ์ „์ฒด ํ–‰ ์ˆ˜: {len(df)}\n"
content += f"- ์ „์ฒด ์—ด ์ˆ˜: {len(df.columns)}\n"
content += f"- ์ปฌ๋Ÿผ ๋ชฉ๋ก: {', '.join(df.columns)}\n"
content += f"\n๐Ÿ“‹ ์ปฌ๋Ÿผ ๋ฐ์ดํ„ฐ ํƒ€์ž…:\n"
for col, dtype in df.dtypes.items():
content += f"- {col}: {dtype}\n"
null_counts = df.isnull().sum()
if null_counts.any():
content += f"\nโš ๏ธ ๊ฒฐ์ธก์น˜:\n"
for col, null_count in null_counts[null_counts > 0].items():
content += f"- {col}: {null_count}๊ฐœ ๋ˆ„๋ฝ\n"
return content, "csv"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"โŒ ์ง€์›๋˜๋Š” ์ธ์ฝ”๋”ฉ์œผ๋กœ ํŒŒ์ผ์„ ์ฝ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค ({', '.join(encodings)})")
else:
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
with open(file.name, 'r', encoding=encoding) as f:
content = f.read()
return content, "text"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"โŒ ์ง€์›๋˜๋Š” ์ธ์ฝ”๋”ฉ์œผ๋กœ ํŒŒ์ผ์„ ์ฝ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค ({', '.join(encodings)})")
except Exception as e:
return f"โŒ ํŒŒ์ผ ์ฝ๊ธฐ ์˜ค๋ฅ˜: {str(e)}", "error"
@spaces.GPU
def stream_chat(message: str, history: list, uploaded_file, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
# ํŒŒ์ผ ์—…๋กœ๋“œ ์ฒ˜๋ฆฌ
file_context = ""
if uploaded_file:
content, file_type = read_uploaded_file(uploaded_file)
if content:
file_context = f"\n\n์—…๋กœ๋“œ๋œ ํŒŒ์ผ ๋‚ด์šฉ:\n```\n{content}\n```"
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ฐพ๊ธฐ
relevant_contexts = find_relevant_context(message)
wiki_context = "\n\n๊ด€๋ จ ์œ„ํ‚คํ”ผ๋””์•„ ์ •๋ณด:\n"
for ctx in relevant_contexts:
wiki_context += f"Q: {ctx['question']}\nA: {ctx['answer']}\n์œ ์‚ฌ๋„: {ctx['similarity']:.3f}\n\n"
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๊ตฌ์„ฑ
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
final_message = file_context + wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
conversation.append({"role": "user", "content": final_message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
CSS = """
/* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* ... (์ด์ „์˜ CSS ์Šคํƒ€์ผ ์œ ์ง€) ... */
"""
with gr.Blocks(css=CSS) as demo:
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(
label="๋ฉ”์‹œ์ง€ ์ž…๋ ฅ",
show_label=False,
placeholder="๋ฌด์—‡์ด๋“  ๋ฌผ์–ด๋ณด์„ธ์š”... ๐Ÿ’ญ",
container=False
)
with gr.Row():
clear = gr.ClearButton([msg, chatbot], value="๋Œ€ํ™”๋‚ด์šฉ ์ง€์šฐ๊ธฐ")
send = gr.Button("๋ณด๋‚ด๊ธฐ ๐Ÿ“ค")
with gr.Column(scale=1):
gr.Markdown("### ํŒŒ์ผ ์—…๋กœ๋“œ ๐Ÿ“")
file_upload = gr.File(
label="ํŒŒ์ผ ์„ ํƒ",
file_types=["text", ".csv", ".parquet"],
type="filepath"
)
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ • โš™๏ธ", open=False):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="์˜จ๋„",
)
max_new_tokens = gr.Slider(
minimum=128,
maximum=8000,
step=1,
value=4000,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="์ƒ์œ„ ํ™•๋ฅ ",
)
top_k = gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="์ƒ์œ„ K",
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ",
)
# ์˜ˆ์‹œ ์งˆ๋ฌธ
gr.Examples(
examples=[
["ํ•œ๊ตญ์˜ ์ „ํ†ต ์ ˆ๊ธฐ์™€ 24์ ˆ๊ธฐ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["์šฐ๋ฆฌ๋‚˜๋ผ ์ „ํ†ต ์Œ์‹ ์ค‘ ๊ฑด๊ฐ•์— ์ข‹์€ ๋ฐœํšจ์Œ์‹ 5๊ฐ€์ง€๋ฅผ ์ถ”์ฒœํ•˜๊ณ  ๊ทธ ํšจ๋Šฅ์„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ๋Œ€ํ‘œ์ ์ธ ์‚ฐ๋“ค์„ ์†Œ๊ฐœํ•˜๊ณ , ๊ฐ ์‚ฐ์˜ ํŠน์ง•๊ณผ ๋“ฑ์‚ฐ ์ฝ”์Šค๋ฅผ ์ถ”์ฒœํ•ด์ฃผ์„ธ์š”."],
["์‚ฌ๋ฌผ๋†€์ด์˜ ์•…๊ธฐ ๊ตฌ์„ฑ๊ณผ ์žฅ๋‹จ์— ๋Œ€ํ•ด ์ดˆ๋ณด์ž๋„ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ๊ฑด์ถ•๋ฌผ์— ๋‹ด๊ธด ๊ณผํ•™์  ์›๋ฆฌ๋ฅผ ํ˜„๋Œ€์  ๊ด€์ ์—์„œ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."],
["์กฐ์„ ์‹œ๋Œ€ ๊ณผ๊ฑฐ ์‹œํ—˜ ์ œ๋„๋ฅผ ํ˜„๋Œ€์˜ ์ž…์‹œ ์ œ๋„์™€ ๋น„๊ตํ•˜์—ฌ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ 4๋Œ€ ๊ถ๊ถ์„ ๋น„๊ตํ•˜์—ฌ ๊ฐ๊ฐ์˜ ํŠน์ง•๊ณผ ์—ญ์‚ฌ์  ์˜๋ฏธ๋ฅผ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ๋†€์ด๋ฅผ ํ˜„๋Œ€์ ์œผ๋กœ ์žฌํ•ด์„ํ•˜์—ฌ ์‹ค๋‚ด์—์„œ ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ธ€ ์ฐฝ์ œ ๊ณผ์ •๊ณผ ํ›ˆ๋ฏผ์ •์Œ์˜ ๊ณผํ•™์  ์›๋ฆฌ๋ฅผ ์ƒ์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ์ฐจ ๋ฌธํ™”์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ณ , ๊ณ„์ ˆ๋ณ„๋กœ ์–ด์šธ๋ฆฌ๋Š” ์ „ํ†ต์ฐจ๋ฅผ ์ถ”์ฒœํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ์˜๋ณต์ธ ํ•œ๋ณต์˜ ๊ตฌ์กฐ์™€ ํŠน์ง•์„ ๊ณผํ•™์ , ๋ฏธํ•™์  ๊ด€์ ์—์„œ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ๊ฐ€์˜ฅ ๊ตฌ์กฐ๋ฅผ ๊ธฐํ›„์™€ ํ™˜๊ฒฝ ๊ด€์ ์—์„œ ๋ถ„์„ํ•˜๊ณ , ํ˜„๋Œ€ ๊ฑด์ถ•์— ์ ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์š”์†Œ๋ฅผ ์ œ์•ˆํ•ด์ฃผ์„ธ์š”."]
],
inputs=msg,
)
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
msg.submit(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
send.click(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
# ํŒŒ์ผ ์—…๋กœ๋“œ์‹œ ์ž๋™ ๋ถ„์„
file_upload.change(
lambda: "ํŒŒ์ผ ๋ถ„์„์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค...",
outputs=msg
).then(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
if __name__ == "__main__":
demo.launch()