Spaces:
Running
on
Zero
Running
on
Zero
bluestyle97
commited on
Commit
•
5bb9a7d
1
Parent(s):
e9133ba
Update freesplatter/utils/infer_util.py
Browse files- freesplatter/utils/infer_util.py +32 -17
freesplatter/utils/infer_util.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import importlib
|
3 |
import imageio
|
4 |
import torch
|
|
|
5 |
import numpy as np
|
6 |
import PIL.Image
|
7 |
from PIL import Image
|
@@ -67,10 +68,36 @@ def get_obj_from_str(string, reload=False):
|
|
67 |
# return image
|
68 |
|
69 |
|
70 |
-
@torch.inference_mode()
|
71 |
-
def remove_background(
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
force: bool = False,
|
75 |
**rembg_kwargs,
|
76 |
) -> PIL.Image.Image:
|
@@ -79,19 +106,7 @@ def remove_background(
|
|
79 |
do_remove = False
|
80 |
do_remove = do_remove or force
|
81 |
if do_remove:
|
82 |
-
|
83 |
-
transforms.Resize((1024, 1024)),
|
84 |
-
transforms.ToTensor(),
|
85 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
86 |
-
])
|
87 |
-
image = image.convert('RGB')
|
88 |
-
input_images = transform_image(image).unsqueeze(0).to(rembg.device)
|
89 |
-
with torch.no_grad():
|
90 |
-
preds = rembg(input_images)[-1].sigmoid().cpu()
|
91 |
-
pred = preds[0].squeeze()
|
92 |
-
pred_pil = transforms.ToPILImage()(pred)
|
93 |
-
mask = pred_pil.resize(image.size)
|
94 |
-
image.putalpha(mask)
|
95 |
return image
|
96 |
|
97 |
|
|
|
2 |
import importlib
|
3 |
import imageio
|
4 |
import torch
|
5 |
+
import rembg
|
6 |
import numpy as np
|
7 |
import PIL.Image
|
8 |
from PIL import Image
|
|
|
68 |
# return image
|
69 |
|
70 |
|
71 |
+
# @torch.inference_mode()
|
72 |
+
# def remove_background(
|
73 |
+
# image: PIL.Image.Image,
|
74 |
+
# rembg: Any = None,
|
75 |
+
# force: bool = False,
|
76 |
+
# **rembg_kwargs,
|
77 |
+
# ) -> PIL.Image.Image:
|
78 |
+
# do_remove = True
|
79 |
+
# if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
80 |
+
# do_remove = False
|
81 |
+
# do_remove = do_remove or force
|
82 |
+
# if do_remove:
|
83 |
+
# transform_image = transforms.Compose([
|
84 |
+
# transforms.Resize((1024, 1024)),
|
85 |
+
# transforms.ToTensor(),
|
86 |
+
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
87 |
+
# ])
|
88 |
+
# image = image.convert('RGB')
|
89 |
+
# input_images = transform_image(image).unsqueeze(0).to(rembg.device)
|
90 |
+
# with torch.no_grad():
|
91 |
+
# preds = rembg(input_images)[-1].sigmoid().cpu()
|
92 |
+
# pred = preds[0].squeeze()
|
93 |
+
# pred_pil = transforms.ToPILImage()(pred)
|
94 |
+
# mask = pred_pil.resize(image.size)
|
95 |
+
# image.putalpha(mask)
|
96 |
+
# return image
|
97 |
+
|
98 |
+
|
99 |
+
def remove_background(image: PIL.Image.Image,
|
100 |
+
rembg_session: Any = None,
|
101 |
force: bool = False,
|
102 |
**rembg_kwargs,
|
103 |
) -> PIL.Image.Image:
|
|
|
106 |
do_remove = False
|
107 |
do_remove = do_remove or force
|
108 |
if do_remove:
|
109 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
return image
|
111 |
|
112 |
|