multimodalart HF staff commited on
Commit
7b0ac85
1 Parent(s): fc8546c

Update sketch_helper.py

Browse files
Files changed (1) hide show
  1. sketch_helper.py +22 -1
sketch_helper.py CHANGED
@@ -12,7 +12,7 @@ def get_high_freq_colors(image):
12
  high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
13
  return high_freq_colors
14
 
15
- def color_quantization(image, n_colors):
16
  # Get color histogram
17
  hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
18
  # Get most frequent colors
@@ -24,6 +24,27 @@ def color_quantization(image, n_colors):
24
  labels = np.argmin(dists, axis=1)
25
  return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def create_binary_matrix(img_arr, target_color):
28
  # Create mask of pixels with target color
29
  mask = np.all(img_arr == target_color, axis=-1)
 
12
  high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
13
  return high_freq_colors
14
 
15
+ def color_quantization_old(image, n_colors):
16
  # Get color histogram
17
  hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
18
  # Get most frequent colors
 
24
  labels = np.argmin(dists, axis=1)
25
  return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
26
 
27
+ def color_quantization(image, n_colors=8, rounds=1):
28
+ h, w = image.shape[:2]
29
+ samples = np.zeros([h*w,3], dtype=np.float32)
30
+ count = 0
31
+
32
+ for x in range(h):
33
+ for y in range(w):
34
+ samples[count] = image[x][y]
35
+ count += 1
36
+
37
+ compactness, labels, centers = cv2.kmeans(samples,
38
+ clusters,
39
+ None,
40
+ (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10000, 0.0001),
41
+ rounds,
42
+ cv2.KMEANS_RANDOM_CENTERS)
43
+
44
+ centers = np.uint8(centers)
45
+ res = centers[labels.flatten()]
46
+ return res.reshape((image.shape))
47
+
48
  def create_binary_matrix(img_arr, target_color):
49
  # Create mask of pixels with target color
50
  mask = np.all(img_arr == target_color, axis=-1)