Spaces:
Runtime error
Runtime error
from typing import Any | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from sklift.datasets import fetch_hillstrom | |
from catboost import CatBoostClassifier | |
import sklearn | |
import streamlit as st | |
import plotly.express as px | |
import plotly.graph_objects as go | |
def get_data() -> tuple[Any, Any, Any]: | |
# получаем датасет | |
dataset = fetch_hillstrom(target_col='visit') | |
dataset, target, treatment = dataset['data'], dataset['target'], dataset['treatment'] | |
# выбираем два сегмента | |
dataset = dataset[treatment != 'Mens E-Mail'] | |
target = target[treatment != 'Mens E-Mail'] | |
treatment = treatment[treatment != 'Mens E-Mail'].map({ | |
'Womens E-Mail': 1, | |
'No E-Mail': 0 | |
}) | |
return dataset, target, treatment | |
def data_split(data, treatment, target) -> tuple[Any, Any, Any, Any, Any, Any]: | |
# склеиваем threatment и target для дальнейшей стратификации по ним | |
stratify_cols = pd.concat([treatment, target], axis=1) | |
# сплитим датасет | |
X_train, X_val, trmnt_train, trmnt_val, y_train, y_val = train_test_split( | |
data, | |
treatment, | |
target, | |
stratify=stratify_cols, | |
test_size=0.3, | |
random_state=42 | |
) | |
return X_train, X_val, trmnt_train, trmnt_val, y_train, y_val | |
def get_newbie_plot(data): | |
fig = px.histogram( | |
data['newbie'], | |
color=data['newbie'], | |
title='Распределение клиентов по флагу Newbie' | |
) | |
fig.update_xaxes( | |
title='', | |
ticktext=['"Старые" клиенты', 'Новые клиенты'], | |
tickvals=[0, 1] | |
) | |
fig.update_yaxes( | |
title='Количество' | |
) | |
fig.update_layout( | |
showlegend=False, | |
bargap=0.3 | |
) | |
fig.update_traces(hovertemplate="Количество клиентов: %{y}") | |
return fig | |
def get_zipcode_plot(data): | |
fig = px.histogram( | |
data['zip_code'], | |
color=data['zip_code'], | |
title='Распределение клиентов по флагу zip_code' | |
) | |
fig.update_xaxes( | |
title='', | |
categoryorder='total descending' | |
) | |
fig.update_yaxes( | |
title='Количество' | |
) | |
fig.update_layout( | |
showlegend=False, | |
bargap=0.3 | |
) | |
fig.update_traces(hovertemplate="Количество клиентов: %{y}") | |
return fig | |
def get_channel_plot(data): | |
fig = px.histogram( | |
data['channel'], | |
color=data['channel'], | |
title='Распределение клиентов по флагу zip_code' | |
) | |
fig.update_xaxes( | |
title='', | |
categoryorder='total descending' | |
) | |
fig.update_yaxes( | |
title='Количество' | |
) | |
fig.update_layout( | |
showlegend=False, | |
bargap=0.3 | |
) | |
fig.update_traces(hovertemplate="Количество клиентов: %{y}") | |
return fig | |
def get_history_segment_plot(data): | |
fig = px.histogram( | |
data['history_segment'], | |
color=data['history_segment'], | |
title='Распределение клиентов по флагу history_segment' | |
) | |
fig.update_xaxes( | |
title='', | |
categoryorder='total descending' | |
) | |
fig.update_yaxes( | |
title='Количество' | |
) | |
fig.update_layout( | |
showlegend=False, | |
bargap=0.3 | |
) | |
fig.update_traces(hovertemplate="Количество клиентов: %{y}") | |
return fig | |