cutechicken commited on
Commit
6360699
β€’
1 Parent(s): 241b26a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -49
app.py CHANGED
@@ -6,6 +6,11 @@ import os
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
 
 
 
 
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
@@ -32,58 +37,97 @@ h3 {
32
  }
33
  """
34
 
35
- # λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
36
- model = AutoModelForCausalLM.from_pretrained(
37
- MODEL_ID,
38
- torch_dtype=torch.bfloat16,
39
- device_map="auto",
40
- )
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
42
 
43
- # 데이터셋 λ‘œλ“œ
44
- dataset = load_dataset("elyza/ELYZA-tasks-100")
45
- print(dataset)
 
46
 
47
- split_name = "train" if "train" in dataset else "test"
48
- examples_list = list(dataset[split_name])
49
- examples = random.sample(examples_list, 50)
50
- example_inputs = [[example['input']] for example in examples]
 
 
 
 
51
 
 
 
 
 
 
 
 
 
 
 
52
  @spaces.GPU
53
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
54
- print(f'message is - {message}')
55
- print(f'history is - {history}')
56
- conversation = []
57
- for prompt, answer in history:
58
- conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
59
- conversation.append({"role": "user", "content": message})
60
-
61
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
62
- inputs = tokenizer(input_ids, return_tensors="pt").to(0)
63
-
64
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
65
-
66
- generate_kwargs = dict(
67
- inputs,
68
- streamer=streamer,
69
- top_k=top_k,
70
- top_p=top_p,
71
- repetition_penalty=penalty,
72
- max_new_tokens=max_new_tokens,
73
- do_sample=True,
74
- temperature=temperature,
75
- eos_token_id=[255001],
76
- )
77
-
78
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
79
- thread.start()
80
 
81
- buffer = ""
82
- for new_text in streamer:
83
- buffer += new_text
84
- yield buffer
85
 
86
- chatbot = gr.Chatbot(height=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  CSS = """
89
  /* 전체 νŽ˜μ΄μ§€ μŠ€νƒ€μΌλ§ */
@@ -184,21 +228,21 @@ with gr.Blocks(css=CSS) as demo:
184
  chatbot=chatbot,
185
  fill_height=True,
186
  theme="soft",
187
- additional_inputs_accordion=gr.Accordion(label="βš™οΈ μ˜΅μ…˜μ…˜", open=False, render=False),
188
  additional_inputs=[
189
  gr.Slider(
190
  minimum=0,
191
  maximum=1,
192
  step=0.1,
193
- value=0.8,
194
  label="μ˜¨λ„",
195
  render=False,
196
  ),
197
  gr.Slider(
198
  minimum=128,
199
- maximum=1000000,
200
  step=1,
201
- value=100000,
202
  label="μ΅œλŒ€ 토큰 수",
203
  render=False,
204
  ),
 
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
9
+ import gc
10
+
11
+ # GPU λ©”λͺ¨λ¦¬ 관리
12
+ torch.cuda.empty_cache()
13
+ gc.collect()
14
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
 
37
  }
38
  """
39
 
40
+ # λ””λ°”μ΄μŠ€ μ„€μ •
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+ # λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ with μ—λŸ¬ 처리
44
+ try:
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ MODEL_ID,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
+ low_cpu_mem_usage=True,
50
+ )
51
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
52
+ except Exception as e:
53
+ print(f"λͺ¨λΈ λ‘œλ”© 쀑 였λ₯˜ λ°œμƒ: {str(e)}")
54
+ raise
55
 
56
+ # 데이터셋 λ‘œλ“œ with μ—λŸ¬ 처리
57
+ try:
58
+ dataset = load_dataset("elyza/ELYZA-tasks-100")
59
+ print(dataset)
60
 
61
+ split_name = "train" if "train" in dataset else "test"
62
+ examples_list = list(dataset[split_name])
63
+ examples = random.sample(examples_list, 50)
64
+ example_inputs = [[example['input']] for example in examples]
65
+ except Exception as e:
66
+ print(f"데이터셋 λ‘œλ”© 쀑 였λ₯˜ λ°œμƒ: {str(e)}")
67
+ examples = []
68
+ example_inputs = []
69
 
70
+ def error_handler(func):
71
+ def wrapper(*args, **kwargs):
72
+ try:
73
+ return func(*args, **kwargs)
74
+ except Exception as e:
75
+ print(f"Error in {func.__name__}: {str(e)}")
76
+ return "μ£„μ†‘ν•©λ‹ˆλ‹€. 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. μž μ‹œ ν›„ λ‹€μ‹œ μ‹œλ„ν•΄μ£Όμ„Έμš”."
77
+ return wrapper
78
+
79
+ @error_handler
80
  @spaces.GPU
81
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
82
+ try:
83
+ print(f'message is - {message}')
84
+ print(f'history is - {history}')
85
+
86
+ # GPU λ©”λͺ¨λ¦¬ 정리
87
+ torch.cuda.empty_cache()
88
+
89
+ conversation = []
90
+ for prompt, answer in history:
91
+ conversation.extend([
92
+ {"role": "user", "content": prompt},
93
+ {"role": "assistant", "content": answer}
94
+ ])
95
+ conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
98
+ inputs = tokenizer(input_ids, return_tensors="pt").to(device)
 
 
99
 
100
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
101
+
102
+ generate_kwargs = dict(
103
+ inputs,
104
+ streamer=streamer,
105
+ top_k=top_k,
106
+ top_p=top_p,
107
+ repetition_penalty=penalty,
108
+ max_new_tokens=max_new_tokens,
109
+ do_sample=True,
110
+ temperature=temperature,
111
+ eos_token_id=[255001],
112
+ )
113
+
114
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
115
+ thread.start()
116
+
117
+ buffer = ""
118
+ for new_text in streamer:
119
+ buffer += new_text
120
+ yield buffer
121
+
122
+ except Exception as e:
123
+ print(f"Stream chat error: {str(e)}")
124
+ yield "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡 생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."
125
+ finally:
126
+ # λ©”λͺ¨λ¦¬ 정리
127
+ torch.cuda.empty_cache()
128
+ gc.collect()
129
+
130
+ chatbot = gr.Chatbot(height=500)
131
 
132
  CSS = """
133
  /* 전체 νŽ˜μ΄μ§€ μŠ€νƒ€μΌλ§ */
 
228
  chatbot=chatbot,
229
  fill_height=True,
230
  theme="soft",
231
+ additional_inputs_accordion=gr.Accordion(label="βš™οΈ μ˜΅μ…˜", open=False, render=False),
232
  additional_inputs=[
233
  gr.Slider(
234
  minimum=0,
235
  maximum=1,
236
  step=0.1,
237
+ value=0.3,
238
  label="μ˜¨λ„",
239
  render=False,
240
  ),
241
  gr.Slider(
242
  minimum=128,
243
+ maximum=8000,
244
  step=1,
245
+ value=4000,
246
  label="μ΅œλŒ€ 토큰 수",
247
  render=False,
248
  ),