Spaces:
Build error
Build error
"""Plot pandas.DataFrame with DBSCAN clustering.""" | |
# pylint: disable=invalid-name, too-many-arguments, unused-import | |
import numpy as np # noqa | |
import pandas as pd | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.cluster import DBSCAN | |
from logzero import logger # noqa | |
# from radiobee.cmat2tset import cmat2tset | |
# turn interactive when in ipython session | |
_ = """ | |
if "get_ipython" in globals(): | |
plt.ion() | |
else: | |
plt.switch_backend('Agg') | |
# """ | |
# fastlid.set_languages = ["en", "zh"] | |
# fmt: off | |
def plot_df( | |
df_: pd.DataFrame, | |
# cmat: np.ndarray, | |
eps: float = 10, | |
min_samples: int = 6, | |
xlabel: str = "", | |
ylabel: str = "", | |
xlim: int = 0, | |
ylim: int = 0, | |
backend: str = "TkAgg", | |
) -> plt: | |
# fmt: on | |
"""Plot df with DBSCAN clustering. | |
Args: | |
df_: pandas.DataFrame, with three columns columns=["x", "y", "cos"] | |
Returns: | |
matplotlib.pyplot: for possible use in gradio | |
plot_df(pd.DataFrame(cmat2tset(smat), columns=['x', 'y', 'cos'])) | |
df_ = pd.DataFrame(cmat2tset(smat), columns=['x', 'y', 'cos']) | |
# sort 'x', axis 0 changes, index regenerated | |
df_s = df_.sort_values('x', axis=0, ignore_index=True) | |
# sorintg does not seem to impact clustering | |
DBSCAN(1.5, min_samples=3).fit(df_).labels_ | |
DBSCAN(1.5, min_samples=3).fit(df_s).labels_ | |
""" | |
# df_ = pd.DataFrame(cmat2tset(cmat)) | |
if df_.shape[1] == 3: | |
df_.columns = ["x", "y", "cos"] | |
else: | |
logger.error(" shape mismatch: %s, expected (x, 3)", df_.shape) | |
# return None | |
raise Exception(" df_.shape[1] not equal to 3 ") | |
if not xlim: | |
xlim = len(df_) | |
if not ylim: | |
ylim = df_.y.max() | |
if not xlabel: | |
xlabel = str(xlim) | |
if not ylabel: | |
ylabel = str(ylim) | |
backend_saved = matplotlib.get_backend() | |
# switch if necessary | |
if backend_saved != backend: | |
plt.switch_backend(backend) | |
sns.set() | |
sns.set_style("darkgrid") | |
fig = plt.figure(figsize=(13, 8)) | |
# gs = fig.add_gridspec(2, 2, wspace=0.4, hspace=0.58) | |
# ax2 = fig.add_subplot(gs[0, 0]) | |
# ax0 = fig.add_subplot(gs[0, 1]) | |
# ax1 = fig.add_subplot(gs[1, 0]) | |
gs = fig.add_gridspec(1, 1, wspace=0.4, hspace=0.58) | |
ax0 = fig.add_subplot(gs[0, 0]) | |
cmap = "viridis_r" | |
_ = DBSCAN(min_samples=min_samples, eps=eps).fit(df_).labels_ > -1 | |
_x = ~_ | |
# clustered | |
df_[_].plot.scatter("x", "y", c="cos", cmap=cmap, ax=ax0) | |
# outliers | |
df_[_x].plot.scatter("x", "y", c="r", marker="x", alpha=0.6, ax=ax0) | |
# ax1.set_xlabel("en") | |
# ax1.set_ylabel("zh") | |
ax0.set_xlabel(xlabel) | |
ax0.set_ylabel(ylabel) | |
# ax0.set_xlim(0, xlim) | |
# ax0.set_ylim(0, ylim) | |
ax0.set_title("max cos ('x': outliers)") | |
# ax1.set_title(f"potential aligned pairs ({round(sum(_) / xlim, 2):.0%})") | |
# restore if necessary | |
if backend_saved != backend: | |
plt.switch_backend(backend_saved) | |
return plt | |
_ = """ | |
eps: float = 10 | |
min_samples: int = 6 | |
xlabel: str = "" | |
ylabel: str = "" | |
xlim: int = 0 | |
ylim: int = 0 | |
""" | |