Rens Dimmendaal
init
c5d7102
raw
history blame contribute delete
No virus
2.53 kB
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeRegressor
import PIL
def img2df(img_array):
dim0_arr = (
np.arange(img_array.shape[0])
.reshape((-1, 1))
.repeat(img_array.shape[1], axis=1)
.flatten()
)
dim1_arr = (
np.arange(img_array.shape[1])
.reshape((1, -1))
.repeat(img_array.shape[0], axis=0)
.flatten()
)
df = pd.DataFrame({"dim0": dim0_arr, "dim1": dim1_arr})
values = img_array.reshape((img_array.shape[0] * img_array.shape[1], -1))
for col in range(values.shape[1]):
df[f"value{col}"] = values[:, col]
return df
def normalize(img):
return (img - img.min()) / (img.max() - img.min())
def df2xy(df):
x = df[[c for c in df if c.startswith("dim")]]
y = df[[c for c in df if c.startswith("value")]]
if len(y.columns) == 1:
y = y.values.reshape(-1)
return x, y
def tree_window(
img, add_cartesian=True, add_rotation=False, add_polar=False, depths=(2, 6)
):
df = img2df(img)
x_raw, y = df2xy(df)
sets = []
if add_cartesian:
sets.append(x_raw)
if add_rotation > 0:
# rotate
theta = np.radians(add_rotation)
c, s = np.cos(theta), np.sin(theta)
R = np.array(((c, -s), (s, c)))
xr = x_raw @ R
sets.append(xr)
if add_polar:
# polar
xp = x_raw.copy()
xp["dim0"] = np.sqrt(((x_raw - x_raw.mean()) ** 2).sum(axis=1))
xp["dim1"] = np.arctan2(x_raw["dim1"], x_raw["dim0"])
sets.append(xp)
x = pd.concat(sets, axis=1)
fig, axes = plt.subplots(ncols=len(depths) + 1, figsize=(36, 36))
for ax, depth in zip(axes, depths):
model = DecisionTreeRegressor(max_depth=depth)
model.fit(x, y)
pred = model.predict(x).reshape(img.shape)
if len(y.shape) == 2:
ax.imshow(pred)
else:
ax.imshow(pred, cmap="gray")
ax.set_axis_off()
if len(y.shape) == 2:
axes[-1].imshow(img)
else:
axes[-1].imshow(img, cmap="gray")
axes[-1].set_axis_off()
return fig
def treeify(img, max_depth):
img_array = np.array(img)
df = img2df(img_array)
x, y = df2xy(df)
model = DecisionTreeRegressor(max_depth=max_depth)
model.fit(x, y)
pred = PIL.Image.fromarray(
model.predict(x).reshape(img_array.shape).round().astype("uint8")
)
score = model.score(x, y)
return pred, score