from typing import Dict, Any from io import BytesIO import base64 from model import ISNetDIS import torch import os from PIL import Image from torchvision.transforms import Compose, Normalize, functional def process_image(image: torch.Tensor): pipe = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) img = pipe(image) return torch.unsqueeze(img, 0) def get_model(device="cpu"): model = ISNetDIS() weight_pth = os.path.join(os.path.dirname(__file__), "isnet.pth") weights = torch.load(weight_pth, map_location=device) model.load_state_dict(weights) model.to(device) model.eval() return model class EndpointHandler(): def __init__(self, path=""): self._model = get_model() def __call__(self, data: Dict[str, Any]) -> list[Dict[str, Any]]: inputs = data.pop("inputs", data) image = Image.open(BytesIO(base64.b64decode(inputs['image']))) t = functional.pil_to_tensor(image).float().divide(255.0) arr = process_image(t) model = get_model() v = model(arr)[0] pred_val = v[0][0, :, :, :] ma = torch.max(pred_val) mi = torch.min(pred_val) pred_val = (pred_val - mi) / (ma - mi) msk = torch.gt(pred_val, 0.1) w = torch.where(msk, t, 1) w = torch.cat([w, msk], dim=0) img2 = functional.to_pil_image(torch.squeeze(w)) stream = BytesIO() img2.save(stream, format="png") res = {"status": 200, "image": base64.b64encode(stream.getvalue()).decode("utf8") } return res if __name__ == "__main__": h = EndpointHandler() v = h({}) print(v)