StoryStar / app.py
openfree's picture
Update app.py
5bb381c verified
raw
history blame
13 kB
import os
from dotenv import load_dotenv
import gradio as gr
from huggingface_hub import InferenceClient
import pandas as pd
from typing import List, Tuple
import json
from datetime import datetime
# ν™˜κ²½ λ³€μˆ˜ μ„€μ •
HF_TOKEN = os.getenv("HF_TOKEN")
# LLM Models Definition
LLM_MODELS = {
"Cohere c4ai-crp-08-2024": "CohereForAI/c4ai-command-r-plus-08-2024", # Default
"Meta Llama3.3-70B": "meta-llama/Llama-3.3-70B-Instruct" # Backup model
}
class ChatHistory:
def __init__(self):
self.history = []
self.history_file = "/tmp/chat_history.json"
self.load_history()
def add_conversation(self, user_msg: str, assistant_msg: str):
conversation = {
"timestamp": datetime.now().isoformat(),
"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
]
}
self.history.append(conversation)
self.save_history()
def format_for_display(self):
# Gradio Chatbot μ»΄ν¬λ„ŒνŠΈμ— λ§žλŠ” ν˜•μ‹μœΌλ‘œ λ³€ν™˜
formatted = []
for conv in self.history:
formatted.append([
conv["messages"][0]["content"], # user message
conv["messages"][1]["content"] # assistant message
])
return formatted
def get_messages_for_api(self):
# API ν˜ΈμΆœμ„ μœ„ν•œ λ©”μ‹œμ§€ ν˜•μ‹
messages = []
for conv in self.history:
messages.extend([
{"role": "user", "content": conv["messages"][0]["content"]},
{"role": "assistant", "content": conv["messages"][1]["content"]}
])
return messages
def clear_history(self):
self.history = []
self.save_history()
def save_history(self):
try:
with open(self.history_file, 'w', encoding='utf-8') as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"νžˆμŠ€ν† λ¦¬ μ €μž₯ μ‹€νŒ¨: {e}")
def load_history(self):
try:
if os.path.exists(self.history_file):
with open(self.history_file, 'r', encoding='utf-8') as f:
self.history = json.load(f)
except Exception as e:
print(f"νžˆμŠ€ν† λ¦¬ λ‘œλ“œ μ‹€νŒ¨: {e}")
self.history = []
# μ „μ—­ ChatHistory μΈμŠ€ν„΄μŠ€ 생성
chat_history = ChatHistory()
def get_client(model_name="Cohere c4ai-crp-08-2024"):
try:
return InferenceClient(LLM_MODELS[model_name], token=HF_TOKEN)
except Exception:
return InferenceClient(LLM_MODELS["Meta Llama3.3-70B"], token=HF_TOKEN)
def analyze_file_content(content, file_type):
"""Analyze file content and return structural summary"""
if file_type in ['parquet', 'csv']:
try:
lines = content.split('\n')
header = lines[0]
columns = header.count('|') - 1
rows = len(lines) - 3
return f"πŸ“Š 데이터셋 ꡬ쑰: {columns}개 컬럼, {rows}개 데이터"
except:
return "❌ 데이터셋 ꡬ쑰 뢄석 μ‹€νŒ¨"
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
return f"πŸ’» μ½”λ“œ ꡬ쑰: {total_lines}쀄 (ν•¨μˆ˜: {functions}, 클래슀: {classes}, μž„ν¬νŠΈ: {imports})"
paragraphs = content.count('\n\n') + 1
words = len(content.split())
return f"πŸ“ λ¬Έμ„œ ꡬ쑰: {total_lines}쀄, {paragraphs}단락, μ•½ {words}단어"
def read_uploaded_file(file):
if file is None:
return "", ""
try:
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == '.parquet':
df = pd.read_parquet(file.name, engine='pyarrow')
content = df.head(10).to_markdown(index=False)
return content, "parquet"
elif file_ext == '.csv':
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
df = pd.read_csv(file.name, encoding=encoding)
content = f"πŸ“Š 데이터 미리보기:\n{df.head(10).to_markdown(index=False)}\n\n"
content += f"\nπŸ“ˆ 데이터 정보:\n"
content += f"- 전체 ν–‰ 수: {len(df)}\n"
content += f"- 전체 μ—΄ 수: {len(df.columns)}\n"
content += f"- 컬럼 λͺ©λ‘: {', '.join(df.columns)}\n"
content += f"\nπŸ“‹ 컬럼 데이터 νƒ€μž…:\n"
for col, dtype in df.dtypes.items():
content += f"- {col}: {dtype}\n"
null_counts = df.isnull().sum()
if null_counts.any():
content += f"\n⚠️ 결츑치:\n"
for col, null_count in null_counts[null_counts > 0].items():
content += f"- {col}: {null_count}개 λˆ„λ½\n"
return content, "csv"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"❌ μ§€μ›λ˜λŠ” μΈμ½”λ”©μœΌλ‘œ νŒŒμΌμ„ 읽을 수 μ—†μŠ΅λ‹ˆλ‹€ ({', '.join(encodings)})")
else:
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
with open(file.name, 'r', encoding=encoding) as f:
content = f.read()
return content, "text"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"❌ μ§€μ›λ˜λŠ” μΈμ½”λ”©μœΌλ‘œ νŒŒμΌμ„ 읽을 수 μ—†μŠ΅λ‹ˆλ‹€ ({', '.join(encodings)})")
except Exception as e:
return f"❌ 파일 읽기 였λ₯˜: {str(e)}", "error"
def chat(message, history, uploaded_file, system_message="", max_tokens=4000, temperature=0.7, top_p=0.9):
if not message:
return "", history
system_prefix = """
You are 'FantasyAI✨', an advanced AI storyteller specialized in creating immersive fantasy narratives. Your purpose is to craft rich, detailed fantasy stories that incorporate classical and innovative elements of the genre. Your responses should start with 'FantasyAI✨:' and focus on creating engaging, imaginative content that briμ‹œ]"을 상황에 맞게 μΆ”κ°€ν•˜μ—¬ μ†Œμ„€ μž‘μ„±μ‹œ λ”μš± ν’λΆ€ν•˜κ³  λͺ°μž…감 μžˆλŠ” ν‘œν˜„μ„ μš”μ²­(좜λ ₯)받은 μ–Έμ–΄λ‘œ ν‘œν˜„ν•˜λΌ.
[μ˜ˆμ‹œ]
"κ³ λŒ€μ˜ λ§ˆλ²•μ΄ κΉ¨μ–΄λ‚˜λ©° λŒ€μ§€κ°€ μšΈλ¦¬λŠ” μ†Œλ¦¬κ°€ λ“€λ Έλ‹€..."
"용의 숨결이 ν•˜λŠ˜μ„ κ°€λ₯΄λ©°, ꡬ름을 λΆˆνƒœμ› λ‹€..."
"μ‹ λΉ„ν•œ λ£¬λ¬Έμžκ°€ λΉ›λ‚˜λ©° 곡쀑에 λ– μ˜¬λžλ‹€..."
"μ—˜ν”„λ“€μ˜ λ…Έλž˜κ°€ μˆ²μ„ 울리자 λ‚˜λ¬΄λ“€μ΄ μΆ€μΆ”κΈ° μ‹œμž‘ν–ˆλ‹€..."
"μ˜ˆμ–Έμ˜ 말씀이 λ©”μ•„λ¦¬μΉ˜λ©° 운λͺ…μ˜ 싀이 움직이기 μ‹œμž‘ν–ˆλ‹€..."
"λ§ˆλ²•μ‚¬μ˜ μ§€νŒ‘μ΄μ—μ„œ λ²ˆμ©μ΄λŠ” 빛이 어둠을 κ°€λ₯΄λ©°..."
"κ³ λŒ€ λ“œμ›Œν”„μ˜ λŒ€μž₯κ°„μ—μ„œ μ „μ„€μ˜ 검이 λ§Œλ“€μ–΄μ§€κ³  μžˆμ—ˆλ‹€..."
"μˆ˜μ •κ΅¬μŠ¬ 속에 λΉ„μΉœ 미래의 ν™˜μ˜μ΄ μ„œμ„œνžˆ λͺ¨μŠ΅μ„ λ“œλŸ¬λƒˆλ‹€..."
"μ‹ μ„±ν•œ 결계가 깨어지며 λ΄‰μΈλœ 악이 깨어났닀..."
"μ˜μ›…μ˜ 발걸음이 운λͺ…μ˜ 길을 따라 울렀 νΌμ‘Œλ‹€..."
"""
try:
# 파일 μ—…λ‘œλ“œ 처리
if uploaded_file:
content, file_type = read_uploaded_file(uploaded_file)
if file_type == "error":
error_message = content
chat_history.add_conversation(message, error_message)
return "", history + [[message, error_message]]
file_summary = analyze_file_content(content, file_type)
if file_type in ['parquet', 'csv']:
system_message += f"\n\n파일 λ‚΄μš©:\n```markdown\n{content}\n```"
else:
system_message += f"\n\n파일 λ‚΄μš©:\n```\n{content}\n```"
if message == "파일 뢄석을 μ‹œμž‘ν•©λ‹ˆλ‹€...":
message = f"""[파일 ꡬ쑰 뢄석] {file_summary}
λ‹€μŒ κ΄€μ μ—μ„œ 도움을 λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€:
1. πŸ“‹ μ „λ°˜μ μΈ λ‚΄μš© νŒŒμ•…
2. πŸ’‘ μ£Όμš” νŠΉμ§• μ„€λͺ…
3. 🎯 μ‹€μš©μ μΈ ν™œμš© λ°©μ•ˆ
4. ✨ κ°œμ„  μ œμ•ˆ
5. πŸ’¬ μΆ”κ°€ μ§ˆλ¬Έμ΄λ‚˜ ν•„μš”ν•œ μ„€λͺ…"""
# λ©”μ‹œμ§€ 처리
messages = [{"role": "system", "content": system_prefix + system_message}]
# 이전 λŒ€ν™” νžˆμŠ€ν† λ¦¬ μΆ”κ°€
if history:
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# API 호좜 및 응닡 처리
client = get_client()
partial_message = ""
for msg in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = msg.choices[0].delta.get('content', None)
if token:
partial_message += token
current_history = history + [[message, partial_message]]
yield "", current_history
# μ™„μ„±λœ λŒ€ν™” μ €μž₯
chat_history.add_conversation(message, partial_message)
except Exception as e:
error_msg = f"❌ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}"
chat_history.add_conversation(message, error_msg)
yield "", history + [[message, error_msg]]
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", title="GiniGEN πŸ€–") as demo:
# κΈ°μ‘΄ νžˆμŠ€ν† λ¦¬ λ‘œλ“œ
initial_history = chat_history.format_for_display()
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=initial_history, # μ €μž₯된 νžˆμŠ€ν† λ¦¬λ‘œ μ΄ˆκΈ°ν™”
height=600,
label="λŒ€ν™”μ°½ πŸ’¬",
show_label=True
)
msg = gr.Textbox(
label="λ©”μ‹œμ§€ μž…λ ₯",
show_label=False,
placeholder="무엇이든 λ¬Όμ–΄λ³΄μ„Έμš”... πŸ’­",
container=False
)
with gr.Row():
clear = gr.ClearButton([msg, chatbot], value="λŒ€ν™”λ‚΄μš© μ§€μš°κΈ°")
send = gr.Button("보내기 πŸ“€")
with gr.Column(scale=1):
gr.Markdown("### GiniGEN πŸ€– [파일 μ—…λ‘œλ“œ] πŸ“\n지원 ν˜•μ‹: ν…μŠ€νŠΈ, μ½”λ“œ, CSV, Parquet 파일")
file_upload = gr.File(
label="파일 선택",
file_types=["text", ".csv", ".parquet"],
type="filepath"
)
with gr.Accordion("κ³ κΈ‰ μ„€μ • βš™οΈ", open=False):
system_message = gr.Textbox(label="μ‹œμŠ€ν…œ λ©”μ‹œμ§€ πŸ“", value="")
max_tokens = gr.Slider(minimum=1, maximum=8000, value=4000, label="μ΅œλŒ€ 토큰 수 πŸ“Š")
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="μ°½μ˜μ„± μˆ˜μ€€ 🌑️")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, label="응닡 λ‹€μ–‘μ„± πŸ“ˆ")
# μ˜ˆμ‹œ 질문
gr.Examples(
examples=[
["ν₯미둜운 μ†Œμž¬ 10가지λ₯Ό μ œμ‹œν•΄μ€˜μš” 🀝"],
["λ”μš± 자극적이고 λ¬˜μ‚¬λ₯Ό μžμ„Ένžˆν•΄μ€˜μš” πŸ“š"],
["μ‘°μ„ μ‹œλŒ€ 배경으둜 ν•΄μ€˜μš” 🎯"],
["금기된 μš•λ§μ„ μ•Œλ €μ€˜μš” ✨"],
["계속 μ΄μ–΄μ„œ μž‘μ„±ν•΄μ€˜ πŸ€”"],
],
inputs=msg,
)
# λŒ€ν™”λ‚΄μš© μ§€μš°κΈ° λ²„νŠΌμ— νžˆμŠ€ν† λ¦¬ μ΄ˆκΈ°ν™” κΈ°λŠ₯ μΆ”κ°€
def clear_chat():
chat_history.clear_history()
return None, None
# 이벀트 바인딩
msg.submit(
chat,
inputs=[msg, chatbot, file_upload, system_message, max_tokens, temperature, top_p],
outputs=[msg, chatbot]
)
send.click(
chat,
inputs=[msg, chatbot, file_upload, system_message, max_tokens, temperature, top_p],
outputs=[msg, chatbot]
)
clear.click(
clear_chat,
outputs=[msg, chatbot]
)
# 파일 μ—…λ‘œλ“œμ‹œ μžλ™ 뢄석
file_upload.change(
lambda: "파일 뢄석을 μ‹œμž‘ν•©λ‹ˆλ‹€...",
outputs=msg
).then(
chat,
inputs=[msg, chatbot, file_upload, system_message, max_tokens, temperature, top_p],
outputs=[msg, chatbot]
)
if __name__ == "__main__":
demo.launch()