anodev commited on
Commit
f84a891
·
verified ·
1 Parent(s): 18bdd95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -69
app.py CHANGED
@@ -1,90 +1,139 @@
1
  import os
2
- import imageio
3
- from PIL import Image
 
 
4
  import gradio as gr
 
 
 
 
 
 
 
5
  import cv2
6
  import numpy as np
7
- import paddlehub as hub
8
  import onnxruntime
 
 
 
 
9
 
10
- # Download and setup models
11
- os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx")
12
- os.system("pip install onnxruntime imageio")
13
- os.makedirs("data", exist_ok=True)
14
- os.makedirs("dataout", exist_ok=True)
 
 
 
15
 
16
- # Load LaMa ONNX model
17
- sess_options = onnxruntime.SessionOptions()
18
- lama_model = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options)
 
19
 
20
- # Load U^2-Net model for automatic masking
21
- u2net_model = hub.Module(name='U2Net')
22
 
23
- # --- Helper Functions ---
 
24
 
25
- def prepare_image(image, target_size=(512, 512)):
26
- """Resizes and preprocesses image for LaMa model."""
27
- if isinstance(image, Image.Image):
28
- image = image.resize(target_size)
29
- image = np.array(image)
30
- elif isinstance(image, np.ndarray):
31
- image = cv2.resize(image, target_size)
 
 
 
32
  else:
33
- raise ValueError("Input image should be either PIL Image or numpy array!")
34
-
35
- # Normalize to [0, 1] and convert to CHW format
36
- image = image.astype(np.float32) / 255.0
37
- if image.ndim == 3:
38
- image = np.transpose(image, (2, 0, 1))
39
- elif image.ndim == 2:
40
- image = image[np.newaxis, ...]
41
- return image[np.newaxis, ...] # Add batch dimension
42
-
43
- def generate_mask(image, method="automatic"):
44
- """Generates mask from image using U^2-Net or user input."""
45
- if method == "automatic":
46
- input_size = 320 # Adjust based on U^2-Net requirements
47
- result = u2net_model.Segmentation(
48
- images=[cv2.cvtColor(image, cv2.COLOR_RGB2BGR)],
49
- paths=None,
50
- batch_size=1,
51
- input_size=input_size,
52
- output_dir='output',
53
- visualization=False
54
- )
55
- mask = Image.fromarray(result[0]['mask'])
56
- mask = mask.resize((512, 512)) # Resize to match LaMa input
57
- mask.save("./data/data_mask.png")
58
- else: # "manual"
59
- mask = imageio.imread("./data/data_mask.png")
60
- mask = Image.fromarray(mask).convert("L") # Ensure grayscale
61
- mask = mask.resize((512, 512))
62
- return prepare_image(mask, (512, 512))
63
-
64
- def inpaint_image(image, mask):
65
- """Performs inpainting using the LaMa model."""
66
- outputs = lama_model.run(None, {'image': image, 'mask': mask})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  output = outputs[0][0]
 
68
  output = output.transpose(1, 2, 0)
69
- output = (output * 255).astype(np.uint8)
70
- return Image.fromarray(output)
 
 
 
71
 
72
- # --- Gradio Interface ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- def process_image(input_image, mask_option):
75
- """Main function for Gradio interface."""
76
- imageio.imwrite("./data/data.png", input_image)
77
 
78
- image = prepare_image(input_image)
79
- mask = generate_mask(input_image, method=mask_option)
80
-
81
- inpainted_image = inpaint_image(image, mask)
82
- inpainted_image = inpainted_image.resize(Image.open("./data/data.png").size)
83
- inpainted_image.save("./dataout/data_mask.png")
84
- return "./dataout/data_mask.png", "./data/data_mask.png"
85
 
86
  iface = gr.Interface(
87
- fn=process_image,
88
  inputs=[
89
  gr.Image(label="Input Image", type="numpy"),
90
  gr.Radio(choices=["automatic", ],
 
1
  import os
2
+ os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx")
3
+ os.system("pip install onnxruntime imageio")
4
+ import cv2
5
+ import paddlehub as hub
6
  import gradio as gr
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ import numpy as np
10
+ import imageio
11
+ os.mkdir("data")
12
+ os.mkdir("dataout")
13
+ model = hub.Module(name='U2Net')
14
  import cv2
15
  import numpy as np
 
16
  import onnxruntime
17
+ import torch
18
+ from PIL import Image
19
+ sess_options = onnxruntime.SessionOptions()
20
+ rmodel = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options)
21
 
22
+ # Source https://github.com/advimman/lama
23
+ def get_image(image):
24
+ if isinstance(image, Image.Image):
25
+ img = np.array(image)
26
+ elif isinstance(image, np.ndarray):
27
+ img = image.copy()
28
+ else:
29
+ raise Exception("Input image should be either PIL Image or numpy array!")
30
 
31
+ if img.ndim == 3:
32
+ img = np.transpose(img, (2, 0, 1)) # chw
33
+ elif img.ndim == 2:
34
+ img = img[np.newaxis, ...]
35
 
36
+ assert img.ndim == 3
 
37
 
38
+ img = img.astype(np.float32) / 255
39
+ return img
40
 
41
+
42
+ def ceil_modulo(x, mod):
43
+ if x % mod == 0:
44
+ return x
45
+ return (x // mod + 1) * mod
46
+
47
+
48
+ def scale_image(img, factor, interpolation=cv2.INTER_AREA):
49
+ if img.shape[0] == 1:
50
+ img = img[0]
51
  else:
52
+ img = np.transpose(img, (1, 2, 0))
53
+
54
+ img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
55
+
56
+ if img.ndim == 2:
57
+ img = img[None, ...]
58
+ else:
59
+ img = np.transpose(img, (2, 0, 1))
60
+ return img
61
+
62
+
63
+ def pad_img_to_modulo(img, mod):
64
+ channels, height, width = img.shape
65
+ out_height = ceil_modulo(height, mod)
66
+ out_width = ceil_modulo(width, mod)
67
+ return np.pad(
68
+ img,
69
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
70
+ mode="symmetric",
71
+ )
72
+
73
+
74
+ def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
75
+ out_image = get_image(image)
76
+ out_mask = get_image(mask)
77
+
78
+ if scale_factor is not None:
79
+ out_image = scale_image(out_image, scale_factor)
80
+ out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
81
+
82
+ if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
83
+ out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
84
+ out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
85
+
86
+ out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
87
+ out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
88
+
89
+ out_mask = (out_mask > 0) * 1
90
+
91
+ return out_image, out_mask
92
+
93
+
94
+ def predict(jpg, msk):
95
+
96
+
97
+ imagex = Image.open(jpg)
98
+ mask = Image.open(msk).convert("L")
99
+
100
+ image, mask = prepare_img_and_mask(imagex.resize((512, 512)), mask.resize((512, 512)), 'cpu')
101
+ # Run the model
102
+ outputs = rmodel.run(None, {'image': image.numpy().astype(np.float32), 'mask': mask.numpy().astype(np.float32)})
103
+
104
  output = outputs[0][0]
105
+ # Postprocess the outputs
106
  output = output.transpose(1, 2, 0)
107
+ output = output.astype(np.uint8)
108
+ output = Image.fromarray(output)
109
+ output = output.resize(imagex.size)
110
+ output.save("/home/user/app/dataout/data_mask.png")
111
+
112
 
113
+ def infer(img,option):
114
+ print(type(img))
115
+ print(type(img["image"]))
116
+ print(type(img["mask"]))
117
+ imageio.imwrite("./data/data.png", img["image"])
118
+ if option == "automatic (U2net)":
119
+ result = model.Segmentation(
120
+ images=[cv2.cvtColor(img["image"], cv2.COLOR_RGB2BGR)],
121
+ paths=None,
122
+ batch_size=1,
123
+ input_size=320,
124
+ output_dir='output',
125
+ visualization=True)
126
+ im = Image.fromarray(result[0]['mask'])
127
+ im.save("./data/data_mask.png")
128
+ else:
129
+ imageio.imwrite("./data/data_mask.png", img["mask"])
130
+ predict("./data/data.png", "./data/data_mask.png")
131
+ return "./dataout/data_mask.png","./data/data_mask.png"
132
 
 
 
 
133
 
 
 
 
 
 
 
 
134
 
135
  iface = gr.Interface(
136
+ fn=infer,
137
  inputs=[
138
  gr.Image(label="Input Image", type="numpy"),
139
  gr.Radio(choices=["automatic", ],