ssboost commited on
Commit
91ea164
1 Parent(s): 96fbe3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -50,6 +50,9 @@ def load_img(source, output_type="pil"):
50
 
51
  def process(image):
52
  image_size = image.size
 
 
 
53
  input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU로 변경
54
  # Prediction
55
  with torch.no_grad():
@@ -57,8 +60,10 @@ def process(image):
57
  pred = preds[0].squeeze()
58
  pred_pil = transforms.ToPILImage()(pred)
59
  mask = pred_pil.resize(image_size)
60
- image.putalpha(mask)
61
- return image
 
 
62
 
63
  def fn(image):
64
  im = load_img(image, output_type="pil")
 
50
 
51
  def process(image):
52
  image_size = image.size
53
+ # RGBA 이미지를 RGB로 변환
54
+ if image.mode == 'RGBA':
55
+ image = image.convert('RGB')
56
  input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU로 변경
57
  # Prediction
58
  with torch.no_grad():
 
60
  pred = preds[0].squeeze()
61
  pred_pil = transforms.ToPILImage()(pred)
62
  mask = pred_pil.resize(image_size)
63
+ # 결과 이미지에 알파 채널 추가
64
+ result_image = image.copy()
65
+ result_image.putalpha(mask)
66
+ return result_image
67
 
68
  def fn(image):
69
  im = load_img(image, output_type="pil")