not-lain commited on
Commit
95b6884
1 Parent(s): 74fddf5

endpoint support for the API.

Browse files
Files changed (2) hide show
  1. handler.py +54 -0
  2. requirements.txt +13 -0
handler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+ from io import BytesIO
4
+ import torch
5
+ from loadimg import load_img
6
+ from torchvision import transforms
7
+ from transformers import AutoModelForImageSegmentation
8
+
9
+ torch.set_float32_matmul_precision(["high", "highest"][0])
10
+
11
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
12
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
13
+ )
14
+ birefnet.to("cuda")
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ transform_image = transforms.Compose(
19
+ [
20
+ transforms.Resize((1024, 1024)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
+ ]
24
+ )
25
+
26
+ class EndpointHandler():
27
+ def __init__(self, path=""):
28
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
29
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
30
+ )
31
+ self.birefnet.to(device)
32
+
33
+ def __call__(self, data: Dict[str, Any]):
34
+ """
35
+ data args:
36
+ inputs (:obj: `str`)
37
+ date (:obj: `str`)
38
+ Return:
39
+ A :obj:`list` | `dict`: will be serialized and returned
40
+ """
41
+ image = load_img(data["inputs"]).convert("RGB")
42
+ image_size = image.size
43
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
44
+ # Prediction
45
+ with torch.no_grad():
46
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
47
+ pred = preds[0].squeeze()
48
+ pred_pil = transforms.ToPILImage()(pred)
49
+ mask = pred_pil.resize(image_size)
50
+ image.putalpha(mask)
51
+ # buffered = BytesIO()
52
+ # image.save(buffered, format="JPEG")
53
+ # img_str = base64.b64encode(buffered.getvalue())
54
+ return image
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ loadimg
2
+ spaces
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ tqdm
7
+ timm
8
+ prettytable
9
+ scipy
10
+ scikit-image
11
+ kornia
12
+ transformers
13
+ huggingface_hub