nopperl commited on
Commit
39c0f4e
1 Parent(s): 3455245

Add application

Browse files
Files changed (6) hide show
  1. README.md +12 -0
  2. app.py +30 -0
  3. examples/01.png +0 -0
  4. examples/02.png +0 -0
  5. onnx_inference.py +245 -0
  6. requirements.txt +7 -0
README.md CHANGED
@@ -8,6 +8,18 @@ sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ preload_from_hub:
12
+ - nopperl/marked-lineart-vectorizer model.onnx
13
+ datasets:
14
+ - kmewhort/tu-berlin-svgs
15
+ tags:
16
+ - image-vectorization
17
+ - sketch
18
+ - sketch-synthesis
19
+ - svg
20
+ - vector-image
21
+ - line-drawing
22
+ - line-art
23
  ---
24
 
25
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import basename, splitext
2
+
3
+ import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ from onnx_inference import vectorize_image
7
+
8
+
9
+ MODEL_PATH = hf_hub_download("nopperl/marked-lineart-vectorizer", "model.onnx")
10
+
11
+
12
+ def predict(input_image_path, threshold, stroke_width):
13
+ output_filepath = splitext(basename(input_image_path))[0] + ".svg"
14
+ for recons_img in vectorize_image(input_image_path, model=MODEL_PATH, output=output_filepath, threshold_ratio=threshold, stroke_width=stroke_width):
15
+ yield recons_img
16
+ yield output_filepath
17
+
18
+
19
+ interface = gr.Interface(
20
+ predict,
21
+ inputs=[gr.Image(sources="upload", type="filepath"), gr.Slider(minimum=0.1, maximum=0.9, value=0.1, label="threshold"), gr.Slider(minimum=0.1, maximum=4.0, value=0.512, label="stroke_width")],
22
+ outputs=gr.Image(),
23
+ description="Demo for a model that converts raster line-art images into vector images iteratively. The model is trained on black-and-white line-art images, hence it won't work with other images. Inference time will be quite slow due to a lack of GPU resources. More information at https://github.com/nopperl/marked-lineart-vectorization.",
24
+ examples = [
25
+ ["examples/01.png", 0.1, 0.512],
26
+ ["examples/02.png", 0.1, 0.512]
27
+ ],
28
+ analytics_enabled=False
29
+ )
30
+ interface.launch()
examples/01.png ADDED
examples/02.png ADDED
onnx_inference.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from argparse import ArgumentParser
3
+ from io import BytesIO
4
+ from os import listdir, makedirs
5
+ from os.path import basename, isdir, join, splitext
6
+ from random import randint
7
+ from typing import Union
8
+
9
+ from cairosvg import svg2png
10
+ import numpy as np
11
+ from imageio.v3 import imread, imwrite
12
+ from skimage.transform import rescale
13
+ from svgpathtools import CubicBezier, Line, QuadraticBezier, disvg, wsvg
14
+
15
+ import onnx
16
+ import onnxruntime as ort
17
+
18
+
19
+ def raster_bezier_hard(all_points, image_width=128, image_height=128, stroke_width=2., colors=None, white_background=True, mark=None):
20
+ if colors is None:
21
+ colors = [[0., 0., 0., 1.]] * len(all_points)
22
+ elif colors is list and colors[0] is not list:
23
+ colors = [colors] * len(all_points)
24
+ else:
25
+ colors = np.array(colors)
26
+ colors[:, :3] *= 255
27
+ colors = ["rgb(" + ",".join(map(str, color[:3])) + ")" for color in colors]
28
+ background_color = "white" if white_background else None
29
+ all_points = all_points + 0
30
+ all_points[:, :, 0] *= image_width
31
+ all_points[:, :, 1] *= image_height
32
+ bezier_curves = [numpy_to_bezier(points) for points in all_points]
33
+ attributes = [{"stroke": colors[i], "stroke-width": str(stroke_width), "fill": "none"} for i in range(len(bezier_curves))]
34
+ if mark is not None:
35
+ mark = mark + 0
36
+ mark[0] *= image_width
37
+ mark[1] *= image_height
38
+ mark_points = np.vstack([mark - stroke_width, mark + stroke_width])
39
+ mark_path = numpy_to_bezier(mark_points)
40
+ mark_attr = {"stroke": "blue", "stroke-width": str(stroke_width * 2), "fill": "blue"}
41
+ bezier_curves.append(mark_path)
42
+ attributes.append(mark_attr)
43
+ svg_attributes = {"width": f"{image_width}px", "height": f"{image_height}px"}
44
+ svg_string = disvg(bezier_curves, attributes=attributes, svg_attributes=svg_attributes, paths2Drawing=True).tostring()
45
+ png_string = svg2png(bytestring=svg_string, background_color=background_color)
46
+ image = imread(BytesIO(png_string), extension=".png")
47
+ output = image.astype("float32")
48
+ output /= 255
49
+ output = np.moveaxis(output, 2, 0)
50
+ return output, all_points
51
+
52
+ def diff_remaining_img(raster_img: np.ndarray, recons_img: np.ndarray):
53
+ remaining_img = raster_img.copy()
54
+ tmp_remaining_img = remaining_img.copy()
55
+ tmp_remaining_img[tmp_remaining_img < 1] = 0.
56
+ recons_img[recons_img < 1] = 0.
57
+ same_mask = (tmp_remaining_img == recons_img).copy()
58
+ remaining_img[same_mask] = 1
59
+ return remaining_img
60
+
61
+
62
+ def place_point_on_img(image, point):
63
+ if np.any(point == point.astype(int)):
64
+ point_idx_start = point.astype(int)
65
+ point_idx_end = point.astype(int) + 1
66
+ else:
67
+ point_idx_start = np.floor(point).astype(int)
68
+ point_idx_end = np.ceil(point).astype(int)
69
+ if image.shape[0] == 3:
70
+ image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0
71
+ image[1, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0
72
+ image[2, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 1
73
+ else:
74
+ image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0.5
75
+ return image
76
+
77
+
78
+ def rgb_to_grayscale(image: np.ndarray):
79
+ image = image[0] * .2989 + image[1] *.587 + image[2] *.114
80
+ return image
81
+
82
+
83
+ def sample_black_pixel(image: np.ndarray):
84
+ image = rgb_to_grayscale(image.copy())
85
+ black_indices = np.argwhere(~np.isclose(image, np.ones_like(image, dtype="float32"), atol=0.5) != 0)
86
+ black_idx = black_indices[randint(0, len(black_indices) - 1)].astype("float32")
87
+ black_idx[0] /= image.shape[0]
88
+ black_idx[1] /= image.shape[1]
89
+ black_idx = black_idx[[1, 0]]
90
+ return black_idx
91
+
92
+
93
+ def numpy_to_bezier(points: np.ndarray):
94
+ if len(points) == 2:
95
+ return Line(*(complex(point[0], point[1]) for point in points))
96
+ elif len(points) == 3:
97
+ return QuadraticBezier(*(complex(point[0], point[1]) for point in points))
98
+ elif len(points) == 4:
99
+ return CubicBezier(*(complex(point[0], point[1]) for point in points))
100
+
101
+
102
+ def center_on_point(image, point, new_width=None, new_height=None):
103
+ _, height, width = image.shape
104
+ if new_width is None:
105
+ new_width = width
106
+ if new_height is None:
107
+ new_height = height
108
+ half_width = round(width / 2)
109
+ half_height = round(height / 2)
110
+ point = point.copy()
111
+ point[0] *= width
112
+ point[1] *= height
113
+ point = point.round().astype(int)
114
+ top=half_height - (half_height - point[1])
115
+ left=half_width - (half_width - point[0])
116
+ padded = np.pad(image, ((0, 0), (half_height, half_height), (half_width, half_width)), constant_values=1)
117
+ cropped = padded[:, top:top+new_height, left:left+new_width]
118
+ return cropped
119
+
120
+
121
+ def reverse_center_on_point(paths, point):
122
+ for i in range(len(paths)):
123
+ paths[i, :, 0] -= 0.5 - point[i, 0]
124
+ paths[i, :, 1] -= 0.5 - point[i, 1]
125
+
126
+
127
+ def save_as_svg(curves: np.ndarray, filename, img_width, img_height, stroke_width=2.0):
128
+ svg_paths = [numpy_to_bezier(curve) for curve in curves]
129
+ output_attributes = [{"stroke": "black", "stroke-width": stroke_width, "stroke-linecap": "round", "fill": "none"}] * len(svg_paths)
130
+ svg_attributes = {"width": f"{img_width}px", "height": f"{img_height}px"}
131
+ wsvg(svg_paths, attributes=output_attributes, svg_attributes=svg_attributes, filename=filename)
132
+
133
+
134
+ def save_as_png(filename: str, image: np.ndarray):
135
+ image = np.moveaxis(image.copy(), 0, 2)
136
+ image *= 255
137
+ imwrite(filename, image.round().astype("uint8"))
138
+
139
+
140
+ def setup_model(model_path):
141
+ model = onnx.load(model_path)
142
+ onnx.checker.check_model(model)
143
+ ort_sess = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
144
+ return ort_sess
145
+
146
+
147
+ def vectorize_image(input_image_path, model: Union[str, ort.InferenceSession], output=None, threshold_ratio=0.1, stroke_width=0.512, width=512, height=512, binarization_threshold=0, force_grayscale=False):
148
+ if type(model) is str:
149
+ ort_sess = setup_model(model)
150
+ elif type(model) is ort.InferenceSession:
151
+ ort_sess = model
152
+ else:
153
+ raise ValueError("Invalid value for the model argument")
154
+
155
+ # Get dimensions expected by the model
156
+ _, channels, height, width = ort_sess.get_inputs()[0].shape
157
+ input_image = imread(input_image_path, pilmode="RGB") / 255
158
+ original_height, original_width, _ = input_image.shape
159
+ # scale and white pad image to dimensions expected by the model
160
+ if original_height >= original_width:
161
+ scale = height / original_height
162
+ else:
163
+ scale = width / original_width
164
+ print(f"Rescale factor: {scale}")
165
+ input_image = rescale(input_image, scale, channel_axis=2, order=5)
166
+ scaled_height, scaled_width = input_image.shape[:2]
167
+ raster_img = np.ones((height, width, channels), dtype="float32")
168
+ raster_img[:input_image.shape[0], :input_image.shape[1]] = input_image
169
+ # convert CHW
170
+ raster_img = np.moveaxis(raster_img, 2, 0)
171
+ if binarization_threshold > 0:
172
+ raster_img[raster_img < binarization_threshold] = 0.
173
+ width = raster_img.shape[2]
174
+ height = raster_img.shape[1]
175
+ curve_pixels = (raster_img < .5).sum()
176
+ threshold = curve_pixels * threshold_ratio
177
+ print(f"Reconstruction candidate pixels: {curve_pixels}")
178
+ print(f"Reconstruction threshold: {threshold.astype(int)}")
179
+ recons_points = None
180
+ recons_img = np.ones_like(raster_img, dtype="float32")
181
+ remaining_img = raster_img.copy()
182
+ while (remaining_img < .5).sum() > threshold:
183
+ remaining_img = diff_remaining_img(raster_img, recons_img)
184
+ try:
185
+ mark = sample_black_pixel(remaining_img)
186
+ except ValueError:
187
+ break
188
+ centered_img = remaining_img.copy()
189
+ mark_real = mark.copy()
190
+ mark_real[0] *= width
191
+ mark_real[1] *= height
192
+ centered_img = place_point_on_img(centered_img, mark_real)
193
+ centered_img = center_on_point(centered_img, mark)
194
+ result = ort_sess.run(None, {"marked_raster_image": np.expand_dims(centered_img, 0)})
195
+ points = result[0]
196
+ reverse_center_on_point(points, np.expand_dims(mark, 0))
197
+ points = np.expand_dims(points, 1)
198
+ if recons_points is None:
199
+ recons_points = points
200
+ else:
201
+ recons_points = np.concatenate((recons_points, points), axis=1)
202
+ recons_img, _ = raster_bezier_hard(recons_points.squeeze(0), image_width=width, image_height=height, stroke_width=stroke_width)
203
+ yield np.moveaxis(recons_img, 0, 2)
204
+
205
+ output_filepath = splitext(basename(input_image_path))[0] + ".svg"
206
+ if output is not None:
207
+ if isdir(output):
208
+ makedirs(output, exist_ok=True)
209
+ output_filepath = join(output, output_filepath)
210
+ elif type(output) is str and output.endswith(".svg"):
211
+ output_filepath = output
212
+ recons_points = recons_points.squeeze(0)
213
+ recons_points[:, :, 0] *= width * (1 / scale)
214
+ recons_points[:, :, 1] *= height * (1 / scale)
215
+ save_as_svg(recons_points, output_filepath, original_width, original_height, stroke_width=stroke_width)
216
+
217
+
218
+ def main():
219
+ parser = ArgumentParser(description="Inference script for the marked curve reconstruction model in ONNX format.")
220
+ parser.add_argument("model", metavar="FIlE", help="path to the *.onnx file")
221
+ parser.add_argument("-i", "--input_images", nargs="*", metavar="FILE", help="one or multiple paths to raster images that should be vectorized.")
222
+ parser.add_argument("-d", "--input_dir", metavar="DIR", help="path to a directory of raster images that should be vectorized.")
223
+ parser.add_argument("-o", "--output", help="optional output directory or file")
224
+ parser.add_argument("--threshold_ratio", "-t", default=0.1, type=float, help="The ratio of black pixels which need to be reconstructed before the algorithm terminates")
225
+ parser.add_argument("--stroke_width", "-r", default=0.512, type=float, help="stroke width if it should be different from the one specified in the model")
226
+ parser.add_argument("--seed", "-s", default=1234, help="Fixed random number generation seed. Set to negative number to deactivate")
227
+ parser.add_argument("-b", "--binarization_threshold", default=0., type=float, help="Set to a value in (0,1) to binarize the image.")
228
+
229
+ args = parser.parse_args()
230
+
231
+ if args.seed >= 0:
232
+ np.random.seed(args.seed)
233
+ if args.input_images is not None:
234
+ input_images = args.input_images
235
+ elif args.input_dir is not None and isdir(args.input_dir):
236
+ input_images = [join(args.input_dir, f) for f in listdir(args.input_dir)]
237
+ else:
238
+ print("-i or -d need to be passed")
239
+ exit(1)
240
+ for input_image in input_images:
241
+ vectorize_image(input_image, args.model, output=args.output, threshold_ratio=args.threshold_ratio, stroke_width=args.stroke_width, binarization_threshold=args.binarization_threshold, force_grayscale=False)
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ onnx==1.14.0
2
+ onnxruntime==1.15.1
3
+ imageio==2.31.1
4
+ svgpathtools==1.6.1
5
+ cairosvg==2.7.0
6
+ scikit-image==0.21.0
7
+ huggingface_hub==0.20.1