frodos's picture
Fix handler
99e5048
raw
history blame contribute delete
No virus
1.68 kB
from typing import Dict, Any
from io import BytesIO
import base64
from model import ISNetDIS
import torch
import os
from PIL import Image
from torchvision.transforms import Compose, Normalize, functional
def process_image(image: torch.Tensor):
pipe = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = pipe(image)
return torch.unsqueeze(img, 0)
def get_model(device="cpu"):
model = ISNetDIS()
weight_pth = os.path.join(os.path.dirname(__file__), "isnet.pth")
weights = torch.load(weight_pth, map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()
return model
class EndpointHandler():
def __init__(self, path=""):
self._model = get_model()
def __call__(self, data: Dict[str, Any]) -> list[Dict[str, Any]]:
inputs = data.pop("inputs", data)
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
t = functional.pil_to_tensor(image).float().divide(255.0)
arr = process_image(t)
model = get_model()
v = model(arr)[0]
pred_val = v[0][0, :, :, :]
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val - mi) / (ma - mi)
msk = torch.gt(pred_val, 0.1)
w = torch.where(msk, t, 1)
w = torch.cat([w, msk], dim=0)
img2 = functional.to_pil_image(torch.squeeze(w))
stream = BytesIO()
img2.save(stream, format="png")
res = {"status": 200,
"image": base64.b64encode(stream.getvalue()).decode("utf8")
}
return res
if __name__ == "__main__":
h = EndpointHandler()
v = h({})
print(v)