PhotographerAlpha7's picture
Update app.py
419fd29 verified
import gradio as gr
import pandas as pd
from datasets import Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import torch
import os
import matplotlib.pyplot as plt
def train_model(data_file, model_name, epochs, batch_size, learning_rate):
try:
if not data_file.name.endswith('.csv'):
return "Invalid file format. Please upload a CSV file.", None
df = pd.read_csv(data_file.name)
if 'prompt' not in df.columns or 'completion' not in df.columns:
return "CSV file must contain 'prompt' and 'completion' columns.", None
df['text'] = df['prompt'] + ': ' + df['completion']
dataset = Dataset.from_pandas(df[['text']])
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
tokens['labels'] = tokens['input_ids'].copy()
return tokens
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
num_train_epochs=int(epochs),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size),
warmup_steps=1000,
weight_decay=0.01,
learning_rate=float(learning_rate),
logging_dir="./logs",
logging_steps=10,
save_steps=10_000,
save_total_limit=2,
evaluation_strategy="steps",
eval_steps=500,
load_best_model_at_end=True,
metric_for_best_model="eval_loss"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
eval_dataset=tokenized_datasets,
)
trainer.train()
eval_results = trainer.evaluate()
model.save_pretrained('./fine-tuned-gpt2')
tokenizer.save_pretrained('./fine-tuned-gpt2')
train_loss = [log['loss'] for log in trainer.state.log_history if 'loss' in log]
eval_loss = [log['eval_loss'] for log in trainer.state.log_history if 'eval_loss' in log]
plt.plot(train_loss, label='Training Loss')
plt.plot(eval_loss, label='Validation Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig('./training_eval_loss.png')
return "Training completed successfully.", './training_eval_loss.png'
except Exception as e:
return f"An error occurred: {str(e)}", None
def generate_text(prompt, temperature, top_k, max_length, repetition_penalty, use_comma):
try:
model_name = "./fine-tuned-gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
if use_comma:
prompt = prompt.replace('.', ',')
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=int(max_length),
temperature=float(temperature) if temperature > 0 else None,
top_k=int(top_k) if top_k > 0 else None,
repetition_penalty=float(repetition_penalty),
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True # Assurer que do_sample est activé pour temperature et top_k
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"An error occurred: {str(e)}"
def set_preset(preset):
if preset == "Default":
return 5, 8, 3e-5
elif preset == "Fast Training":
return 3, 16, 5e-5
elif preset == "High Accuracy":
return 10, 4, 1e-5
with gr.Blocks() as ui:
gr.Markdown("# Fine-Tuner | Photographer Alpha7")
with gr.Tab("Train Model"):
with gr.Row():
data_file = gr.File(label="Upload Data File (CSV)")
model_name = gr.Textbox(label="Model Name", value="gpt2")
with gr.Row():
preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset")
epochs = gr.Number(label="Epochs", value=5)
batch_size = gr.Number(label="Batch Size", value=8)
learning_rate = gr.Number(label="Learning Rate", value=3e-5)
preset.change(set_preset, preset, [epochs, batch_size, learning_rate])
train_button = gr.Button("Train Model")
train_output = gr.Textbox(label="Training Output")
train_graph = gr.Image(label="Training and Validation Loss Graph")
train_button.click(train_model, inputs=[data_file, model_name, epochs, batch_size, learning_rate], outputs=[train_output, train_graph])
with gr.Tab("Generate Text"):
with gr.Column():
with gr.Row():
generated_text = gr.Textbox(label="Result view")
with gr.Row():
prompt = gr.Textbox(label="Prompt")
generate_button = gr.Button("Generate Text")
with gr.Column():
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7)
top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50)
max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2)
use_comma = gr.Checkbox(label="Use Comma", value=True)
generate_button.click(generate_text, inputs=[prompt, temperature, top_k, max_length, repetition_penalty, use_comma], outputs=generated_text)
ui.launch()