|
|
|
|
|
import spaces |
|
import gradio as gr |
|
import json |
|
|
|
|
|
|
|
|
|
|
|
from server import submit_weights |
|
|
|
|
|
|
|
from utilities.setup import get_files |
|
from utilities.templates import prompt_template |
|
|
|
|
|
|
|
conf = get_files.json_cfg() |
|
|
|
class update_visibility: |
|
|
|
def textbox_vis(radio): |
|
value = radio |
|
if value == "Hugging Face Hub Dataset": |
|
return gr.Dropdown(visible=bool(1)) |
|
else: |
|
return gr.Dropdown(visible=bool(0)) |
|
|
|
def textbox_button_vis(radio): |
|
value = radio |
|
if value == "Hugging Face Hub Dataset": |
|
return gr.Button(visible=bool(1)) |
|
else: |
|
return gr.Button(visible=bool(0)) |
|
|
|
def upload_vis(radio): |
|
value = radio |
|
if value == "Upload Your Own": |
|
return gr.UploadButton(visible=bool(1)) |
|
else: |
|
return gr.UploadButton(visible=bool(0)) |
|
@spaces.GPU |
|
def train(model_name, |
|
inject_prompt, |
|
dataset_predefined, |
|
peft, |
|
sft, |
|
max_seq_length, |
|
random_seed, |
|
num_epochs, |
|
max_steps, |
|
data_field, |
|
repository, |
|
model_out_name): |
|
"""The model call""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return f"Hello!! Using model: {model_name} with template: {inject_prompt}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("About"): |
|
|
|
gr.Markdown(get_files.load_markdown_file("README.md")) |
|
|
|
with gr.TabItem("Basic Setup"): |
|
gr.Markdown("# Select Model and Input details") |
|
|
|
modelnames = conf['model']['choices'] |
|
model_name = gr.Dropdown(label="Supported Models", |
|
choices=modelnames, |
|
value=modelnames[0]) |
|
|
|
repository = gr.Textbox(label="Your User Name", |
|
value=conf['model']['general']["repository"]) |
|
model_out_name = gr.Textbox(label="Your Model Output Name", |
|
value=conf['model']['general']["model_name"]) |
|
hf_token = gr.Textbox(label="Your Huggingface Token", |
|
type='password', |
|
value='') |
|
|
|
|
|
with gr.TabItem("Upload Data"): |
|
|
|
gr.Markdown("# Dataset Selection and Upload") |
|
|
|
dataset_choice = gr.Radio(label="Choose Dataset", |
|
choices=["Hugging Face Hub Dataset", "Upload Your Own"], |
|
value="Hugging Face Hub Dataset") |
|
dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset", |
|
value='yahma/alpaca-cleaned', |
|
visible=True) |
|
dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)") |
|
|
|
dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)", |
|
file_types=[".csv",".jsonl", ".txt"], |
|
visible=False) |
|
|
|
data_snippet = gr.Markdown() |
|
|
|
|
|
dataset_choice.change(update_visibility.textbox_vis, |
|
dataset_choice, |
|
dataset_predefined) |
|
dataset_choice.change(update_visibility.upload_vis, |
|
dataset_choice, |
|
dataset_uploaded_load) |
|
dataset_choice.change(update_visibility.textbox_button_vis, |
|
dataset_choice, |
|
dataset_predefined_load) |
|
|
|
inject_prompt = gr.Textbox(label="Prompt Template", |
|
value=prompt_template()) |
|
|
|
dataset_predefined_load.click(fn=get_files.predefined_dataset, |
|
inputs=dataset_predefined, |
|
outputs=data_snippet) |
|
|
|
dataset_uploaded_load.click(fn=get_files.uploaded_dataset, |
|
inputs=dataset_uploaded_load, |
|
outputs=data_snippet) |
|
|
|
|
|
with gr.TabItem("Train Model"): |
|
|
|
gr.Markdown("# Model Parameter Selection") |
|
|
|
|
|
data_field = gr.Textbox(label="Dataset Training Field Name", |
|
value=conf['model']['general']["dataset_text_field"]) |
|
max_seq_length = gr.Textbox(label="Maximum sequence length", |
|
value=conf['model']['general']["max_seq_length"]) |
|
random_seed = gr.Textbox(label="Seed", |
|
value=conf['model']['general']["seed"]) |
|
num_epochs = gr.Textbox(label="Training Epochs", |
|
value=conf['model']['general']["num_train_epochs"]) |
|
max_steps = gr.Textbox(label="Maximum steps", |
|
value=conf['model']['general']["max_steps"]) |
|
|
|
|
|
with gr.Accordion("Advanced Tuning", open=False): |
|
|
|
sftparams = conf['model']['general'] |
|
|
|
dict_string = json.dumps(dict(conf['model']['peft']), indent=4) |
|
peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string) |
|
|
|
dict_string = json.dumps(dict(conf['model']['sft']), indent=4) |
|
sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string) |
|
|
|
|
|
|
|
|
|
tune_btn = gr.Button("Start Fine Tuning") |
|
gr.Markdown("### Model Progress") |
|
|
|
output = gr.Textbox(label="Output") |
|
|
|
|
|
|
|
|
|
|
|
|
|
tune_btn.click(fn=train, |
|
inputs=[model_name, |
|
inject_prompt, |
|
dataset_predefined, |
|
peft, |
|
sft, |
|
max_seq_length, |
|
random_seed, |
|
num_epochs, |
|
max_steps, |
|
data_field, |
|
repository, |
|
model_out_name |
|
], |
|
outputs=output) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |