jamino30's picture
Upload folder using huggingface_hub
ecf0440 verified
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.amp import autocast, GradScaler
from safetensors.torch import save_file
from data_loader import DUTSDataset, MSRADataset
from model import U2Net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = GradScaler()
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
inputs = torch.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
def train_one_epoch(model, loader, criterion, optimizer):
model.train()
running_loss = 0.
for images, masks in tqdm(loader, desc='Training', leave=False):
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast(device_type='cuda'):
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
return running_loss / len(loader)
def validate(model, loader, criterion):
model.eval()
running_loss = 0.
with torch.no_grad():
for images, masks in tqdm(loader, desc='Validating', leave=False):
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
running_loss += loss.item()
avg_loss = running_loss / len(loader)
return avg_loss
def save(model, model_name, losses):
save_file(model.state_dict(), f'results/{model_name}.safetensors')
with open('results/loss.txt', 'wb') as f:
pickle.dump(losses, f)
if __name__ == '__main__':
batch_size = 40
valid_batch_size = 80
epochs = 200
lr = 1e-3
loss_fn_bce = nn.BCEWithLogitsLoss(reduction='mean')
loss_fn_dice = DiceLoss()
alpha = 0.6
loss_fn = lambda o, m: alpha * loss_fn_bce(o, m) + (1 - alpha) * loss_fn_dice(o, m)
model_name = 'u2net-duts-msra'
model = U2Net()
model = torch.nn.parallel.DataParallel(model.to(device))
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
train_loader = DataLoader(
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=8, persistent_workers=True
)
valid_loader = DataLoader(
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
batch_size=valid_batch_size, shuffle=False, pin_memory=True,
num_workers=8, persistent_workers=True
)
best_val_loss = float('inf')
losses = {'train': [], 'val': []}
# training loop
try:
for epoch in range(epochs):
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
val_loss = validate(model, valid_loader, loss_fn)
losses['train'].append(train_loss)
losses['val'].append(val_loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
finally:
save(model, model_name, losses)