File size: 6,422 Bytes
1118487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419fd29
 
1118487
419fd29
1118487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import pandas as pd
from datasets import Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import torch
import os
import matplotlib.pyplot as plt

def train_model(data_file, model_name, epochs, batch_size, learning_rate):
    try:
        if not data_file.name.endswith('.csv'):
            return "Invalid file format. Please upload a CSV file.", None
        
        df = pd.read_csv(data_file.name)
        
        if 'prompt' not in df.columns or 'completion' not in df.columns:
            return "CSV file must contain 'prompt' and 'completion' columns.", None
        
        df['text'] = df['prompt'] + ': ' + df['completion']
        dataset = Dataset.from_pandas(df[['text']])
        
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        model = GPT2LMHeadModel.from_pretrained(model_name)
        
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            tokenizer.pad_token = tokenizer.eos_token
        
        def tokenize_function(examples):
            tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
            tokens['labels'] = tokens['input_ids'].copy()
            return tokens
        
        tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
        
        training_args = TrainingArguments(
            output_dir="./results",
            overwrite_output_dir=True,
            num_train_epochs=int(epochs),
            per_device_train_batch_size=int(batch_size),
            per_device_eval_batch_size=int(batch_size),
            warmup_steps=1000,
            weight_decay=0.01,
            learning_rate=float(learning_rate),
            logging_dir="./logs",
            logging_steps=10,
            save_steps=10_000,
            save_total_limit=2,
            evaluation_strategy="steps",
            eval_steps=500,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss"
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_datasets,
            eval_dataset=tokenized_datasets,
        )
        
        trainer.train()
        eval_results = trainer.evaluate()
        
        model.save_pretrained('./fine-tuned-gpt2')
        tokenizer.save_pretrained('./fine-tuned-gpt2')
        
        train_loss = [log['loss'] for log in trainer.state.log_history if 'loss' in log]
        eval_loss = [log['eval_loss'] for log in trainer.state.log_history if 'eval_loss' in log]
        plt.plot(train_loss, label='Training Loss')
        plt.plot(eval_loss, label='Validation Loss')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.savefig('./training_eval_loss.png')
        
        return "Training completed successfully.", './training_eval_loss.png'
    except Exception as e:
        return f"An error occurred: {str(e)}", None

def generate_text(prompt, temperature, top_k, max_length, repetition_penalty, use_comma):
    try:
        model_name = "./fine-tuned-gpt2"
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        model = GPT2LMHeadModel.from_pretrained(model_name)
        
        if use_comma:
            prompt = prompt.replace('.', ',')
        
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask

        outputs = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=int(max_length),
            temperature=float(temperature) if temperature > 0 else None,
            top_k=int(top_k) if top_k > 0 else None,
            repetition_penalty=float(repetition_penalty),
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True  # Assurer que do_sample est activé pour temperature et top_k
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"An error occurred: {str(e)}"

def set_preset(preset):
    if preset == "Default":
        return 5, 8, 3e-5
    elif preset == "Fast Training":
        return 3, 16, 5e-5
    elif preset == "High Accuracy":
        return 10, 4, 1e-5

with gr.Blocks() as ui:
    gr.Markdown("# Fine-Tuner | Photographer Alpha7")
    
    with gr.Tab("Train Model"):
        with gr.Row():
            data_file = gr.File(label="Upload Data File (CSV)")
            model_name = gr.Textbox(label="Model Name", value="gpt2")
        
        with gr.Row():
            preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset")
            epochs = gr.Number(label="Epochs", value=5)
            batch_size = gr.Number(label="Batch Size", value=8)
            learning_rate = gr.Number(label="Learning Rate", value=3e-5)
        
        preset.change(set_preset, preset, [epochs, batch_size, learning_rate])
        
        train_button = gr.Button("Train Model")
        train_output = gr.Textbox(label="Training Output")
        train_graph = gr.Image(label="Training and Validation Loss Graph")
        
        train_button.click(train_model, inputs=[data_file, model_name, epochs, batch_size, learning_rate], outputs=[train_output, train_graph])
    
    with gr.Tab("Generate Text"):
        with gr.Column():
            with gr.Row():
                generated_text = gr.Textbox(label="Result view")  
            with gr.Row():
                prompt = gr.Textbox(label="Prompt")
                generate_button = gr.Button("Generate Text") 
        with gr.Column():
            temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7)
            top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50)
            max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128)
            repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2)
            use_comma = gr.Checkbox(label="Use Comma", value=True)
        
        generate_button.click(generate_text, inputs=[prompt, temperature, top_k, max_length, repetition_penalty, use_comma], outputs=generated_text)

ui.launch()