openfree commited on
Commit
83ee74c
1 Parent(s): aca4005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -51
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import os
4
- from typing import List, Tuple, Generator
5
- import concurrent.futures
6
 
7
  # Hugging Face 토큰 설정
8
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # 경고 메시지 방지
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
11
  # Available LLM models
@@ -26,18 +24,11 @@ DEFAULT_MODELS = [
26
  "mistralai/Mistral-Nemo-Instruct-2407"
27
  ]
28
 
29
- # Pipeline 초기화
30
- pipes = {}
31
- for model_name in LLM_MODELS.values():
32
- try:
33
- pipes[model_name] = pipeline(
34
- "text-generation",
35
- model=model_name,
36
- token=HF_TOKEN,
37
- device_map="auto"
38
- )
39
- except Exception as e:
40
- print(f"Failed to load model {model_name}: {str(e)}")
41
 
42
  def process_file(file) -> str:
43
  if file is None:
@@ -46,7 +37,15 @@ def process_file(file) -> str:
46
  return file.read().decode('utf-8')
47
  return f"Uploaded file: {file.name}"
48
 
49
- def format_messages(message: str, history: List[Tuple[str, str]], system_message: str) -> List[dict]:
 
 
 
 
 
 
 
 
50
  messages = [{"role": "system", "content": system_message}]
51
 
52
  for user, assistant in history:
@@ -56,35 +55,18 @@ def format_messages(message: str, history: List[Tuple[str, str]], system_message
56
  messages.append({"role": "assistant", "content": assistant})
57
 
58
  messages.append({"role": "user", "content": message})
59
- return messages
60
-
61
- def generate_response(
62
- pipe,
63
- messages: List[dict],
64
- max_tokens: int,
65
- temperature: float,
66
- top_p: float
67
- ) -> Generator[str, None, None]:
68
  try:
69
- formatted_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
70
-
71
- response = pipe(
72
- formatted_prompt,
73
  max_new_tokens=max_tokens,
 
74
  temperature=temperature,
75
  top_p=top_p,
76
- do_sample=True,
77
- pad_token_id=50256,
78
- num_return_sequences=1,
79
- streaming=True
80
- )
81
-
82
- generated_text = ""
83
- for output in response:
84
- new_text = output[0]['generated_text'][len(formatted_prompt):].strip()
85
- generated_text = new_text
86
- yield generated_text
87
-
88
  except Exception as e:
89
  yield f"Error: {str(e)}"
90
 
@@ -99,7 +81,7 @@ def respond_all(
99
  max_tokens: int,
100
  temperature: float,
101
  top_p: float,
102
- ) -> Tuple[Generator[str, None, None], Generator[str, None, None], Generator[str, None, None]]:
103
  if file:
104
  file_content = process_file(file)
105
  message = f"{message}\n\nFile content:\n{file_content}"
@@ -107,16 +89,25 @@ def respond_all(
107
  while len(selected_models) < 3:
108
  selected_models.append(selected_models[-1])
109
 
110
- def generate(pipe, history):
111
- messages = format_messages(message, history, system_message)
112
- return generate_response(pipe, messages, max_tokens, temperature, top_p)
 
 
 
 
 
 
 
113
 
114
  return (
115
- generate(pipes[selected_models[0]], history1),
116
- generate(pipes[selected_models[1]], history2),
117
- generate(pipes[selected_models[2]], history3),
118
  )
119
 
 
 
120
  css = """
121
  footer {
122
  visibility: hidden;
@@ -126,6 +117,7 @@ footer {
126
 
127
 
128
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
 
129
  with gr.Row():
130
  model_choices = gr.Checkboxgroup(
131
  choices=list(LLM_MODELS.values()),
@@ -212,7 +204,7 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
212
  )
213
 
214
  if __name__ == "__main__":
215
- # Hugging Face 토큰이 설정되어 있는지 확인
216
  if not HF_TOKEN:
217
  print("Warning: HF_TOKEN environment variable is not set")
218
- demo.launch()
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  import os
4
+ from typing import List, Tuple
 
5
 
6
  # Hugging Face 토큰 설정
 
7
  HF_TOKEN = os.getenv("HF_TOKEN")
8
 
9
  # Available LLM models
 
24
  "mistralai/Mistral-Nemo-Instruct-2407"
25
  ]
26
 
27
+ # Initialize clients with token
28
+ clients = {
29
+ model: InferenceClient(model, token=HF_TOKEN)
30
+ for model in LLM_MODELS.values()
31
+ }
 
 
 
 
 
 
 
32
 
33
  def process_file(file) -> str:
34
  if file is None:
 
37
  return file.read().decode('utf-8')
38
  return f"Uploaded file: {file.name}"
39
 
40
+ def respond_single(
41
+ client,
42
+ message: str,
43
+ history: List[Tuple[str, str]],
44
+ system_message: str,
45
+ max_tokens: int,
46
+ temperature: float,
47
+ top_p: float,
48
+ ):
49
  messages = [{"role": "system", "content": system_message}]
50
 
51
  for user, assistant in history:
 
55
  messages.append({"role": "assistant", "content": assistant})
56
 
57
  messages.append({"role": "user", "content": message})
58
+
59
+ response = ""
 
 
 
 
 
 
 
60
  try:
61
+ for msg in client.text_generation(
62
+ prompt=message,
 
 
63
  max_new_tokens=max_tokens,
64
+ stream=True,
65
  temperature=temperature,
66
  top_p=top_p,
67
+ ):
68
+ response += msg
69
+ yield response
 
 
 
 
 
 
 
 
 
70
  except Exception as e:
71
  yield f"Error: {str(e)}"
72
 
 
81
  max_tokens: int,
82
  temperature: float,
83
  top_p: float,
84
+ ):
85
  if file:
86
  file_content = process_file(file)
87
  message = f"{message}\n\nFile content:\n{file_content}"
 
89
  while len(selected_models) < 3:
90
  selected_models.append(selected_models[-1])
91
 
92
+ def generate(client, history):
93
+ return respond_single(
94
+ client,
95
+ message,
96
+ history,
97
+ system_message,
98
+ max_tokens,
99
+ temperature,
100
+ top_p,
101
+ )
102
 
103
  return (
104
+ generate(clients[selected_models[0]], history1),
105
+ generate(clients[selected_models[1]], history2),
106
+ generate(clients[selected_models[2]], history3),
107
  )
108
 
109
+
110
+
111
  css = """
112
  footer {
113
  visibility: hidden;
 
117
 
118
 
119
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
120
+
121
  with gr.Row():
122
  model_choices = gr.Checkboxgroup(
123
  choices=list(LLM_MODELS.values()),
 
204
  )
205
 
206
  if __name__ == "__main__":
 
207
  if not HF_TOKEN:
208
  print("Warning: HF_TOKEN environment variable is not set")
209
+ demo.launch()
210
+