ArmelR commited on
Commit
ce01d23
1 Parent(s): a9778e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -13,6 +13,7 @@ from share_btn import community_icon_html, loading_icon_html, share_js, share_bt
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
15
  API_URL_G = "https://api-inference.huggingface.co/models/ArmelR/starcoder-gradio-v0/"
 
16
 
17
  with open("./HHH_prompt_short.txt", "r") as f:
18
  HHH_PROMPT = f.read() + "\n\n"
@@ -53,6 +54,10 @@ client_g = Client(
53
  API_URL_G, headers={"Authorization": f"Bearer {HF_TOKEN}"},
54
  )
55
 
 
 
 
 
56
  def generate(
57
  prompt,
58
  temperature=0.9,
@@ -60,7 +65,7 @@ def generate(
60
  top_p=0.95,
61
  repetition_penalty=1.0,
62
  chat_mode="TA prompt",
63
- version=None,
64
  ):
65
 
66
  temperature = float(temperature)
@@ -90,14 +95,19 @@ def generate(
90
  chat_prompt = prompt + "\n\nAnswer:"
91
  prompt = base_prompt + chat_prompt
92
 
93
- stream = client_g.generate_stream(prompt, **generate_kwargs)
 
 
 
 
 
94
 
95
  output = ""
96
  previous_token = ""
97
 
98
  for response in stream:
99
  if (
100
- (response.token.text in ["Question:", "-----"]
101
  and previous_token in ["\n", "-----"])
102
  or response.token.text == "<|endoftext|>"
103
  ):
@@ -121,12 +131,17 @@ def bot(
121
  top_p=0.95,
122
  repetition_penalty=1.0,
123
  chat_mode=None,
124
- version=None,
125
  ):
126
  # concat history of prompts with answers expect for last empty answer only add prompt
127
- prompt = "\n".join(
128
- [f"Question: {prompt}\n\nAnswer: {answer}" for prompt, answer in history[:-1]] + [f"\nQuestion: {history[-1][0]}"]
129
- )
 
 
 
 
 
130
 
131
  bot_message = generate(
132
  prompt,
@@ -210,12 +225,12 @@ _Note:_ this is an internal chat playground - **please do not share**. The deplo
210
  interactive=True,
211
  info="Penalize repeated tokens",
212
  )
213
- #version = gr.Dropdown(
214
- # ["StarCoderBase", "StarCoder"],
215
- # value="StarCoderBase",
216
- # label="Version",
217
- # info="",
218
- #)
219
  with column_1:
220
  # output = gr.Code(elem_id="q-output")
221
  # add visibl=False and update if chat_mode True
@@ -251,7 +266,7 @@ _Note:_ this is an internal chat playground - **please do not share**. The deplo
251
  user, [instruction, chatbot], [instruction, chatbot], queue=False
252
  ).then(
253
  bot,
254
- [chatbot, temperature, max_new_tokens, top_p, repetition_penalty, chat_mode],
255
  chatbot,
256
  )
257
 
@@ -259,7 +274,7 @@ _Note:_ this is an internal chat playground - **please do not share**. The deplo
259
  user, [instruction, chatbot], [instruction, chatbot], queue=False
260
  ).then(
261
  bot,
262
- [chatbot, temperature, max_new_tokens, top_p, repetition_penalty, chat_mode],
263
  chatbot,
264
  )
265
  clear.click(lambda: None, None, chatbot, queue=False)
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
15
  API_URL_G = "https://api-inference.huggingface.co/models/ArmelR/starcoder-gradio-v0/"
16
+ API_URL_S = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1"
17
 
18
  with open("./HHH_prompt_short.txt", "r") as f:
19
  HHH_PROMPT = f.read() + "\n\n"
 
54
  API_URL_G, headers={"Authorization": f"Bearer {HF_TOKEN}"},
55
  )
56
 
57
+ client_starchat = Client(
58
+ API_URL_S, headers={"Authorization": f"Bearer {HF_TOKEN}"},
59
+ )
60
+
61
  def generate(
62
  prompt,
63
  temperature=0.9,
 
65
  top_p=0.95,
66
  repetition_penalty=1.0,
67
  chat_mode="TA prompt",
68
+ version="StarChat-alpha",
69
  ):
70
 
71
  temperature = float(temperature)
 
95
  chat_prompt = prompt + "\n\nAnswer:"
96
  prompt = base_prompt + chat_prompt
97
 
98
+ if version == "StarCoder-gradio" :
99
+ stream = client_g.generate_stream(prompt, **generate_kwargs)
100
+ elif version == "StarChat-alpha" :
101
+ stream = client_s.generate_stream(prompt, **generate_kwargs)
102
+ else :
103
+ pass
104
 
105
  output = ""
106
  previous_token = ""
107
 
108
  for response in stream:
109
  if (
110
+ (response.token.text in ["Human", "-----", "Question:"]
111
  and previous_token in ["\n", "-----"])
112
  or response.token.text == "<|endoftext|>"
113
  ):
 
131
  top_p=0.95,
132
  repetition_penalty=1.0,
133
  chat_mode=None,
134
+ version="starchat-alpha",
135
  ):
136
  # concat history of prompts with answers expect for last empty answer only add prompt
137
+ if version == "StarCoder-gradio"
138
+ prompt = "\n".join(
139
+ [f"Question: {prompt}\n\nAnswer: {answer}" for prompt, answer in history[:-1]] + [f"\nQuestion: {history[-1][0]}"]
140
+ )
141
+ else :
142
+ prompt = "\n".join(
143
+ [f"Human: {prompt}\n\nAssistant: {answer}" for prompt, answer in history[:-1]] + [f"\nHuman: {history[-1][0]}"]
144
+ )
145
 
146
  bot_message = generate(
147
  prompt,
 
225
  interactive=True,
226
  info="Penalize repeated tokens",
227
  )
228
+ version = gr.Dropdown(
229
+ ["StarCoder-gradio", "StarChat-alpha"],
230
+ value="StarCoderBase",
231
+ label="Version",
232
+ info="",
233
+ )
234
  with column_1:
235
  # output = gr.Code(elem_id="q-output")
236
  # add visibl=False and update if chat_mode True
 
266
  user, [instruction, chatbot], [instruction, chatbot], queue=False
267
  ).then(
268
  bot,
269
+ [chatbot, temperature, max_new_tokens, top_p, repetition_penalty, chat_mode, version],
270
  chatbot,
271
  )
272
 
 
274
  user, [instruction, chatbot], [instruction, chatbot], queue=False
275
  ).then(
276
  bot,
277
+ [chatbot, temperature, max_new_tokens, top_p, repetition_penalty, chat_mode, version],
278
  chatbot,
279
  )
280
  clear.click(lambda: None, None, chatbot, queue=False)