Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import random_split, DataLoader | |
import pytorch_lightning as pl | |
from pytorch_lightning.loggers import MLFlowLogger | |
from src.trainer import LitTrainer | |
def argmax(a): | |
return max(range(len(a)), key=lambda x: a[x]) | |
def get_dataloaders(dataset, test_data): | |
train_size = round(len(dataset) * 0.8) | |
validate_size = len(dataset) - train_size | |
train_data, validate_data = random_split(dataset, [train_size, validate_size]) | |
# For 8 CPU cores | |
return DataLoader(train_data, num_workers=8), \ | |
DataLoader(validate_data, num_workers=8), \ | |
DataLoader(test_data, num_workers=8) | |
def train_loop(net, batch, loss_fn, optim, device="cuda"): | |
x, y = batch | |
x = x.to(device) | |
y = y.to(device) | |
y_pred = net(x).reshape(1, -1) | |
loss = loss_fn(y_pred, y) | |
truth_count = argmax(y_pred.flatten()) == y | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
return loss.item(), truth_count | |
def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"): | |
for i in range(epochs): | |
print("Epoch: 0") | |
epoch_loss = 0 | |
epoch_truth_count = 0 | |
for idx, batch in enumerate(train_loader): | |
loss, truth_count = train_loop(net, batch, loss_fn, optim, device) | |
epoch_loss += loss | |
epoch_truth_count += truth_count | |
if idx % 1000 == 0: | |
print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})") | |
print(f"Epoch Loss: {epoch_loss}") | |
print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}") | |
torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt") | |
def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10): | |
logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs") | |
pl_net = LitTrainer(net) | |
pl_net.optim = optim | |
pl_net.loss = loss_fn | |
trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs, | |
default_root_dir="../checkpoints", logger=logger) | |
trainer.fit(pl_net, train_loader, validate_loader) | |