Kims12 commited on
Commit
c231a45
·
verified ·
1 Parent(s): ed9da8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -27
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import os
@@ -5,63 +6,91 @@ import os
5
  # Cohere Command R+ 모델 ID 정의
6
  COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"
7
 
8
- # Hugging Face API 토큰을 환경 변수에서 가져옴
9
- def get_client():
 
 
 
10
  hf_token = os.getenv("HF_TOKEN")
11
  if not hf_token:
12
  raise ValueError("HuggingFace API 토큰이 필요합니다.")
13
- return InferenceClient(COHERE_MODEL, token=hf_token)
14
 
15
- def generate_blog(tone: str, ref1: str, ref2: str, ref3: str):
16
- """
17
- Cohere Command R+ 모델을 사용하여 블로그 글을 생성하는 함수
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  """
19
- system_message = f"""반드시 한글로 답변할 것.
20
- 너는 최고의 비서이며 요청에 따라 주어진 말투를 사용하여 블로그를 작성한다.
21
- 말투: {tone}.
22
  """
23
- question = f"참조글 1: {ref1}\n참조글 2: {ref2}\n참조글 3: {ref3}\n블로그 글을 생성하라."
 
 
 
 
 
 
24
 
25
  try:
26
- client = get_client()
27
- response = client.chat_completion(
28
- messages=[
29
  {"role": "system", "content": system_message},
30
  {"role": "user", "content": question}
31
  ],
32
- max_tokens=4000,
33
- temperature=0.7,
34
- top_p=0.95
35
  )
36
- return response.choices[0].message.content
 
37
  except Exception as e:
38
  return f"오류가 발생했습니다: {str(e)}"
39
 
40
- # Gradio 인터페이스 구축
41
  with gr.Blocks() as demo:
42
  gr.Markdown("# 블로그 생성기")
43
 
44
  with gr.Row():
45
  tone = gr.Radio(
46
- label="말투 바꾸기",
47
  choices=["친근하게", "일반적인", "전문적인"],
48
- value="일반적인"
 
49
  )
50
 
51
- ref1 = gr.Textbox(label="참조글 1", lines=3)
52
- ref2 = gr.Textbox(label="참조글 2", lines=3)
53
- ref3 = gr.Textbox(label="참조글 3", lines=3)
54
 
55
  output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)
56
 
57
- generate_button = gr.Button("생성하기")
 
 
 
 
 
 
 
 
 
 
58
 
59
  generate_button.click(
60
- fn=generate_blog,
61
- inputs=[tone, ref1, ref2, ref3],
62
  outputs=output
63
  )
64
 
65
  if __name__ == "__main__":
66
  demo.launch()
67
-
 
1
+ # app.py
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
  import os
 
6
  # Cohere Command R+ 모델 ID 정의
7
  COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"
8
 
9
+ def get_client(model_name):
10
+ """
11
+ 모델 이름에 맞춰 InferenceClient 생성.
12
+ 토큰은 환경 변수에서 가져옴.
13
+ """
14
  hf_token = os.getenv("HF_TOKEN")
15
  if not hf_token:
16
  raise ValueError("HuggingFace API 토큰이 필요합니다.")
 
17
 
18
+ if model_name == "Cohere Command R+":
19
+ model_id = COHERE_MODEL
20
+ else:
21
+ raise ValueError("유효하지 않은 모델 이름입니다.")
22
+ return InferenceClient(model_id, token=hf_token)
23
+
24
+ def respond_cohere_qna(
25
+ tone: str,
26
+ reference1: str,
27
+ reference2: str,
28
+ reference3: str,
29
+ system_message: str,
30
+ max_tokens: int,
31
+ temperature: float,
32
+ top_p: float
33
+ ):
34
  """
35
+ Cohere Command R+ 모델을 이용해 블로그 생성 함수.
 
 
36
  """
37
+ model_name = "Cohere Command R+"
38
+ try:
39
+ client = get_client(model_name)
40
+ except ValueError as e:
41
+ return f"오류: {str(e)}"
42
+
43
+ question = f"말투: {tone} \n\n 참조글1: {reference1} \n\n 참조글2: {reference2} \n\n 참조글3: {reference3}"
44
 
45
  try:
46
+ response_full = client.chat_completion(
47
+ [
 
48
  {"role": "system", "content": system_message},
49
  {"role": "user", "content": question}
50
  ],
51
+ max_tokens=max_tokens,
52
+ temperature=temperature,
53
+ top_p=top_p,
54
  )
55
+ assistant_message = response_full.choices[0].message.content
56
+ return assistant_message
57
  except Exception as e:
58
  return f"오류가 발생했습니다: {str(e)}"
59
 
60
+ # Gradio UI 설정
61
  with gr.Blocks() as demo:
62
  gr.Markdown("# 블로그 생성기")
63
 
64
  with gr.Row():
65
  tone = gr.Radio(
 
66
  choices=["친근하게", "일반적인", "전문적인"],
67
+ value="일반적인",
68
+ label="말투바꾸기"
69
  )
70
 
71
+ reference1 = gr.Textbox(label="참조글 1", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 1")
72
+ reference2 = gr.Textbox(label="참조글 2", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 2")
73
+ reference3 = gr.Textbox(label="참조글 3", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 3")
74
 
75
  output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)
76
 
77
+ with gr.Accordion("고급 설정 (Cohere)", open=False):
78
+ system_message = gr.Textbox(
79
+ value="""반드시 한글로 답변할 것.\n너는 블로그 작성을 도와주는 비서이다.\n사용자의 요구사항을 정확히 반영하여 작성하라.""",
80
+ label="System Message",
81
+ lines=3
82
+ )
83
+ max_tokens = gr.Slider(minimum=100, maximum=5000, value=2000, step=100, label="Max Tokens")
84
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
85
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
86
+
87
+ generate_button = gr.Button("생성")
88
 
89
  generate_button.click(
90
+ fn=respond_cohere_qna,
91
+ inputs=[tone, reference1, reference2, reference3, system_message, max_tokens, temperature, top_p],
92
  outputs=output
93
  )
94
 
95
  if __name__ == "__main__":
96
  demo.launch()