PhotographerAlpha7 commited on
Commit
1118487
1 Parent(s): 8a9727b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datasets import Dataset
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
5
+ import torch
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+
9
+ def train_model(data_file, model_name, epochs, batch_size, learning_rate):
10
+ try:
11
+ if not data_file.name.endswith('.csv'):
12
+ return "Invalid file format. Please upload a CSV file.", None
13
+
14
+ df = pd.read_csv(data_file.name)
15
+
16
+ if 'prompt' not in df.columns or 'description' not in df.columns:
17
+ return "CSV file must contain 'prompt' and 'description' columns.", None
18
+
19
+ df['text'] = df['prompt'] + ': ' + df['description']
20
+ dataset = Dataset.from_pandas(df[['text']])
21
+
22
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
23
+ model = GPT2LMHeadModel.from_pretrained(model_name)
24
+
25
+ if tokenizer.pad_token is None:
26
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+
29
+ def tokenize_function(examples):
30
+ tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
31
+ tokens['labels'] = tokens['input_ids'].copy()
32
+ return tokens
33
+
34
+ tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
35
+
36
+ training_args = TrainingArguments(
37
+ output_dir="./results",
38
+ overwrite_output_dir=True,
39
+ num_train_epochs=int(epochs),
40
+ per_device_train_batch_size=int(batch_size),
41
+ per_device_eval_batch_size=int(batch_size),
42
+ warmup_steps=1000,
43
+ weight_decay=0.01,
44
+ learning_rate=float(learning_rate),
45
+ logging_dir="./logs",
46
+ logging_steps=10,
47
+ save_steps=10_000,
48
+ save_total_limit=2,
49
+ evaluation_strategy="steps",
50
+ eval_steps=500,
51
+ load_best_model_at_end=True,
52
+ metric_for_best_model="eval_loss"
53
+ )
54
+
55
+ trainer = Trainer(
56
+ model=model,
57
+ args=training_args,
58
+ train_dataset=tokenized_datasets,
59
+ eval_dataset=tokenized_datasets,
60
+ )
61
+
62
+ trainer.train()
63
+ eval_results = trainer.evaluate()
64
+
65
+ model.save_pretrained('./fine-tuned-gpt2')
66
+ tokenizer.save_pretrained('./fine-tuned-gpt2')
67
+
68
+ train_loss = [log['loss'] for log in trainer.state.log_history if 'loss' in log]
69
+ eval_loss = [log['eval_loss'] for log in trainer.state.log_history if 'eval_loss' in log]
70
+ plt.plot(train_loss, label='Training Loss')
71
+ plt.plot(eval_loss, label='Validation Loss')
72
+ plt.xlabel('Steps')
73
+ plt.ylabel('Loss')
74
+ plt.title('Training and Validation Loss')
75
+ plt.legend()
76
+ plt.savefig('./training_eval_loss.png')
77
+
78
+ return "Training completed successfully.", './training_eval_loss.png'
79
+ except Exception as e:
80
+ return f"An error occurred: {str(e)}", None
81
+
82
+ def generate_text(prompt, temperature, top_k, max_length, repetition_penalty, use_comma):
83
+ try:
84
+ model_name = "./fine-tuned-gpt2"
85
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
86
+ model = GPT2LMHeadModel.from_pretrained(model_name)
87
+
88
+ if use_comma:
89
+ prompt = prompt.replace('.', ',')
90
+
91
+ inputs = tokenizer(prompt, return_tensors="pt")
92
+ input_ids = inputs.input_ids
93
+ attention_mask = inputs.attention_mask
94
+
95
+ outputs = model.generate(
96
+ input_ids,
97
+ attention_mask=attention_mask,
98
+ max_length=int(max_length),
99
+ temperature=float(temperature) if temperature > 0 else None,
100
+ top_k=int(top_k) if top_k > 0 else None,
101
+ repetition_penalty=float(repetition_penalty),
102
+ num_return_sequences=1,
103
+ pad_token_id=tokenizer.eos_token_id,
104
+ do_sample=True # Assurer que do_sample est activé pour temperature et top_k
105
+ )
106
+
107
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
108
+ except Exception as e:
109
+ return f"An error occurred: {str(e)}"
110
+
111
+ def set_preset(preset):
112
+ if preset == "Default":
113
+ return 5, 8, 3e-5
114
+ elif preset == "Fast Training":
115
+ return 3, 16, 5e-5
116
+ elif preset == "High Accuracy":
117
+ return 10, 4, 1e-5
118
+
119
+ with gr.Blocks() as ui:
120
+ gr.Markdown("# Fine-Tuner | Photographer Alpha7")
121
+
122
+ with gr.Tab("Train Model"):
123
+ with gr.Row():
124
+ data_file = gr.File(label="Upload Data File (CSV)")
125
+ model_name = gr.Textbox(label="Model Name", value="gpt2")
126
+
127
+ with gr.Row():
128
+ preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset")
129
+ epochs = gr.Number(label="Epochs", value=5)
130
+ batch_size = gr.Number(label="Batch Size", value=8)
131
+ learning_rate = gr.Number(label="Learning Rate", value=3e-5)
132
+
133
+ preset.change(set_preset, preset, [epochs, batch_size, learning_rate])
134
+
135
+ train_button = gr.Button("Train Model")
136
+ train_output = gr.Textbox(label="Training Output")
137
+ train_graph = gr.Image(label="Training and Validation Loss Graph")
138
+
139
+ train_button.click(train_model, inputs=[data_file, model_name, epochs, batch_size, learning_rate], outputs=[train_output, train_graph])
140
+
141
+ with gr.Tab("Generate Text"):
142
+ with gr.Column():
143
+ with gr.Row():
144
+ generated_text = gr.Textbox(label="Result view")
145
+ with gr.Row():
146
+ prompt = gr.Textbox(label="Prompt")
147
+ generate_button = gr.Button("Generate Text")
148
+ with gr.Column():
149
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7)
150
+ top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50)
151
+ max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128)
152
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2)
153
+ use_comma = gr.Checkbox(label="Use Comma", value=True)
154
+
155
+ generate_button.click(generate_text, inputs=[prompt, temperature, top_k, max_length, repetition_penalty, use_comma], outputs=generated_text)
156
+
157
+ ui.launch()