bluestyle97 commited on
Commit
5bb9a7d
1 Parent(s): e9133ba

Update freesplatter/utils/infer_util.py

Browse files
Files changed (1) hide show
  1. freesplatter/utils/infer_util.py +32 -17
freesplatter/utils/infer_util.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import importlib
3
  import imageio
4
  import torch
 
5
  import numpy as np
6
  import PIL.Image
7
  from PIL import Image
@@ -67,10 +68,36 @@ def get_obj_from_str(string, reload=False):
67
  # return image
68
 
69
 
70
- @torch.inference_mode()
71
- def remove_background(
72
- image: PIL.Image.Image,
73
- rembg: Any = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  force: bool = False,
75
  **rembg_kwargs,
76
  ) -> PIL.Image.Image:
@@ -79,19 +106,7 @@ def remove_background(
79
  do_remove = False
80
  do_remove = do_remove or force
81
  if do_remove:
82
- transform_image = transforms.Compose([
83
- transforms.Resize((1024, 1024)),
84
- transforms.ToTensor(),
85
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
86
- ])
87
- image = image.convert('RGB')
88
- input_images = transform_image(image).unsqueeze(0).to(rembg.device)
89
- with torch.no_grad():
90
- preds = rembg(input_images)[-1].sigmoid().cpu()
91
- pred = preds[0].squeeze()
92
- pred_pil = transforms.ToPILImage()(pred)
93
- mask = pred_pil.resize(image.size)
94
- image.putalpha(mask)
95
  return image
96
 
97
 
 
2
  import importlib
3
  import imageio
4
  import torch
5
+ import rembg
6
  import numpy as np
7
  import PIL.Image
8
  from PIL import Image
 
68
  # return image
69
 
70
 
71
+ # @torch.inference_mode()
72
+ # def remove_background(
73
+ # image: PIL.Image.Image,
74
+ # rembg: Any = None,
75
+ # force: bool = False,
76
+ # **rembg_kwargs,
77
+ # ) -> PIL.Image.Image:
78
+ # do_remove = True
79
+ # if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
80
+ # do_remove = False
81
+ # do_remove = do_remove or force
82
+ # if do_remove:
83
+ # transform_image = transforms.Compose([
84
+ # transforms.Resize((1024, 1024)),
85
+ # transforms.ToTensor(),
86
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
87
+ # ])
88
+ # image = image.convert('RGB')
89
+ # input_images = transform_image(image).unsqueeze(0).to(rembg.device)
90
+ # with torch.no_grad():
91
+ # preds = rembg(input_images)[-1].sigmoid().cpu()
92
+ # pred = preds[0].squeeze()
93
+ # pred_pil = transforms.ToPILImage()(pred)
94
+ # mask = pred_pil.resize(image.size)
95
+ # image.putalpha(mask)
96
+ # return image
97
+
98
+
99
+ def remove_background(image: PIL.Image.Image,
100
+ rembg_session: Any = None,
101
  force: bool = False,
102
  **rembg_kwargs,
103
  ) -> PIL.Image.Image:
 
106
  do_remove = False
107
  do_remove = do_remove or force
108
  if do_remove:
109
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
110
  return image
111
 
112