from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import DataLoader from safetensors.torch import load_file from data_loader import PASCALSDataset from model import U2Net device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Device:', device) def load_model(model, model_path): state_dict = load_file(model_path, device=device.type) model.load_state_dict(state_dict) model.eval() def eval(model, loader, criterion): model.eval() running_loss = 0. with torch.no_grad(): for images, masks in tqdm(loader, desc='Testing'): images, masks = images.to(device), masks.to(device) outputs = model(images) loss = sum([criterion(output, masks) for output in outputs]) running_loss += loss.item() return running_loss / len(loader) if __name__ == '__main__': batch_size = 1 model_type = input('Model type [b,f]: ') model_name = 'best-u2net-duts-msra.safetensors' if model_type == 'b' else 'u2net-duts-msra.safetensors' loss_fn = nn.BCEWithLogitsLoss(reduction='mean') model = U2Net().to(device) model = nn.DataParallel(model) load_model(model, f'results/{model_name}') loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False) loss = eval(model, loader, loss_fn) print(f'Loss: {loss:.4f}')