Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
import os
|
@@ -5,63 +6,91 @@ import os
|
|
5 |
# Cohere Command R+ 모델 ID 정의
|
6 |
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
10 |
hf_token = os.getenv("HF_TOKEN")
|
11 |
if not hf_token:
|
12 |
raise ValueError("HuggingFace API 토큰이 필요합니다.")
|
13 |
-
return InferenceClient(COHERE_MODEL, token=hf_token)
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"""
|
19 |
-
|
20 |
-
너는 최고의 비서이며 요청에 따라 주어진 말투를 사용하여 블로그를 작성한다.
|
21 |
-
말투: {tone}.
|
22 |
"""
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
try:
|
26 |
-
|
27 |
-
|
28 |
-
messages=[
|
29 |
{"role": "system", "content": system_message},
|
30 |
{"role": "user", "content": question}
|
31 |
],
|
32 |
-
max_tokens=
|
33 |
-
temperature=
|
34 |
-
top_p=
|
35 |
)
|
36 |
-
|
|
|
37 |
except Exception as e:
|
38 |
return f"오류가 발생했습니다: {str(e)}"
|
39 |
|
40 |
-
# Gradio
|
41 |
with gr.Blocks() as demo:
|
42 |
gr.Markdown("# 블로그 생성기")
|
43 |
|
44 |
with gr.Row():
|
45 |
tone = gr.Radio(
|
46 |
-
label="말투 바꾸기",
|
47 |
choices=["친근하게", "일반적인", "전문적인"],
|
48 |
-
value="일반적인"
|
|
|
49 |
)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
|
55 |
output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
generate_button.click(
|
60 |
-
fn=
|
61 |
-
inputs=[tone,
|
62 |
outputs=output
|
63 |
)
|
64 |
|
65 |
if __name__ == "__main__":
|
66 |
demo.launch()
|
67 |
-
|
|
|
1 |
+
# app.py
|
2 |
import gradio as gr
|
3 |
from huggingface_hub import InferenceClient
|
4 |
import os
|
|
|
6 |
# Cohere Command R+ 모델 ID 정의
|
7 |
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"
|
8 |
|
9 |
+
def get_client(model_name):
|
10 |
+
"""
|
11 |
+
모델 이름에 맞춰 InferenceClient 생성.
|
12 |
+
토큰은 환경 변수에서 가져옴.
|
13 |
+
"""
|
14 |
hf_token = os.getenv("HF_TOKEN")
|
15 |
if not hf_token:
|
16 |
raise ValueError("HuggingFace API 토큰이 필요합니다.")
|
|
|
17 |
|
18 |
+
if model_name == "Cohere Command R+":
|
19 |
+
model_id = COHERE_MODEL
|
20 |
+
else:
|
21 |
+
raise ValueError("유효하지 않은 모델 이름입니다.")
|
22 |
+
return InferenceClient(model_id, token=hf_token)
|
23 |
+
|
24 |
+
def respond_cohere_qna(
|
25 |
+
tone: str,
|
26 |
+
reference1: str,
|
27 |
+
reference2: str,
|
28 |
+
reference3: str,
|
29 |
+
system_message: str,
|
30 |
+
max_tokens: int,
|
31 |
+
temperature: float,
|
32 |
+
top_p: float
|
33 |
+
):
|
34 |
"""
|
35 |
+
Cohere Command R+ 모델을 이용해 블로그 생성 함수.
|
|
|
|
|
36 |
"""
|
37 |
+
model_name = "Cohere Command R+"
|
38 |
+
try:
|
39 |
+
client = get_client(model_name)
|
40 |
+
except ValueError as e:
|
41 |
+
return f"오류: {str(e)}"
|
42 |
+
|
43 |
+
question = f"말투: {tone} \n\n 참조글1: {reference1} \n\n 참조글2: {reference2} \n\n 참조글3: {reference3}"
|
44 |
|
45 |
try:
|
46 |
+
response_full = client.chat_completion(
|
47 |
+
[
|
|
|
48 |
{"role": "system", "content": system_message},
|
49 |
{"role": "user", "content": question}
|
50 |
],
|
51 |
+
max_tokens=max_tokens,
|
52 |
+
temperature=temperature,
|
53 |
+
top_p=top_p,
|
54 |
)
|
55 |
+
assistant_message = response_full.choices[0].message.content
|
56 |
+
return assistant_message
|
57 |
except Exception as e:
|
58 |
return f"오류가 발생했습니다: {str(e)}"
|
59 |
|
60 |
+
# Gradio UI 설정
|
61 |
with gr.Blocks() as demo:
|
62 |
gr.Markdown("# 블로그 생성기")
|
63 |
|
64 |
with gr.Row():
|
65 |
tone = gr.Radio(
|
|
|
66 |
choices=["친근하게", "일반적인", "전문적인"],
|
67 |
+
value="일반적인",
|
68 |
+
label="말투바꾸기"
|
69 |
)
|
70 |
|
71 |
+
reference1 = gr.Textbox(label="참조글 1", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 1")
|
72 |
+
reference2 = gr.Textbox(label="참조글 2", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 2")
|
73 |
+
reference3 = gr.Textbox(label="참조글 3", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 3")
|
74 |
|
75 |
output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)
|
76 |
|
77 |
+
with gr.Accordion("고급 설정 (Cohere)", open=False):
|
78 |
+
system_message = gr.Textbox(
|
79 |
+
value="""반드시 한글로 답변할 것.\n너는 블로그 작성을 도와주는 비서이다.\n사용자의 요구사항을 정확히 반영하여 작성하라.""",
|
80 |
+
label="System Message",
|
81 |
+
lines=3
|
82 |
+
)
|
83 |
+
max_tokens = gr.Slider(minimum=100, maximum=5000, value=2000, step=100, label="Max Tokens")
|
84 |
+
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
|
85 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
|
86 |
+
|
87 |
+
generate_button = gr.Button("생성")
|
88 |
|
89 |
generate_button.click(
|
90 |
+
fn=respond_cohere_qna,
|
91 |
+
inputs=[tone, reference1, reference2, reference3, system_message, max_tokens, temperature, top_p],
|
92 |
outputs=output
|
93 |
)
|
94 |
|
95 |
if __name__ == "__main__":
|
96 |
demo.launch()
|
|