|
import os |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
from transformers import ViTForImageClassification, TrainingArguments, Trainer |
|
from datasets import load_dataset |
|
|
|
def finetune_model(epochs, save_at_num_epoch, learning_rate): |
|
|
|
dataset = load_dataset("imagenet") |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="vit-finetuned", |
|
num_train_epochs=epochs, |
|
save_strategy="steps", |
|
save_steps=save_at_num_epoch, |
|
learning_rate=learning_rate, |
|
) |
|
|
|
|
|
trainer = Trainer(model=model, args=training_args, train_dataset=dataset["train"]) |
|
train_metrics = trainer.train() |
|
|
|
|
|
model.save_pretrained("vit-finetuned") |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.plot(train_metrics.history["loss"]) |
|
plt.title("Model Loss") |
|
plt.xlabel("Epoch") |
|
plt.ylabel("Loss") |
|
plt.savefig("loss_graph.png") |
|
|
|
return "Fine-tuning complete!" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Fine-Tune a Model") |
|
|
|
with gr.Column(): |
|
epochs = gr.Slider(label="Epochs", minimum=1, maximum=10, value=3) |
|
save_at_num_epoch = gr.Slider(label="Save Model Every X Epochs", minimum=1, maximum=epochs, value=1) |
|
learning_rate = gr.Slider(label="Learning Rate", minimum=1e-5, maximum=1e-3, value=2e-5) |
|
run_button = gr.Button("Fine-Tune Model") |
|
|
|
status = gr.Textbox(label="Fine-Tuning Status") |
|
loss_graph = gr.Image(label="Loss Graph") |
|
|
|
run_button.click(finetune_model, inputs=[epochs, save_at_num_epoch, learning_rate], outputs=[status, loss_graph]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |