loubnabnl HF staff commited on
Commit
3925884
1 Parent(s): 06fe7ef

add api endpoints and dropdown for models (#24)

Browse files

- add api endpoints and dropdown for models (7af0e86df124c48b004a2ba74acd084b223b18fa)
- update readme to mention both models (ad689e69933185f3256c3e4e5cec1e4d5e6823a9)

Files changed (1) hide show
  1. app.py +61 -48
app.py CHANGED
@@ -11,7 +11,7 @@ from share_btn import community_icon_html, loading_icon_html, share_js, share_bt
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder/"
14
-
15
 
16
  FIM_PREFIX = "<fim_prefix>"
17
  FIM_MIDDLE = "<fim_middle>"
@@ -77,10 +77,12 @@ client = Client(
77
  API_URL,
78
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
79
  )
80
-
 
 
81
 
82
  def generate(
83
- prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0
84
  ):
85
 
86
  temperature = float(temperature)
@@ -106,7 +108,10 @@ def generate(
106
  raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
107
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
108
 
109
- stream = client.generate_stream(prompt, **generate_kwargs)
 
 
 
110
 
111
  if fim_mode:
112
  output = prefix
@@ -160,7 +165,7 @@ css += share_btn_css + monospace_css + custom_output_css + ".gradio-container {c
160
  description = """
161
  <div style="text-align: center;">
162
  <h1 style='color: black;'> 💫 StarCoder<span style='color: #e6b800;'> - </span>Playground 🪐</h1>
163
- <p style='color: black;'>This is a demo to generate code with <a href="https://huggingface.co/bigcode/starcoder" style='color: #e6b800;'>StarCoder</a>, a 15B parameter model for code generation in 86 programming languages.</p>
164
  </div>
165
  """
166
  disclaimer = """⚠️<b>Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within.</b>\
@@ -178,48 +183,56 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
178
  )
179
  submit = gr.Button("Generate", variant="primary")
180
  output = gr.Code(elem_id="q-output", lines=30)
181
-
182
- with gr.Accordion("Advanced settings", open=False):
183
- with gr.Row():
184
- column_1, column_2 = gr.Column(), gr.Column()
185
- with column_1:
186
- temperature = gr.Slider(
187
- label="Temperature",
188
- value=0.2,
189
- minimum=0.0,
190
- maximum=1.0,
191
- step=0.05,
192
- interactive=True,
193
- info="Higher values produce more diverse outputs",
194
- )
195
- max_new_tokens = gr.Slider(
196
- label="Max new tokens",
197
- value=256,
198
- minimum=0,
199
- maximum=8192,
200
- step=64,
201
- interactive=True,
202
- info="The maximum numbers of new tokens",
203
- )
204
- with column_2:
205
- top_p = gr.Slider(
206
- label="Top-p (nucleus sampling)",
207
- value=0.90,
208
- minimum=0.0,
209
- maximum=1,
210
- step=0.05,
211
- interactive=True,
212
- info="Higher values sample more low-probability tokens",
213
- )
214
- repetition_penalty = gr.Slider(
215
- label="Repetition penalty",
216
- value=1.2,
217
- minimum=1.0,
218
- maximum=2.0,
219
- step=0.05,
220
- interactive=True,
221
- info="Penalize repeated tokens",
222
- )
 
 
 
 
 
 
 
 
223
  gr.Markdown(disclaimer)
224
  with gr.Group(elem_id="share-btn-container"):
225
  community_icon = gr.HTML(community_icon_html, visible=True)
@@ -238,7 +251,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
238
 
239
  submit.click(
240
  generate,
241
- inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty],
242
  outputs=[output],
243
  )
244
  share_button.click(None, [], [], _js=share_js)
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder/"
14
+ API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase/"
15
 
16
  FIM_PREFIX = "<fim_prefix>"
17
  FIM_MIDDLE = "<fim_middle>"
 
77
  API_URL,
78
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
79
  )
80
+ client_base = Client(
81
+ API_URL_BASE, headers={"Authorization": f"Bearer {HF_TOKEN}"},
82
+ )
83
 
84
  def generate(
85
+ prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoder",
86
  ):
87
 
88
  temperature = float(temperature)
 
108
  raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
109
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
110
 
111
+ if version == "StarCoder":
112
+ stream = client.generate_stream(prompt, **generate_kwargs)
113
+ else:
114
+ stream = client_base.generate_stream(prompt, **generate_kwargs)
115
 
116
  if fim_mode:
117
  output = prefix
 
165
  description = """
166
  <div style="text-align: center;">
167
  <h1 style='color: black;'> 💫 StarCoder<span style='color: #e6b800;'> - </span>Playground 🪐</h1>
168
+ <p style='color: black;'>This is a demo to generate code with <a href="https://huggingface.co/bigcode/starcoder" style='color: #e6b800;'>StarCoder</a> and <a href="https://huggingface.co/bigcode/starcoderbase" style='color: #e6b800;'>StarCoderBase</a>, 15B parameter models for code generation in 86 programming languages.</p>
169
  </div>
170
  """
171
  disclaimer = """⚠️<b>Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within.</b>\
 
183
  )
184
  submit = gr.Button("Generate", variant="primary")
185
  output = gr.Code(elem_id="q-output", lines=30)
186
+ with gr.Row():
187
+ with gr.Column():
188
+ with gr.Accordion("Advanced settings", open=False):
189
+ with gr.Row():
190
+ column_1, column_2 = gr.Column(), gr.Column()
191
+ with column_1:
192
+ temperature = gr.Slider(
193
+ label="Temperature",
194
+ value=0.2,
195
+ minimum=0.0,
196
+ maximum=1.0,
197
+ step=0.05,
198
+ interactive=True,
199
+ info="Higher values produce more diverse outputs",
200
+ )
201
+ max_new_tokens = gr.Slider(
202
+ label="Max new tokens",
203
+ value=256,
204
+ minimum=0,
205
+ maximum=8192,
206
+ step=64,
207
+ interactive=True,
208
+ info="The maximum numbers of new tokens",
209
+ )
210
+ with column_2:
211
+ top_p = gr.Slider(
212
+ label="Top-p (nucleus sampling)",
213
+ value=0.90,
214
+ minimum=0.0,
215
+ maximum=1,
216
+ step=0.05,
217
+ interactive=True,
218
+ info="Higher values sample more low-probability tokens",
219
+ )
220
+ repetition_penalty = gr.Slider(
221
+ label="Repetition penalty",
222
+ value=1.2,
223
+ minimum=1.0,
224
+ maximum=2.0,
225
+ step=0.05,
226
+ interactive=True,
227
+ info="Penalize repeated tokens",
228
+ )
229
+ with gr.Column():
230
+ version = gr.Dropdown(
231
+ ["StarCoderBase", "StarCoder"],
232
+ value="StarCoder",
233
+ label="Version",
234
+ info="",
235
+ )
236
  gr.Markdown(disclaimer)
237
  with gr.Group(elem_id="share-btn-container"):
238
  community_icon = gr.HTML(community_icon_html, visible=True)
 
251
 
252
  submit.click(
253
  generate,
254
+ inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version],
255
  outputs=[output],
256
  )
257
  share_button.click(None, [], [], _js=share_js)