Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import logging as log | |
from typing import Generator | |
import gradio as gr | |
from gradio.themes.utils import sizes | |
from text_generation import Client | |
from src.request import StarCoderRequest, StarCoderRequestConfig | |
# todo: remove and replace by the actual js file instead | |
from src.share_btn import (share_js) | |
from src.utils import ( | |
get_file_as_string, | |
get_sections, | |
get_url_from_env_or_default_path, | |
preview | |
) | |
from constants import ( | |
FIM_MIDDLE, | |
FIM_PREFIX, | |
FIM_SUFFIX, | |
END_OF_TEXT, | |
MIN_TEMPERATURE, | |
) | |
from settings import ( | |
DEFAULT_PORT, | |
DEFAULT_STARCODER_API_PATH, | |
DEFAULT_STARCODER_BASE_API_PATH, | |
) | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# Gracefully exit the app if the HF_TOKEN is not set, | |
# printing to system `errout` the error (instead of raising an exception) | |
# and the expected behavior | |
if not HF_TOKEN: | |
ERR_MSG = """ | |
Please set the HF_TOKEN environment variable with your Hugging Face API token. | |
You can get one by signing up at https://huggingface.co/join and then visiting | |
https://huggingface.co/settings/tokens.""" | |
print(ERR_MSG, file=sys.stderr) | |
# gr.errors.GradioError(ERR_MSG) | |
# gr.close_all(verbose=False) | |
sys.exit(1) | |
API_URL_STAR = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH) | |
API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH) | |
preview("StarCoder Model URL", API_URL_STAR) | |
preview("StarCoderBase Model URL", API_URL_BASE) | |
preview("HF Token", HF_TOKEN, ofuscate=True) | |
# Loads the whole content of the formats.md file | |
# and stores it into the FORMATS variable | |
STATIC_PATH = "static" | |
FORMATS = get_file_as_string("formats.md", path=STATIC_PATH) | |
CSS = get_file_as_string("styles.css", path=STATIC_PATH) | |
community_icon_svg = get_file_as_string("community_icon.svg", path=STATIC_PATH) | |
loading_icon_svg = get_file_as_string("loading_icon.svg", path=STATIC_PATH) | |
# todo: evaluate making STATIC_PATH the default path instead of the current one | |
README = get_file_as_string("README.md") | |
# Slicing the different sections from the README | |
readme_sections = get_sections(README, "---") | |
manifest, description, disclaimer = readme_sections[:3] | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("Rubik"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
text_size=sizes.text_lg, | |
) | |
HEADERS = { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
} | |
client_star = Client(API_URL_STAR, headers=HEADERS) | |
client_base = Client(API_URL_BASE, headers=HEADERS) | |
def get_tokens_collector(request: StarCoderRequest) -> Generator[str, None, None]: | |
model_client = client_star if request.settings.version == "StarCoder" else client_base | |
stream = model_client.generate_stream(request.prompt, **request.settings.kwargs()) | |
for response in stream: | |
# print(response.token.id, response.token.text) | |
# if token.text != END_OF_TEXT: | |
if response.token.id != 0: | |
yield response.token.text | |
def get_tokens_accumulator(request: StarCoderRequest) -> Generator[str, None, None]: | |
# start with the prefix (if in fim_mode) | |
output = request.prefix if request.fim_mode else request.prompt | |
for token in get_tokens_collector(request=request): | |
output += token | |
yield output | |
# after the last token, append the suffix (if in fim_mode) | |
if request.fim_mode: | |
output += request.suffix | |
yield output | |
# Append an extra line at the end | |
yield output + '\n' | |
def get_tokens_linker(request: StarCoderRequest) -> str: | |
return "".join(list(get_tokens_collector(request))) | |
def generate( | |
prompt: str, | |
temperature = 0.9, | |
max_new_tokens = 256, | |
top_p = 0.95, | |
repetition_penalty = 1.0, | |
version = "StarCoder", | |
) -> Generator[str, None, None]: | |
request = StarCoderRequest( | |
prompt=prompt, | |
settings=StarCoderRequestConfig( | |
version=version, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
) | |
yield from get_tokens_accumulator(request) | |
def process_example( | |
prompt: str, | |
temperature = 0.9, | |
max_new_tokens = 256, | |
top_p = 0.95, | |
repetition_penalty = 1.0, | |
version = "StarCoder", | |
) -> Generator[str, None, None]: | |
request = StarCoderRequest( | |
prompt=prompt, | |
settings=StarCoderRequestConfig( | |
version=version, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
) | |
yield from get_tokens_linker(request) | |
# todo: move it into the README too | |
examples = [ | |
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score", | |
"// Returns every other value in the array as a new array.\nfunction everyOther(arr) {", | |
"def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results", | |
] | |
with gr.Blocks(theme=theme, analytics_enabled=False, css=CSS) as demo: | |
with gr.Column(): | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox( | |
placeholder="Enter your code here", | |
label="Code", | |
elem_id="q-input", | |
) | |
submit = gr.Button("Generate", variant="primary") | |
output = gr.Code(elem_id="q-output", lines=30) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Row(): | |
column_1, column_2 = gr.Column(), gr.Column() | |
with column_1: | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=8192, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
with column_2: | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
with gr.Column(): | |
version = gr.Dropdown( | |
["StarCoderBase", "StarCoder"], | |
value="StarCoder", | |
label="Version", | |
info="", | |
) | |
gr.Markdown(disclaimer) | |
with gr.Group(elem_id="share-btn-container"): | |
community_icon = gr.HTML(community_icon_svg, visible=True) | |
loading_icon = gr.HTML(loading_icon_svg, visible=True) | |
share_button = gr.Button( | |
"Share to community", elem_id="share-btn", visible=True | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[instruction], | |
cache_examples=False, | |
fn=process_example, | |
outputs=[output], | |
) | |
gr.Markdown(FORMATS) | |
submit.click( | |
generate, | |
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version], | |
outputs=[output], | |
# preprocess=False, | |
max_batch_size=8, | |
show_progress=True | |
) | |
share_button.click(None, [], [], _js=share_js) | |
demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT) | |