Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
from sklearn.tree import DecisionTreeRegressor | |
import PIL | |
def img2df(img_array): | |
dim0_arr = ( | |
np.arange(img_array.shape[0]) | |
.reshape((-1, 1)) | |
.repeat(img_array.shape[1], axis=1) | |
.flatten() | |
) | |
dim1_arr = ( | |
np.arange(img_array.shape[1]) | |
.reshape((1, -1)) | |
.repeat(img_array.shape[0], axis=0) | |
.flatten() | |
) | |
df = pd.DataFrame({"dim0": dim0_arr, "dim1": dim1_arr}) | |
values = img_array.reshape((img_array.shape[0] * img_array.shape[1], -1)) | |
for col in range(values.shape[1]): | |
df[f"value{col}"] = values[:, col] | |
return df | |
def normalize(img): | |
return (img - img.min()) / (img.max() - img.min()) | |
def df2xy(df): | |
x = df[[c for c in df if c.startswith("dim")]] | |
y = df[[c for c in df if c.startswith("value")]] | |
if len(y.columns) == 1: | |
y = y.values.reshape(-1) | |
return x, y | |
def tree_window( | |
img, add_cartesian=True, add_rotation=False, add_polar=False, depths=(2, 6) | |
): | |
df = img2df(img) | |
x_raw, y = df2xy(df) | |
sets = [] | |
if add_cartesian: | |
sets.append(x_raw) | |
if add_rotation > 0: | |
# rotate | |
theta = np.radians(add_rotation) | |
c, s = np.cos(theta), np.sin(theta) | |
R = np.array(((c, -s), (s, c))) | |
xr = x_raw @ R | |
sets.append(xr) | |
if add_polar: | |
# polar | |
xp = x_raw.copy() | |
xp["dim0"] = np.sqrt(((x_raw - x_raw.mean()) ** 2).sum(axis=1)) | |
xp["dim1"] = np.arctan2(x_raw["dim1"], x_raw["dim0"]) | |
sets.append(xp) | |
x = pd.concat(sets, axis=1) | |
fig, axes = plt.subplots(ncols=len(depths) + 1, figsize=(36, 36)) | |
for ax, depth in zip(axes, depths): | |
model = DecisionTreeRegressor(max_depth=depth) | |
model.fit(x, y) | |
pred = model.predict(x).reshape(img.shape) | |
if len(y.shape) == 2: | |
ax.imshow(pred) | |
else: | |
ax.imshow(pred, cmap="gray") | |
ax.set_axis_off() | |
if len(y.shape) == 2: | |
axes[-1].imshow(img) | |
else: | |
axes[-1].imshow(img, cmap="gray") | |
axes[-1].set_axis_off() | |
return fig | |
def treeify(img, max_depth): | |
img_array = np.array(img) | |
df = img2df(img_array) | |
x, y = df2xy(df) | |
model = DecisionTreeRegressor(max_depth=max_depth) | |
model.fit(x, y) | |
pred = PIL.Image.fromarray( | |
model.predict(x).reshape(img_array.shape).round().astype("uint8") | |
) | |
score = model.score(x, y) | |
return pred, score | |