luna-playground / app.py
terryyz's picture
Update app.py
3410eb4 verified
import json
import os
import shutil
import requests
import spaces
import torch
import gradio as gr
from huggingface_hub import Repository
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
FORMATS = """## Model Formats
The model is pretrained on code and specifically learns to use APIs from the unknonw libraries.
"""
if not torch.cuda.is_available():
FORMATS += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
HF_TOKEN = os.environ.get("HF_TOKEN", None)
CHECKPOINT_URL = "Salesforce/codegen-350M-mono"
SQLMODEL_PREFIX_URL = "luna-code/sqlmodel-codegen-350M-mono-prefix"
SFEPY_PREFIX_URL = "luna-code/sfepy-codegen-350M-mono-prefix"
MEGENGINE_PREFIX_URL = "luna-code/megengine-codegen-350M-mono-prefix"
MAIN_EVO_PREFIX_URL = "luna-code/codegen-350M-mono-evo-prefix"
SQLMODEL_FFT_URL = "luna-code/sqlmodel-codegen-350M-mono-fft"
SFEPY_FFT_URL = "luna-code/sfepy-codegen-350M-mono-fft"
MEGENGINE_FFT_URL = "luna-code/megengine-codegen-350M-mono-fft"
MAIN_EVO_FFT_URL = "luna-code/codegen-350M-mono-evo-fft"
MAIN_FD_FFT_URL = "luna-code/codegen-350M-mono-fd-fft"
LANGCHAIN_PREFIX_URL = "luna-code/langchain-codegen-350M-mono-prefix"
LLAMAINDEX_PREFIX_URL = "luna-code/llamaindex-codegen-350M-mono-prefix"
DSPY_PREFIX_URL = "luna-code/dspy-codegen-350M-mono-prefix"
CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL, device_map="auto")
sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL, device_map="auto")
sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL, device_map="auto")
megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL, device_map="auto")
main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL, device_map="auto")
sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL, device_map="auto")
sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL, device_map="auto")
megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL, device_map="auto")
main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL, device_map="auto")
main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL, device_map="auto")
langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL, device_map="auto")
llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL, device_map="auto")
dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL, device_map="auto")
cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL, device_map="auto")
model_map = {
"Base": basemodel,
"SQLModel Prefix": sql_prefix,
"SfePy Prefix": sfepy_prefix,
"MegEngine Prefix": megengine_prefix,
"Main Evo Prefix": main_evo_prefix,
"SQLModel FFT": sqlmodel_fft,
"SfePy FFT": sfepy_fft,
"MegEngine FFT": megengine_fft,
"Main Evo FFT": main_evo_fft,
"Main FD FFT": main_fd_fft,
"LangChain Prefix": langchain_prefix,
"LlamaIndex Prefix": llamaindex_prefix,
"DSPy Prefix": dspy_prefix,
"CS Evo Prefix": cs_evo_prefix,
}
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
@spaces.GPU
def generate(
prompt, temperature=0.6, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, library="LangChain", method="Prefix"
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
)
if method == "Base":
model = basemodel
elif method == "Prefix":
model = model_map[library + " Prefix"]
elif method == "Evo Prefix" and library in ["SQLModel", "SfePy", "MegEngine"]:
model = model_map["Main Evo Prefix"]
elif method == "FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
model = model_map[library + " FFT"]
elif method == "Evo FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
model = model_map["Main Evo FFT"]
elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
model = model_map["Main FD FFT"]
elif method == "Evo Prefix" and library in ["LangChain", "LlamaIndex", "DSPy"]:
model = model_map["CS Evo Prefix"]
else:
output = ""
model.to(device)
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
generated_ids = model.generate(**input_ids, **generate_kwargs)
return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
examples = [
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
"// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
"Poor English: She no went to the market. Corrected English:",
"def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results",
]
def process_example(args):
for x in generate(args):
pass
return x
css = ".generating {visibility: hidden}"
monospace_css = """
#q-input textarea {
font-family: monospace, 'Consolas', Courier, monospace;
}
"""
css += share_btn_css + monospace_css + ".gradio-container {color: black}"
description = """
<div style="text-align: center;">
<h1> 🌙 LUNA Models Playground</h1>
</div>
<div style="text-align: left;">
<p>This is a demo to generate text and code with unknown libraries. The supported based model is <a href="Salesforce/codegen-350M-mono" style='color: #e6b800;'>CodeGen-350M-mono</a></p>
<p>The supported libraries are:</p>
<ul>
<li><a href="https://sqlmodel.tiangolo.com" style='color: #e6b800;'>SQLModel</a></li>
<li><a href="https://sfepy.org" style='color: #e6b800;'>SfePy</a></li>
<li><a href="https://megengine.org" style='color: #e6b800;'>MegEngine</a></li>
<li><a href="https://www.langchain.com/" style='color: #e6b800;'>LangChain</a></li>
<li><a href="https://www.llamaindex.ai/" style='color: #e6b800;'>LlamaIndex</a></li>
<li><a href="https://dspy-docs.vercel.app/" style='color: #e6b800;'>DSpy</a></li>
</ul>
<p><b>Please note:</b> These models are not designed for instruction purposes.</p>
</div>
"""
disclaimer = """⚠️<b>Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within.</b>\
<br>**Intended Use**: this app and its [supporting model](bigcode) are provided for demonstration purposes; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](hf.co/bigcode)"""
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
with gr.Column():
gr.Markdown(description)
with gr.Row():
library = gr.Dropdown(
["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "DSPy"],
value="LlamaIndex",
label="Library",
info="Choose a library from the list",
)
with gr.Row():
method = gr.Dropdown(
["Base", "Prefix", "Evo Prefix", "FFT", "Evo FFT", "Full Data FFT"],
value="Prefix",
label="Model",
info="Choose an expert from the list",
)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(
placeholder="Enter your code here",
lines=5,
label="Input",
elem_id="q-input",
)
submit = gr.Button("Generate", variant="primary")
output = gr.Code(elem_id="q-output", lines=30, label="Output")
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
column_1, column_2 = gr.Column(), gr.Column()
with column_1:
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=8192,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
)
with column_2:
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
gr.Markdown(disclaimer)
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=True)
loading_icon = gr.HTML(loading_icon_html, visible=True)
share_button = gr.Button(
"Share to community", elem_id="share-btn", visible=True
)
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output],
)
gr.Markdown(FORMATS)
submit.click(
generate,
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, library, method],
outputs=[output]
)
share_button.click(None, [], [])
demo.queue().launch(debug=True)