CustomNet / utils /gradio_utils.py
jiangyzy's picture
Update utils/gradio_utils.py
3a54e1f verified
raw
history blame
969 Bytes
from ldm.util import load_and_preprocess
from carvekit.api.high import HiInterface
import spaces
def load_preprocess_model():
carvekit = HiInterface(object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
# device='cuda' if torch.cuda.is_available() else 'cpu',
device='cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
return carvekit
@spaces.GPU
def preprocess_image(models, input_im):
'''
:param input_im (PIL Image).
:return input_im (H, W, 3) array.
'''
input_im = load_and_preprocess(models, input_im)
return input_im