OVAWARE commited on
Commit
d47023e
·
verified ·
1 Parent(s): 772b83d

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +257 -0
train.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import json
8
+ import os
9
+ import subprocess
10
+ from transformers import BertTokenizer, BertModel
11
+ import wandb
12
+
13
+ # Hyperparameters
14
+ LATENT_DIM = 128
15
+ HIDDEN_DIM = 256
16
+
17
+ # Custom dataset
18
+ class Text2ImageDataset(Dataset):
19
+ def __init__(self, image_dir, metadata_file):
20
+ self.image_dir = image_dir
21
+ with open(metadata_file, 'r') as f:
22
+ self.metadata = json.load(f)
23
+ self.transform = transforms.Compose([
24
+ transforms.ToTensor(),
25
+ transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5))
26
+ ])
27
+
28
+ def __len__(self):
29
+ return len(self.metadata)
30
+
31
+ def __getitem__(self, idx):
32
+ item = self.metadata[idx]
33
+ image_path = os.path.join(self.image_dir, item['file_name'])
34
+
35
+ try:
36
+ image = Image.open(image_path).convert('RGBA')
37
+ except FileNotFoundError:
38
+ print(f"Image not found: {image_path}")
39
+ return None, None
40
+ except Exception as e:
41
+ print(f"Error loading image {image_path}: {e}")
42
+ return None, None
43
+
44
+ image = self.transform(image)
45
+ prompt = str(item['description'])
46
+ return image, prompt
47
+
48
+ # Text encoder
49
+ class TextEncoder(nn.Module):
50
+ def __init__(self, hidden_size, output_size):
51
+ super(TextEncoder, self).__init__()
52
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
53
+ self.fc = nn.Linear(self.bert.config.hidden_size, output_size)
54
+
55
+ def forward(self, input_ids, attention_mask):
56
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
57
+ return self.fc(outputs.last_hidden_state[:, 0, :])
58
+
59
+ # CVAE model
60
+ class CVAE(nn.Module):
61
+ def __init__(self, text_encoder):
62
+ super(CVAE, self).__init__()
63
+ self.text_encoder = text_encoder
64
+
65
+ # Encoder
66
+ self.encoder = nn.Sequential(
67
+ nn.Conv2d(4, 32, 3, stride=1, padding=1),
68
+ nn.ReLU(),
69
+ nn.Conv2d(32, 64, 3, stride=2, padding=1),
70
+ nn.ReLU(),
71
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
72
+ nn.ReLU(),
73
+ nn.Flatten(),
74
+ nn.Linear(128 * 4 * 4, HIDDEN_DIM)
75
+ )
76
+
77
+ self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM)
78
+ self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM)
79
+
80
+ # Decoder
81
+ self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4)
82
+ self.decoder = nn.Sequential(
83
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
84
+ nn.ReLU(),
85
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
86
+ nn.ReLU(),
87
+ nn.Conv2d(32, 4, 3, stride=1, padding=1),
88
+ nn.Tanh()
89
+ )
90
+
91
+ def encode(self, x, c):
92
+ x = self.encoder(x)
93
+ x = torch.cat([x, c], dim=1)
94
+ mu = self.fc_mu(x)
95
+ logvar = self.fc_logvar(x)
96
+ return mu, logvar
97
+
98
+ def decode(self, z, c):
99
+ z = torch.cat([z, c], dim=1)
100
+ x = self.decoder_input(z)
101
+ x = x.view(-1, 128, 4, 4)
102
+ return self.decoder(x)
103
+
104
+ def reparameterize(self, mu, logvar):
105
+ std = torch.exp(0.5 * logvar)
106
+ eps = torch.randn_like(std)
107
+ return mu + eps * std
108
+
109
+ def forward(self, x, c):
110
+ mu, logvar = self.encode(x, c)
111
+ z = self.reparameterize(mu, logvar)
112
+ return self.decode(z, c), mu, logvar
113
+
114
+ # Loss function
115
+ def loss_function(recon_x, x, mu, logvar):
116
+ BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
117
+ KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
118
+ return BCE + KLD
119
+
120
+ # Updated training function
121
+ def train(model, train_loader, optimizer, device, tokenizer):
122
+ model.train()
123
+ train_loss = 0
124
+ for batch_idx, (data, prompt) in enumerate(train_loader):
125
+ data = data.to(device)
126
+ optimizer.zero_grad()
127
+
128
+ encoded_input = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
129
+ input_ids = encoded_input['input_ids'].to(device)
130
+ attention_mask = encoded_input['attention_mask'].to(device)
131
+
132
+ text_encoding = model.text_encoder(input_ids, attention_mask)
133
+
134
+ recon_batch, mu, logvar = model(data, text_encoding)
135
+ loss = loss_function(recon_batch, data, mu, logvar)
136
+ loss.backward()
137
+ train_loss += loss.item()
138
+ optimizer.step()
139
+
140
+ # Log batch-level metrics
141
+ wandb.log({
142
+ "batch_loss": loss.item(),
143
+ "batch_reconstruction_loss": nn.functional.mse_loss(recon_batch, data, reduction='mean').item(),
144
+ "batch_kl_divergence": (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / data.size(0)).item()
145
+ })
146
+
147
+ avg_loss = train_loss / len(train_loader.dataset)
148
+ return avg_loss
149
+
150
+ # Updated main function
151
+ def main():
152
+
153
+ NUM_EPOCHS = 500
154
+ BATCH_SIZE = 128
155
+ LEARNING_RATE = 1e-4
156
+
157
+ # New hyperparameters
158
+ SAVE_INTERVAL = 25 # Save model every XXX epochs
159
+ SAVE_INTERVAL_IMAGE = 1 # Save generated image every XXX epochs
160
+ PROJECT_NAME = "BitRoss"
161
+ MODEL_NAME = "BitRoss"
162
+ SAVE_DIR = "/models/BitRoss/"
163
+
164
+ if(os.path.exists(SAVE_DIR) == False):
165
+ os.makedirs(SAVE_DIR)
166
+
167
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
168
+
169
+ if not os.path.exists(SAVE_DIR):
170
+ os.makedirs(SAVE_DIR)
171
+
172
+ DATA_DIR = "./trainingData/"
173
+ METADATA_FILE = "./trainingData/metadata.json"
174
+
175
+
176
+ # Initialize wandb
177
+ wandb.init(project=PROJECT_NAME, config={
178
+ "LATENT_DIM": LATENT_DIM,
179
+ "HIDDEN_DIM": HIDDEN_DIM,
180
+ "NUM_EPOCHS": NUM_EPOCHS,
181
+ "BATCH_SIZE": BATCH_SIZE,
182
+ "LEARNING_RATE": LEARNING_RATE,
183
+ "SAVE_INTERVAL": SAVE_INTERVAL,
184
+ "MODEL_NAME": MODEL_NAME
185
+ })
186
+
187
+
188
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189
+
190
+ dataset = Text2ImageDataset(DATA_DIR, METADATA_FILE)
191
+ train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
192
+
193
+ text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
194
+ model = CVAE(text_encoder).to(device)
195
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
196
+
197
+ # Log model architecture
198
+ wandb.watch(model, log="all", log_freq=100)
199
+
200
+ for epoch in range(1, NUM_EPOCHS + 1):
201
+ train_loss = train(model, train_loader, optimizer, device, tokenizer)
202
+ print(f'Epoch {epoch}, Loss: {train_loss:.4f}')
203
+
204
+ # Log epoch-level metrics
205
+ wandb.log({
206
+ "epoch": epoch,
207
+ "train_loss": train_loss,
208
+ })
209
+
210
+ # Generate image and save model every SAVE_INTERVAL epochs
211
+ if epoch % SAVE_INTERVAL_IMAGE == 0:
212
+ # Generate image
213
+ output_image = f"{SAVE_DIR}output_epoch_{epoch}.png"
214
+
215
+ # Generate image using the current model state
216
+ from generate import generate_image
217
+ prompt = "A blue sword made of diamond" # You can change this prompt as needed
218
+ generated_image = generate_image(model, prompt, device)
219
+ generated_image.save(output_image)
220
+
221
+ # Upload generated image to wandb
222
+ wandb.log({
223
+ "generated_image": wandb.Image(output_image, caption=f"Generated at epoch {epoch} with prompt {prompt}")
224
+ })
225
+
226
+
227
+ if epoch % SAVE_INTERVAL == 0:
228
+ model_save_path = f"{SAVE_DIR}{MODEL_NAME}_epoch_{epoch}.pth"
229
+ torch.save(model.state_dict(), model_save_path)
230
+ print(f"Model saved to {model_save_path}")
231
+
232
+ # Log sample reconstructions
233
+ if epoch % 10 == 0:
234
+ model.eval()
235
+ with torch.no_grad():
236
+ sample_data, sample_prompt = next(iter(train_loader))
237
+ sample_data = sample_data[:4].to(device) # Take first 4 samples
238
+ encoded_input = tokenizer(sample_prompt[:4], padding=True, truncation=True, return_tensors="pt")
239
+ input_ids = encoded_input['input_ids'].to(device)
240
+ attention_mask = encoded_input['attention_mask'].to(device)
241
+ text_encoding = model.text_encoder(input_ids, attention_mask)
242
+ recon_batch, _, _ = model(sample_data, text_encoding)
243
+
244
+ # Denormalize and convert to PIL images
245
+ original_images = [transforms.ToPILImage()((sample_data[i] * 0.5 + 0.5).cpu()) for i in range(4)]
246
+ reconstructed_images = [transforms.ToPILImage()((recon_batch[i] * 0.5 + 0.5).cpu()) for i in range(4)]
247
+
248
+ wandb.log({
249
+ f"original_vs_reconstructed_{i}": [wandb.Image(original_images[i], caption=f"Original {i}"),
250
+ wandb.Image(reconstructed_images[i], caption=f"Reconstructed {i}")]
251
+ for i in range(4)
252
+ })
253
+
254
+ wandb.finish()
255
+
256
+ if __name__ == "__main__":
257
+ main()