BiRefNet / utils.py
ZhengPeng7's picture
Move all BiRefNet github codes to the first level directory.
45e89f1
raw
history blame
No virus
2.75 kB
import logging
import os
import torch
from torchvision import transforms
import numpy as np
import random
import cv2
from PIL import Image
def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]):
if color_type.lower() == 'rgb':
image = cv2.imread(path)
elif color_type.lower() == 'gray':
image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
else:
print('Select the color_type to return, either to RGB or gray image.')
return
if size:
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
if color_type.lower() == 'rgb':
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB')
else:
image = Image.fromarray(image).convert('L')
return image
def check_state_dict(state_dict, unwanted_prefix='_orig_mod.'):
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
return state_dict
def generate_smoothed_gt(gts):
epsilon = 0.001
new_gts = (1-epsilon)*gts+epsilon/2
return new_gts
class Logger():
def __init__(self, path="log.txt"):
self.logger = logging.getLogger('BiRefNet')
self.file_handler = logging.FileHandler(path, "w")
self.stdout_handler = logging.StreamHandler()
self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
self.logger.addHandler(self.file_handler)
self.logger.addHandler(self.stdout_handler)
self.logger.setLevel(logging.INFO)
self.logger.propagate = False
def info(self, txt):
self.logger.info(txt)
def close(self):
self.file_handler.close()
self.stdout_handler.close()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_checkpoint(state, path, filename="latest.pth"):
torch.save(state, os.path.join(path, filename))
def save_tensor_img(tenor_im, path):
im = tenor_im.cpu().clone()
im = im.squeeze(0)
tensor2pil = transforms.ToPILImage()
im = tensor2pil(im)
im.save(path)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True