Rooni commited on
Commit
0819d7f
1 Parent(s): ad1fd1c

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +45 -0
train_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+
4
+ # Загрузка датасета ImageNet
5
+ dataset = load_dataset("imagenet-1k")
6
+
7
+ # Инициализация модели и токенизатора
8
+ model_name = "gpt2"
9
+ model = GPT2LMHeadModel.from_pretrained(model_name)
10
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
11
+
12
+ # Предобработка данных
13
+ def preprocess_data(examples):
14
+ inputs = examples["image"]
15
+ targets = examples["caption"]
16
+ inputs = tokenizer(inputs, padding=True, truncation=True, max_length=512, return_tensors="pt")
17
+ targets = tokenizer(targets, padding=True, truncation=True, max_length=512, return_tensors="pt")
18
+ inputs["labels"] = targets["input_ids"]
19
+ return inputs
20
+
21
+ # Применение предобработки к датасету
22
+ dataset = dataset.map(preprocess_data, batched=True)
23
+
24
+ # Определение аргументов обучения
25
+ training_args = TrainingArguments(
26
+ output_dir="./model",
27
+ num_train_epochs=5,
28
+ per_device_train_batch_size=4,
29
+ per_device_eval_batch_size=4,
30
+ warmup_steps=500,
31
+ weight_decay=0.01,
32
+ logging_dir="./logs",
33
+ logging_steps=100,
34
+ evaluation_strategy="epoch",
35
+ )
36
+
37
+ # Создание трейнера и обучение модели
38
+ trainer = Trainer(
39
+ model=model,
40
+ args=training_args,
41
+ train_dataset=dataset["train"],
42
+ eval_dataset=dataset["validation"],
43
+ data_collator=None,
44
+ )
45
+ trainer.train()