xxxpo13 commited on
Commit
fee70e0
·
verified ·
1 Parent(s): b5225b3

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +96 -0
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader, DistributedSampler
6
+ from torchvision import datasets, transforms
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+
9
+ # Set your model class here (for demonstration, we'll create a simple CNN)
10
+ class SimpleCNN(nn.Module):
11
+ def __init__(self):
12
+ super(SimpleCNN, self).__init__()
13
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
14
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
15
+ self.fc1 = nn.Linear(64 * 7 * 7, 128)
16
+ self.fc2 = nn.Linear(128, 10)
17
+
18
+ def forward(self, x):
19
+ x = nn.ReLU()(self.conv1(x))
20
+ x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
21
+ x = nn.ReLU()(self.conv2(x))
22
+ x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
23
+ x = x.view(x.size(0), -1)
24
+ x = nn.ReLU()(self.fc1(x))
25
+ x = self.fc2(x)
26
+ return x
27
+
28
+ def init_distributed_mode():
29
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
30
+ rank = int(os.environ['RANK'])
31
+ world_size = int(os.environ['WORLD_SIZE'])
32
+ dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
33
+ torch.cuda.set_device(rank % torch.cuda.device_count())
34
+ print(f"Initialized distributed mode: rank {rank}, world size {world_size}")
35
+ else:
36
+ print("Not using distributed mode")
37
+ rank = 0
38
+ world_size = 1
39
+ return rank, world_size
40
+
41
+ def main():
42
+ # Initialize the distributed mode
43
+ rank, world_size = init_distributed_mode()
44
+
45
+ # Set up data transformations
46
+ transform = transforms.Compose([
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.5,), (0.5,))
49
+ ])
50
+
51
+ # Load dataset
52
+ train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
53
+ train_sampler = DistributedSampler(train_dataset)
54
+ train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
55
+
56
+ # Initialize model
57
+ model = SimpleCNN()
58
+ device = torch.device(f'cuda:{rank % torch.cuda.device_count()}')
59
+ model.to(device)
60
+
61
+ # Wrap the model with DDP
62
+ if world_size > 1:
63
+ model = DDP(model, device_ids=[rank], output_device=rank)
64
+
65
+ # Set up the optimizer and loss function
66
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
67
+ criterion = nn.CrossEntropyLoss()
68
+
69
+ # Training loop
70
+ for epoch in range(10): # Train for 10 epochs
71
+ train_sampler.set_epoch(epoch) # Shuffle data every epoch
72
+ running_loss = 0.0
73
+
74
+ for inputs, targets in train_loader:
75
+ inputs, targets = inputs.to(device), targets.to(device)
76
+
77
+ # Forward pass
78
+ outputs = model(inputs)
79
+ loss = criterion(outputs, targets)
80
+
81
+ # Backward pass and optimization
82
+ optimizer.zero_grad()
83
+ loss.backward()
84
+ optimizer.step()
85
+
86
+ running_loss += loss.item()
87
+
88
+ if rank == 0: # Only print from the main process
89
+ print(f'Epoch [{epoch + 1}/10], Loss: {running_loss / len(train_loader):.4f}')
90
+
91
+ # Clean up distributed training
92
+ if world_size > 1:
93
+ dist.destroy_process_group()
94
+
95
+ if __name__ == '__main__':
96
+ main()