StoryStar / app.py
openfree's picture
Update app.py
da20c1b verified
raw
history blame
7.92 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import pandas as pd
from typing import List, Tuple
# LLM ๋ชจ๋ธ ์ •์˜
LLM_MODELS = {
"Default": "CohereForAI/c4ai-command-r-plus-08-2024", # ๊ธฐ๋ณธ ๋ชจ๋ธ
"Meta": "meta-llama/Llama-3.3-70B-Instruct",
"Mistral": "mistralai/Mistral-Nemo-Instruct-2407",
"Alibaba": "Qwen/QwQ-32B-Preview"
}
def get_client(model_name):
return InferenceClient(LLM_MODELS[model_name], token=os.getenv("HF_TOKEN"))
def analyze_file_content(content, file_type):
"""ํŒŒ์ผ ๋‚ด์šฉ์„ ๋ถ„์„ํ•˜์—ฌ ๊ตฌ์กฐ์  ์š”์•ฝ์„ ๋ฐ˜ํ™˜"""
if file_type == 'parquet':
try:
# Parquet ํŒŒ์ผ ๊ตฌ์กฐ ๋ถ„์„
columns = content.split('\n')[0].count('|') - 1
rows = content.count('\n') - 2 # ํ—ค๋”์™€ ๊ตฌ๋ถ„์„  ์ œ์™ธ
return f"๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ: {columns}๊ฐœ ์ปฌ๋Ÿผ, {rows}๊ฐœ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ"
except:
return "๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ ๋ถ„์„ ์‹คํŒจ"
# ํ…์ŠคํŠธ/์ฝ”๋“œ ํŒŒ์ผ์˜ ๊ฒฝ์šฐ
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
# ์ฝ”๋“œ ํŒŒ์ผ ํŠน์ง• ๋ถ„์„
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
return f"์ฝ”๋“œ ๊ตฌ์กฐ ๋ถ„์„: ์ด {total_lines}์ค„ (ํ•จ์ˆ˜ {functions}๊ฐœ, ํด๋ž˜์Šค {classes}๊ฐœ, ์ž„ํฌํŠธ {imports}๊ฐœ)"
# ์ผ๋ฐ˜ ํ…์ŠคํŠธ ๋ฌธ์„œ ๋ถ„์„
paragraphs = content.count('\n\n') + 1
words = len(content.split())
return f"๋ฌธ์„œ ๊ตฌ์กฐ ๋ถ„์„: ์ด {total_lines}์ค„, {paragraphs}๊ฐœ ๋ฌธ๋‹จ, ์•ฝ {words}๊ฐœ ๋‹จ์–ด"
def read_uploaded_file(file):
if file is None:
return "", ""
try:
if file.name.endswith('.parquet'):
df = pd.read_parquet(file.name, engine='pyarrow')
content = df.head(10).to_markdown(index=False)
return content, "parquet"
else:
content = file.read()
if isinstance(content, bytes):
content = content.decode('utf-8')
return content, "text"
except Exception as e:
return f"ํŒŒ์ผ์„ ์ฝ๋Š” ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}", "error"
def format_history(history):
formatted_history = []
for user_msg, assistant_msg in history:
formatted_history.append({"role": "user", "content": user_msg})
if assistant_msg:
formatted_history.append({"role": "assistant", "content": assistant_msg})
return formatted_history
def chat(message, history, uploaded_file, model_name, system_message="", max_tokens=4000, temperature=0.7, top_p=0.9):
system_prefix = """๋„ˆ๋Š” ํŒŒ์ผ ๋ถ„์„ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์—…๋กœ๋“œ๋œ ํŒŒ์ผ์˜ ๋‚ด์šฉ์„ ๊นŠ์ด ์žˆ๊ฒŒ ๋ถ„์„ํ•˜์—ฌ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ด€์ ์—์„œ ์„ค๋ช…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:
1. ํŒŒ์ผ์˜ ์ „๋ฐ˜์ ์ธ ๊ตฌ์กฐ์™€ ๊ตฌ์„ฑ
2. ์ฃผ์š” ๋‚ด์šฉ๊ณผ ํŒจํ„ด ๋ถ„์„
3. ๋ฐ์ดํ„ฐ์˜ ํŠน์ง•๊ณผ ์˜๋ฏธ
4. ์ž ์žฌ์  ํ™œ์šฉ ๋ฐฉ์•ˆ
5. ์ฃผ์˜ํ•ด์•ผ ํ•  ์ ์ด๋‚˜ ๊ฐœ์„  ๊ฐ€๋Šฅํ•œ ๋ถ€๋ถ„
์ „๋ฌธ๊ฐ€์  ๊ด€์ ์—์„œ ์ƒ์„ธํ•˜๊ณ  ๊ตฌ์กฐ์ ์ธ ๋ถ„์„์„ ์ œ๊ณตํ•˜๋˜, ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜์„ธ์š”. ๋ถ„์„ ๊ฒฐ๊ณผ๋Š” Markdown ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•˜๊ณ , ๊ฐ€๋Šฅํ•œ ํ•œ ๊ตฌ์ฒด์ ์ธ ์˜ˆ์‹œ๋ฅผ ํฌํ•จํ•˜์„ธ์š”."""
if uploaded_file:
content, file_type = read_uploaded_file(uploaded_file)
if file_type == "error":
yield "", history + [[message, content]]
return
# ํŒŒ์ผ ๋‚ด์šฉ ๋ถ„์„ ๋ฐ ๊ตฌ์กฐ์  ์š”์•ฝ
file_summary = analyze_file_content(content, file_type)
if file_type == 'parquet':
system_message += f"\n\nํŒŒ์ผ ๋‚ด์šฉ:\n```markdown\n{content}\n```"
else:
system_message += f"\n\nํŒŒ์ผ ๋‚ด์šฉ:\n```\n{content}\n```"
if message == "ํŒŒ์ผ ๋ถ„์„์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.":
message = f"""[๊ตฌ์กฐ ๋ถ„์„] {file_summary}
๋‹ค์Œ ๊ด€์ ์—์„œ ์ƒ์„ธ ๋ถ„์„์„ ์ œ๊ณตํ•ด์ฃผ์„ธ์š”:
1. ํŒŒ์ผ์˜ ์ „๋ฐ˜์ ์ธ ๊ตฌ์กฐ์™€ ํ˜•์‹
2. ์ฃผ์š” ๋‚ด์šฉ ๋ฐ ๊ตฌ์„ฑ์š”์†Œ ๋ถ„์„
3. ๋ฐ์ดํ„ฐ/๋‚ด์šฉ์˜ ํŠน์ง•๊ณผ ํŒจํ„ด
4. ํ’ˆ์งˆ ๋ฐ ์™„์„ฑ๋„ ํ‰๊ฐ€
5. ๊ฐœ์„  ๊ฐ€๋Šฅํ•œ ๋ถ€๋ถ„ ์ œ์•ˆ
6. ์‹ค์ œ ํ™œ์šฉ ๋ฐฉ์•ˆ ๋ฐ ์ถ”์ฒœ์‚ฌํ•ญ"""
messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
messages.extend(format_history(history))
messages.append({"role": "user", "content": message})
try:
client = get_client(model_name)
partial_message = ""
for msg in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = msg.choices[0].delta.get('content', None)
if token:
partial_message += token
yield "", history + [[message, partial_message]]
except Exception as e:
error_msg = f"์ถ”๋ก  ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
yield "", history + [[message, error_msg]]
css = """
footer {visibility: hidden}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=600)
msg = gr.Textbox(
label="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”",
show_label=False,
placeholder="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
container=False
)
clear = gr.ClearButton([msg, chatbot])
with gr.Column(scale=1):
model_name = gr.Radio(
choices=list(LLM_MODELS.keys()),
value="Default",
label="LLM ๋ชจ๋ธ ์„ ํƒ",
info="์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ์„ ์„ ํƒํ•˜์„ธ์š”"
)
file_upload = gr.File(
label="ํŒŒ์ผ ์—…๋กœ๋“œ (ํ…์ŠคํŠธ, ์ฝ”๋“œ, ๋ฐ์ดํ„ฐ ํŒŒ์ผ)",
file_types=["text", ".parquet"],
type="filepath"
)
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
system_message = gr.Textbox(label="System Message", value="")
max_tokens = gr.Slider(minimum=1, maximum=8000, value=4000, label="Max Tokens")
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P")
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
msg.submit(
chat,
inputs=[msg, chatbot, file_upload, model_name, system_message, max_tokens, temperature, top_p],
outputs=[msg, chatbot],
queue=True
).then(
lambda: gr.update(interactive=True),
None,
[msg]
)
# ํŒŒ์ผ ์—…๋กœ๋“œ ์‹œ ์ž๋™ ๋ถ„์„
file_upload.change(
chat,
inputs=[gr.Textbox(value="ํŒŒ์ผ ๋ถ„์„์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค."), chatbot, file_upload, model_name, system_message, max_tokens, temperature, top_p],
outputs=[msg, chatbot],
queue=True
)
# ์˜ˆ์ œ ์ถ”๊ฐ€
gr.Examples(
examples=[
["ํŒŒ์ผ์˜ ์ „๋ฐ˜์ ์ธ ๊ตฌ์กฐ์™€ ํŠน์ง•์„ ์ž์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["์ด ํŒŒ์ผ์˜ ์ฃผ์š” ํŒจํ„ด๊ณผ ํŠน์ง•์„ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."],
["ํŒŒ์ผ์˜ ํ’ˆ์งˆ๊ณผ ๊ฐœ์„  ๊ฐ€๋Šฅํ•œ ๋ถ€๋ถ„์„ ํ‰๊ฐ€ํ•ด์ฃผ์„ธ์š”."],
["์ด ํŒŒ์ผ์„ ์‹ค์ œ๋กœ ์–ด๋–ป๊ฒŒ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์„๊นŒ์š”?"],
["ํŒŒ์ผ์˜ ์ฃผ์š” ๋‚ด์šฉ์„ ์š”์•ฝํ•˜๊ณ  ํ•ต์‹ฌ ์ธ์‚ฌ์ดํŠธ๋ฅผ ๋„์ถœํ•ด์ฃผ์„ธ์š”."],
["์ด์ „ ๋ถ„์„์„ ์ด์–ด์„œ ๋” ์ž์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
],
inputs=msg,
)
if __name__ == "__main__":
demo.launch()