Kevin Fink
commited on
Commit
·
038610e
1
Parent(s):
d253497
init
Browse files- app.py +5 -2
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,11 +3,13 @@ import gradio as gr
|
|
3 |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from datasets import load_dataset
|
5 |
import traceback
|
|
|
6 |
from peft import get_peft_model, LoraConfig
|
7 |
|
8 |
@spaces.GPU
|
9 |
-
def fine_tune_model(model_name, dataset_name, hub_id, num_epochs, batch_size, lr, grad):
|
10 |
try:
|
|
|
11 |
lora_config = LoraConfig(
|
12 |
r=16, # Rank of the low-rank adaptation
|
13 |
lora_alpha=32, # Scaling factor
|
@@ -25,7 +27,7 @@ def fine_tune_model(model_name, dataset_name, hub_id, num_epochs, batch_size, lr
|
|
25 |
# Tokenize the dataset
|
26 |
def tokenize_function(examples):
|
27 |
max_length = 256
|
28 |
-
return tokenizer(examples['text'], truncation=True)
|
29 |
|
30 |
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
31 |
|
@@ -82,6 +84,7 @@ try:
|
|
82 |
gr.Textbox(label="Model Name (e.g., 'google/t5-efficient-tiny-nh8')"),
|
83 |
gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
|
84 |
gr.Textbox(label="HF hub to push to after training"),
|
|
|
85 |
gr.Slider(minimum=1, maximum=10, value=3, label="Number of Epochs"),
|
86 |
gr.Slider(minimum=1, maximum=16, value=4, label="Batch Size"),
|
87 |
gr.Slider(minimum=1, maximum=1000, value=50, label="Learning Rate (e-5)"),
|
|
|
3 |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from datasets import load_dataset
|
5 |
import traceback
|
6 |
+
from huggingface_hub import login
|
7 |
from peft import get_peft_model, LoraConfig
|
8 |
|
9 |
@spaces.GPU
|
10 |
+
def fine_tune_model(model_name, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
11 |
try:
|
12 |
+
login(api_key)
|
13 |
lora_config = LoraConfig(
|
14 |
r=16, # Rank of the low-rank adaptation
|
15 |
lora_alpha=32, # Scaling factor
|
|
|
27 |
# Tokenize the dataset
|
28 |
def tokenize_function(examples):
|
29 |
max_length = 256
|
30 |
+
return tokenizer(examples['text'], max_length=max_length, truncation=True)
|
31 |
|
32 |
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
33 |
|
|
|
84 |
gr.Textbox(label="Model Name (e.g., 'google/t5-efficient-tiny-nh8')"),
|
85 |
gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
|
86 |
gr.Textbox(label="HF hub to push to after training"),
|
87 |
+
gr.Textbox(label="HF API token"),
|
88 |
gr.Slider(minimum=1, maximum=10, value=3, label="Number of Epochs"),
|
89 |
gr.Slider(minimum=1, maximum=16, value=4, label="Batch Size"),
|
90 |
gr.Slider(minimum=1, maximum=1000, value=50, label="Learning Rate (e-5)"),
|
requirements.txt
CHANGED
@@ -2,3 +2,4 @@ spaces
|
|
2 |
transformers
|
3 |
datasets
|
4 |
peft
|
|
|
|
2 |
transformers
|
3 |
datasets
|
4 |
peft
|
5 |
+
huggingface_hub
|