Spaces:
Running
on
Zero
Running
on
Zero
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from safetensors.torch import load_file | |
from data_loader import PASCALSDataset | |
from model import U2Net | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print('Device:', device) | |
def load_model(model, model_path): | |
state_dict = load_file(model_path, device=device.type) | |
model.load_state_dict(state_dict) | |
model.eval() | |
def eval(model, loader, criterion): | |
model.eval() | |
running_loss = 0. | |
with torch.no_grad(): | |
for images, masks in tqdm(loader, desc='Testing'): | |
images, masks = images.to(device), masks.to(device) | |
outputs = model(images) | |
loss = sum([criterion(output, masks) for output in outputs]) | |
running_loss += loss.item() | |
return running_loss / len(loader) | |
if __name__ == '__main__': | |
batch_size = 1 | |
model_type = input('Model type [b,f]: ') | |
model_name = 'best-u2net-duts-msra.safetensors' if model_type == 'b' else 'u2net-duts-msra.safetensors' | |
loss_fn = nn.BCEWithLogitsLoss(reduction='mean') | |
model = U2Net().to(device) | |
model = nn.DataParallel(model) | |
load_model(model, f'results/{model_name}') | |
loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False) | |
loss = eval(model, loader, loss_fn) | |
print(f'Loss: {loss:.4f}') |