from typing import Any
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklift.datasets import fetch_hillstrom
from catboost import CatBoostClassifier
import sklearn
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
@st.experimental_memo
def get_data() -> tuple[Any, Any, Any]:
# получаем датасет
dataset = fetch_hillstrom(target_col='visit')
dataset, target, treatment = dataset['data'], dataset['target'], dataset['treatment']
# выбираем два сегмента
dataset = dataset[treatment != 'Mens E-Mail']
target = target[treatment != 'Mens E-Mail']
treatment = treatment[treatment != 'Mens E-Mail'].map({
'Womens E-Mail': 1,
'No E-Mail': 0
})
return dataset, target, treatment
@st.experimental_memo
def data_split(data, treatment, target) -> tuple[Any, Any, Any, Any, Any, Any]:
# склеиваем threatment и target для дальнейшей стратификации по ним
stratify_cols = pd.concat([treatment, target], axis=1)
# сплитим датасет
X_train, X_val, trmnt_train, trmnt_val, y_train, y_val = train_test_split(
data,
treatment,
target,
stratify=stratify_cols,
test_size=0.3,
random_state=42
)
return X_train, X_val, trmnt_train, trmnt_val, y_train, y_val
def get_newbie_plot(data):
fig = px.histogram(
data['newbie'],
color=data['newbie'],
title='Распределение клиентов по флагу newbie'
)
fig.update_xaxes(
title='',
ticktext=['"Старые" клиенты', 'Новые клиенты'],
tickvals=[0, 1]
)
fig.update_yaxes(
title='Количество клиентов'
)
fig.update_layout(
showlegend=False,
bargap=0.3,
margin=dict(l=20, r=10, t=80, b=10)
)
fig.update_traces(hovertemplate="Количество клиентов: %{y}")
return fig
def get_zipcode_plot(data):
fig = px.histogram(
data['zip_code'],
color=data['newbie'],
title='Распределение клиентов по почтовым индексам'
)
fig.update_xaxes(
title='',
categoryorder='total descending'
)
fig.update_yaxes(
title='Количество клиентов'
)
fig.update_layout(
showlegend=True,
legend_orientation="h",
legend=dict(x=.66, y=.99, title='Новый клиент'),
margin=dict(l=20, r=10, t=80, b=10),
hovermode="x",
bargap=0.3
)
fig.update_traces(hovertemplate="Количество клиентов: %{y}")
return fig
def get_channel_plot(data):
fig = px.histogram(
data['channel'],
color=data['newbie'],
title='Распределение клиентов по каналам покупки товаров'
)
fig.update_xaxes(
title='',
categoryorder='total descending'
)
fig.update_yaxes(
title='Количество клиентов'
)
fig.update_layout(
showlegend=True,
legend_orientation="h",
legend=dict(x=.66, y=.99, title='Новый клиент'),
margin=dict(l=20, r=10, t=80, b=10),
hovermode="x",
bargap=0.3
)
fig.update_traces(hovertemplate="Количество клиентов: %{y}")
return fig
def get_history_segment_plot(data):
fig = px.histogram(
data['history_segment'],
color=data['history_segment'],
title='Распределение клиентов по количеству $, потраченных в прошлом году'
)
fig.update_xaxes(
title='',
categoryorder='total descending',
tickangle=45
)
fig.update_yaxes(
title='Количество клиентов'
)
fig.update_layout(
showlegend=False,
bargap=0.3,
margin=dict(l=20, r=10, t=80, b=10)
)
fig.update_traces(hovertemplate="Количество клиентов: %{y}")
return fig
def get_recency_plot(data):
fig = px.histogram(
data['recency'],
color=data['newbie'],
title='Распределение клиентов по количеству месяцев с последней покупки'
)
fig.update_xaxes(
title='Месяцев после покупки'
)
fig.update_yaxes(
title='Количество клиентов'
)
fig.update_layout(
showlegend=True,
legend_orientation="h",
legend=dict(x=.66, y=.99, title='Новый клиент'),
margin=dict(l=20, r=10, t=80, b=10),
hovermode="x",
bargap=0.3
)
fig.update_traces(hovertemplate="
".join(
[
"Месяцев: %{x}",
"Клиентов: %{y}"
]
)
)
return fig
def get_history_plot(data):
fig = px.histogram(
data['history'],
color=data['newbie'],
title='Распределение клиентов по количеству месяцев с последней покупки'
)
fig.update_xaxes(
title='Месяцев после покупки'
)
fig.update_yaxes(
title='Количество клиентов'
)
fig.update_layout(
showlegend=True,
legend_orientation="h",
legend=dict(x=.66, y=.99, title='Новый клиент'),
margin=dict(l=20, r=10, t=80, b=10),
hovermode="x",
bargap=0.3
)
fig.update_traces(hovertemplate="
".join(
[
'Совершено покупок на: $%{x}',
'Количество клиентов: %{y}'
]
)
)
return fig