Spaces:
Running
on
Zero
Running
on
Zero
import pickle | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, ConcatDataset | |
from torch.amp import autocast, GradScaler | |
from safetensors.torch import save_file | |
from data_loader import DUTSDataset, MSRADataset | |
from model import U2Net | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
scaler = GradScaler() | |
class DiceLoss(nn.Module): | |
def __init__(self): | |
super(DiceLoss, self).__init__() | |
def forward(self, inputs, targets, smooth=1): | |
inputs = torch.sigmoid(inputs) | |
inputs = inputs.view(-1) | |
targets = targets.view(-1) | |
intersection = (inputs * targets).sum() | |
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) | |
return 1 - dice | |
def train_one_epoch(model, loader, criterion, optimizer): | |
model.train() | |
running_loss = 0. | |
for images, masks in tqdm(loader, desc='Training', leave=False): | |
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True) | |
optimizer.zero_grad() | |
with autocast(device_type='cuda'): | |
outputs = model(images) | |
loss = sum([criterion(output, masks) for output in outputs]) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
running_loss += loss.item() | |
return running_loss / len(loader) | |
def validate(model, loader, criterion): | |
model.eval() | |
running_loss = 0. | |
with torch.no_grad(): | |
for images, masks in tqdm(loader, desc='Validating', leave=False): | |
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True) | |
outputs = model(images) | |
loss = sum([criterion(output, masks) for output in outputs]) | |
running_loss += loss.item() | |
avg_loss = running_loss / len(loader) | |
return avg_loss | |
def save(model, model_name, losses): | |
save_file(model.state_dict(), f'results/{model_name}.safetensors') | |
with open('results/loss.txt', 'wb') as f: | |
pickle.dump(losses, f) | |
if __name__ == '__main__': | |
batch_size = 40 | |
valid_batch_size = 80 | |
epochs = 200 | |
lr = 1e-3 | |
loss_fn_bce = nn.BCEWithLogitsLoss(reduction='mean') | |
loss_fn_dice = DiceLoss() | |
alpha = 0.6 | |
loss_fn = lambda o, m: alpha * loss_fn_bce(o, m) + (1 - alpha) * loss_fn_dice(o, m) | |
model_name = 'u2net-duts-msra' | |
model = U2Net() | |
model = torch.nn.parallel.DataParallel(model.to(device)) | |
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) | |
train_loader = DataLoader( | |
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]), | |
batch_size=batch_size, shuffle=True, pin_memory=True, | |
num_workers=8, persistent_workers=True | |
) | |
valid_loader = DataLoader( | |
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]), | |
batch_size=valid_batch_size, shuffle=False, pin_memory=True, | |
num_workers=8, persistent_workers=True | |
) | |
best_val_loss = float('inf') | |
losses = {'train': [], 'val': []} | |
# training loop | |
try: | |
for epoch in range(epochs): | |
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer) | |
val_loss = validate(model, valid_loader, loss_fn) | |
losses['train'].append(train_loss) | |
losses['val'].append(val_loss) | |
if val_loss < best_val_loss: | |
best_val_loss = val_loss | |
save_file(model.state_dict(), f'results/best-{model_name}.safetensors') | |
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})') | |
finally: | |
save(model, model_name, losses) |