|
import gradio as gr |
|
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM |
|
from datasets import load_dataset |
|
import traceback |
|
|
|
def fine_tune_model(model_name, dataset_name, hub_id, num_epochs, batch_size, lr, grad): |
|
try: |
|
|
|
|
|
dataset = load_dataset(dataset_name) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, num_labels=2) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples['text'], padding="max_length", truncation=True) |
|
|
|
tokenized_datasets = dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
evaluation_strategy="epoch", |
|
learning_rate=lr, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
num_train_epochs=num_epochs, |
|
weight_decay=0.01, |
|
gradient_accumulation_steps=grad, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="accuracy", |
|
greater_is_better=True, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
push_to_hub=True, |
|
hub_model_id=hub_id, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_datasets['train'], |
|
eval_dataset=tokenized_datasets['validation'], |
|
) |
|
|
|
|
|
trainer.train() |
|
trainer.push_to_hub(commit_message="Training complete!") |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}" |
|
return 'DONE!' |
|
''' |
|
# Define Gradio interface |
|
def predict(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
outputs = model(inputs) |
|
predictions = outputs.logits.argmax(dim=-1) |
|
return "Positive" if predictions.item() == 1 else "Negative" |
|
''' |
|
|
|
try: |
|
|
|
iface = gr.Interface( |
|
fn=fine_tune_model, |
|
inputs=[ |
|
gr.Textbox(label="Model Name (e.g., 'google/t5-efficient-tiny-nh8')"), |
|
gr.Textbox(label="Dataset Name (e.g., 'imdb')"), |
|
gr.Textbox(label="HF hub to push to after training"), |
|
gr.Slider(minimum=1, maximum=10, value=3, label="Number of Epochs"), |
|
gr.Slider(minimum=1, maximum=16, value=4, label="Batch Size"), |
|
gr.Slider(minimum=1, maximum=1000, value=50, label="Learning Rate (e-5)"), |
|
gr.Slider(minimum=1, maximum=100, value=1, label="Gradient accumulation (e-1)"), |
|
], |
|
outputs="text", |
|
title="Fine-Tune Hugging Face Model", |
|
description="This interface allows you to fine-tune a Hugging Face model on a specified dataset." |
|
) |
|
|
|
iface.launch() |
|
except Exception as e: |
|
print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}") |
|
|
|
|