|
from typing import Dict, Any |
|
from io import BytesIO |
|
import base64 |
|
from model import ISNetDIS |
|
import torch |
|
import os |
|
from PIL import Image |
|
from torchvision.transforms import Compose, Normalize, functional |
|
|
|
|
|
def process_image(image: torch.Tensor): |
|
pipe = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
img = pipe(image) |
|
return torch.unsqueeze(img, 0) |
|
|
|
|
|
def get_model(device="cpu"): |
|
model = ISNetDIS() |
|
weight_pth = os.path.join(os.path.dirname(__file__), "isnet.pth") |
|
weights = torch.load(weight_pth, map_location=device) |
|
model.load_state_dict(weights) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
class EndpointHandler(): |
|
|
|
def __init__(self, path=""): |
|
self._model = get_model() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> list[Dict[str, Any]]: |
|
inputs = data.pop("inputs", data) |
|
|
|
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
t = functional.pil_to_tensor(image).float().divide(255.0) |
|
arr = process_image(t) |
|
model = get_model() |
|
v = model(arr)[0] |
|
|
|
pred_val = v[0][0, :, :, :] |
|
ma = torch.max(pred_val) |
|
mi = torch.min(pred_val) |
|
pred_val = (pred_val - mi) / (ma - mi) |
|
|
|
msk = torch.gt(pred_val, 0.1) |
|
w = torch.where(msk, t, 1) |
|
w = torch.cat([w, msk], dim=0) |
|
|
|
img2 = functional.to_pil_image(torch.squeeze(w)) |
|
|
|
|
|
stream = BytesIO() |
|
img2.save(stream, format="png") |
|
res = {"status": 200, |
|
"image": base64.b64encode(stream.getvalue()).decode("utf8") |
|
} |
|
return res |
|
|
|
|
|
if __name__ == "__main__": |
|
h = EndpointHandler() |
|
v = h({}) |
|
print(v) |