zhang-ziang commited on
Commit
b03b419
·
1 Parent(s): c1fa1ed
Files changed (2) hide show
  1. app.py +70 -6
  2. requirements.txt +1 -1
app.py CHANGED
@@ -8,6 +8,9 @@ import os
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
 
 
 
11
 
12
  from huggingface_hub import hf_hub_download
13
  ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./')
@@ -30,6 +33,68 @@ dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
30
  print('weight loaded')
31
  val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def get_3angle(image):
35
 
@@ -80,10 +145,8 @@ def figure_to_img(fig):
80
  image = Image.open(buf).copy()
81
  return image
82
 
83
- # def generate_mutimodal(title, context, img):
84
- # return f"Title:{title}\nContext:{context}\n...{img}"
85
-
86
- def generate_mutimodal(img):
87
  angles = get_3angle(img)
88
 
89
  fig, ax = plt.subplots(figsize=(8, 8))
@@ -123,9 +186,10 @@ def generate_mutimodal(img):
123
 
124
  server = gr.Interface(
125
  flagging_mode='never',
126
- fn=generate_mutimodal,
127
  inputs=[
128
- gr.Image(height=512, width=512, label="upload your image")
 
129
  ],
130
  outputs=[
131
  gr.Image(height=512, width=512, label="result image"),
 
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
11
+ import rembg
12
+ from typing import Any
13
+
14
 
15
  from huggingface_hub import hf_hub_download
16
  ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./')
 
33
  print('weight loaded')
34
  val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
35
 
36
+ def background_preprocess(input_image, do_remove_background):
37
+
38
+ rembg_session = rembg.new_session() if do_remove_background else None
39
+
40
+ if do_remove_background:
41
+ input_image = remove_background(input_image, rembg_session)
42
+ input_image = resize_foreground(input_image, 0.85)
43
+
44
+ return input_image
45
+
46
+ def resize_foreground(
47
+ image: Image,
48
+ ratio: float,
49
+ ) -> Image:
50
+ image = np.array(image)
51
+ assert image.shape[-1] == 4
52
+ alpha = np.where(image[..., 3] > 0)
53
+ y1, y2, x1, x2 = (
54
+ alpha[0].min(),
55
+ alpha[0].max(),
56
+ alpha[1].min(),
57
+ alpha[1].max(),
58
+ )
59
+ # crop the foreground
60
+ fg = image[y1:y2, x1:x2]
61
+ # pad to square
62
+ size = max(fg.shape[0], fg.shape[1])
63
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
64
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
65
+ new_image = np.pad(
66
+ fg,
67
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
68
+ mode="constant",
69
+ constant_values=((0, 0), (0, 0), (0, 0)),
70
+ )
71
+
72
+ # compute padding according to the ratio
73
+ new_size = int(new_image.shape[0] / ratio)
74
+ # pad to size, double side
75
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
76
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
77
+ new_image = np.pad(
78
+ new_image,
79
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
80
+ mode="constant",
81
+ constant_values=((0, 0), (0, 0), (0, 0)),
82
+ )
83
+ new_image = Image.fromarray(new_image)
84
+ return new_image
85
+
86
+ def remove_background(image: Image,
87
+ rembg_session: Any = None,
88
+ force: bool = False,
89
+ **rembg_kwargs,
90
+ ) -> Image:
91
+ do_remove = True
92
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
93
+ do_remove = False
94
+ do_remove = do_remove or force
95
+ if do_remove:
96
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
97
+ return image
98
 
99
  def get_3angle(image):
100
 
 
145
  image = Image.open(buf).copy()
146
  return image
147
 
148
+ def infer_func(img, do_rm_bkg):
149
+ img = background_preprocess(img, do_rm_bkg)
 
 
150
  angles = get_3angle(img)
151
 
152
  fig, ax = plt.subplots(figsize=(8, 8))
 
186
 
187
  server = gr.Interface(
188
  flagging_mode='never',
189
+ fn=infer_func,
190
  inputs=[
191
+ gr.Image(height=512, width=512, label="upload your image"),
192
+ gr.Checkbox(label="Remove Background", value=True)
193
  ],
194
  outputs=[
195
  gr.Image(height=512, width=512, label="result image"),
requirements.txt CHANGED
@@ -5,4 +5,4 @@ pillow==10.2.0
5
  huggingface-hub==0.26.5
6
  gradio==5.9.0
7
  numpy==1.26.4
8
-
 
5
  huggingface-hub==0.26.5
6
  gradio==5.9.0
7
  numpy==1.26.4
8
+ rembg