File size: 1,684 Bytes
e9311f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e5048
e9311f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Dict, Any
from io import BytesIO
import base64
from model import ISNetDIS
import torch
import os
from PIL import Image
from torchvision.transforms import Compose, Normalize, functional


def process_image(image: torch.Tensor):
    pipe = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img = pipe(image)
    return torch.unsqueeze(img, 0)


def get_model(device="cpu"):
    model = ISNetDIS()
    weight_pth = os.path.join(os.path.dirname(__file__), "isnet.pth")
    weights = torch.load(weight_pth, map_location=device)
    model.load_state_dict(weights)
    model.to(device)
    model.eval()
    return model


class EndpointHandler():

    def __init__(self, path=""):
        self._model = get_model()

    def __call__(self, data: Dict[str, Any]) -> list[Dict[str, Any]]:
        inputs = data.pop("inputs", data)

        image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
        t = functional.pil_to_tensor(image).float().divide(255.0)
        arr = process_image(t)
        model = get_model()
        v = model(arr)[0]

        pred_val = v[0][0, :, :, :]
        ma = torch.max(pred_val)
        mi = torch.min(pred_val)
        pred_val = (pred_val - mi) / (ma - mi)

        msk = torch.gt(pred_val, 0.1)
        w = torch.where(msk, t, 1)
        w = torch.cat([w, msk], dim=0)

        img2 = functional.to_pil_image(torch.squeeze(w))


        stream = BytesIO()
        img2.save(stream, format="png")
        res = {"status": 200,
               "image": base64.b64encode(stream.getvalue()).decode("utf8")
               }
        return res


if __name__ == "__main__":
    h = EndpointHandler()
    v = h({})
    print(v)