anodev commited on
Commit
69f51b6
1 Parent(s): 8465efd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- os.system("wget https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt")
3
- os.system("pip install imageio")
4
  import cv2
5
  import paddlehub as hub
6
  import gradio as gr
@@ -9,9 +9,106 @@ from PIL import Image, ImageOps
9
  import numpy as np
10
  import imageio
11
  os.mkdir("data")
12
- os.rename("best.ckpt", "models/best.ckpt")
13
  os.mkdir("dataout")
14
  model = hub.Module(name='U2Net')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def infer(img,option):
16
  print(type(img))
17
  print(type(img["image"]))
@@ -29,7 +126,7 @@ def infer(img,option):
29
  im.save("./data/data_mask.png")
30
  else:
31
  imageio.imwrite("./data/data_mask.png", img["mask"])
32
- os.system('python predict.py model.path=/home/user/app/ indir=/home/user/app/data/ outdir=/home/user/app/dataout/ device=cpu')
33
  return "./dataout/data_mask.png","./data/data_mask.png"
34
 
35
  inputs = [gr.Image(tool="sketch", label="Input",type="numpy"),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")]
 
1
  import os
2
+ os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama.onnx")
3
+ os.system("pip install onnxruntime")
4
  import cv2
5
  import paddlehub as hub
6
  import gradio as gr
 
9
  import numpy as np
10
  import imageio
11
  os.mkdir("data")
 
12
  os.mkdir("dataout")
13
  model = hub.Module(name='U2Net')
14
+ import cv2
15
+ import numpy as np
16
+ import onnxruntime
17
+ import torch
18
+ from PIL import Image
19
+
20
+
21
+ # Source https://github.com/advimman/lama
22
+ def get_image(image):
23
+ if isinstance(image, Image.Image):
24
+ img = np.array(image)
25
+ elif isinstance(image, np.ndarray):
26
+ img = image.copy()
27
+ else:
28
+ raise Exception("Input image should be either PIL Image or numpy array!")
29
+
30
+ if img.ndim == 3:
31
+ img = np.transpose(img, (2, 0, 1)) # chw
32
+ elif img.ndim == 2:
33
+ img = img[np.newaxis, ...]
34
+
35
+ assert img.ndim == 3
36
+
37
+ img = img.astype(np.float32) / 255
38
+ return img
39
+
40
+
41
+ def ceil_modulo(x, mod):
42
+ if x % mod == 0:
43
+ return x
44
+ return (x // mod + 1) * mod
45
+
46
+
47
+ def scale_image(img, factor, interpolation=cv2.INTER_AREA):
48
+ if img.shape[0] == 1:
49
+ img = img[0]
50
+ else:
51
+ img = np.transpose(img, (1, 2, 0))
52
+
53
+ img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
54
+
55
+ if img.ndim == 2:
56
+ img = img[None, ...]
57
+ else:
58
+ img = np.transpose(img, (2, 0, 1))
59
+ return img
60
+
61
+
62
+ def pad_img_to_modulo(img, mod):
63
+ channels, height, width = img.shape
64
+ out_height = ceil_modulo(height, mod)
65
+ out_width = ceil_modulo(width, mod)
66
+ return np.pad(
67
+ img,
68
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
69
+ mode="symmetric",
70
+ )
71
+
72
+
73
+ def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
74
+ out_image = get_image(image)
75
+ out_mask = get_image(mask)
76
+
77
+ if scale_factor is not None:
78
+ out_image = scale_image(out_image, scale_factor)
79
+ out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
80
+
81
+ if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
82
+ out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
83
+ out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
84
+
85
+ out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
86
+ out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
87
+
88
+ out_mask = (out_mask > 0) * 1
89
+
90
+ return out_image, out_mask
91
+
92
+
93
+ def predict(jpg, msk):
94
+ sess_options = onnxruntime.SessionOptions()
95
+ model = onnxruntime.InferenceSession('lama.onnx', sess_options=sess_options)
96
+
97
+ image = Image.open(jpg).resize((512, 512))
98
+ mask = Image.open(msk).convert("L").resize((512, 512))
99
+
100
+ image, mask = prepare_img_and_mask(image, mask, 'cpu')
101
+ # Run the model
102
+ outputs = model.run(None, {'l_image_': image.numpy().astype(np.float32), 'l_mask_': mask.numpy().astype(np.float32)})
103
+
104
+ output = outputs[0][0]
105
+ # Postprocess the outputs
106
+ output = output.transpose(1, 2, 0)
107
+ output = output.astype(np.uint8)
108
+ output = Image.fromarray(output)
109
+ output.save("/home/user/app/dataout/data_mask.png")
110
+
111
+
112
  def infer(img,option):
113
  print(type(img))
114
  print(type(img["image"]))
 
126
  im.save("./data/data_mask.png")
127
  else:
128
  imageio.imwrite("./data/data_mask.png", img["mask"])
129
+ predict("./data/data.png", "./data/data_mask.png")
130
  return "./dataout/data_mask.png","./data/data_mask.png"
131
 
132
  inputs = [gr.Image(tool="sketch", label="Input",type="numpy"),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")]