RemBG / rembg /session_base.py
KenjieDec's picture
ef928a1
from typing import Dict, List, Tuple
import numpy as np
import onnxruntime as ort
from PIL import Image
from PIL.Image import Image as PILImage
class BaseSession:
def __init__(self, model_name: str, inner_session: ort.InferenceSession):
self.model_name = model_name
self.inner_session = inner_session
def normalize(
self,
img: PILImage,
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
size: Tuple[int, int],
) -> Dict[str, np.ndarray]:
im = img.convert("RGB").resize(size, Image.LANCZOS)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
tmpImg = tmpImg.transpose((2, 0, 1))
return {
self.inner_session.get_inputs()[0]
.name: np.expand_dims(tmpImg, 0)
.astype(np.float32)
}
def predict(self, img: PILImage) -> List[PILImage]:
raise NotImplementedError