uplift_lab / src /web_app.py
ChuckNorris
Initial commit
a22630b
raw
history blame
17.1 kB
import catboost
import pandas as pd
import os
from sklift.metrics import uplift_at_k, uplift_by_percentile, qini_auc_score
from sklift.viz import plot_qini_curve, plot_uplift_curve
from sklift.models import SoloModel, TwoModels, ClassTransformation
import streamlit as st
import catboost
import tools
dataset, target, treatment = tools.get_data()
ct_cbc_model = catboost.CatBoostClassifier()
ct_cbc_model.load_model('src/models/ct_cbc.cbm')
sm_cbc_model = catboost.CatBoostClassifier()
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
data_train_index = pd.read_csv('data/data_train_index.csv')
data_test_index = pd.read_csv('data/data_test_index.csv')
treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
treatment_test_index = pd.read_csv('data/treatment_test_index.csv')
target_train_index = pd.read_csv('data/target_train_index.csv')
target_test_index = pd.read_csv('data/target_test_index.csv')
# фиксируем выборки, чтобы результат работы ML был предсказуем
data_train = dataset.loc[data_train_index['0']]
data_test = dataset.loc[data_test_index['0']]
treatment_train = treatment.loc[treatment_train_index['0']]
treatment_test = treatment.loc[treatment_test_index['0']]
target_train = target.loc[target_train_index['0']]
target_test = target.loc[target_test_index['0']]
if 'filter_data' not in st.session_state.keys():
st.session_state.filter_data = True
st.title('Uplift lab')
st.markdown(
"""
#### Рассмотрим пример применения одного из подходов прогнозирования _uplift_.
Данные для примера взяты из [_The MineThatData E-Mail Analytics And Data Mining Challenge_](https://blog.minethatdata.com/2008/03/minethatdata-e-mail-analytics-and-data.html)
Этот набор данных содержит 42 693 строк с данными клиентов, которые в последний раз совершали покупки в течение двенадцати месяцев.
Из данных уже отделена тестовая выборка в виде 30% записей клиентов, так что данных в предоставленной выборке будет меньше.
Среди клиентов была проведена рекламная кампания с помощью _email_ рассылки:
- 1/2 клиентов были выбраны случайным образом для получения электронного письма, рекламирующего женскую продукцию;
- С оставшейся 1/2 коммуникацию не проводили.
Для каждого клиента из выборки замерили факт перехода по ссылке в письме, факт совершения покупки и сумму трат за
две недели, следующими после получения письма.
Пример данных приведен ниже.
"""
)
refresh = st.button('Обновить выборку')
title_subsample = data_train.sample(7)
if refresh:
title_subsample = data_train.sample(7)
st.dataframe(title_subsample, width=700)
st.write(f"Всего записей: {data_train.shape[0]}")
st.write('Описание данных')
st.markdown(
"""
| Колонка | Обозначение |
|-------------------|------------------------------------------------------------------------|
| _recency_ | Месяцев с момента последней покупки |
| _history_segment_ | Классификация клиентов в долларах, потраченных в прошлом году |
| _history_ | Фактическая стоимость в долларах, потраченная в прошлом году |
| _mens_ | Флаг 1/0, 1 = клиент приобрел мужские товары в прошлом году |
| _womens_ | Флаг 1/0, 1 = клиент приобрел женские товары в прошлом году |
| _zip_code_ | Классифицирует почтовый индекс как городской, пригородный или сельский |
| _newbie_ | Флаг 1/0, 1 = Новый клиент за последние двенадцать месяцев |
| _channel_ | Описывает каналы, через которые клиент приобрел тоовар в прошлом году |
---
"""
)
st.write("Для того, чтобы лучше понять на какую аудиторию лучше запустить рекламную кампанию, проведем небольшой \
анализ данных")
with st.expander('Развернуть блок анализа данных'):
st.plotly_chart(tools.get_newbie_plot(data_train), use_container_width=True)
st.write(f'В данных примерно одинаковое количество новых и "старых клиентов". '
f'Отношение новых клиентов к старым: {(data_train["newbie"] == 1).sum() / (data_train["newbie"] == 0).sum():.2f}')
st.plotly_chart(tools.get_zipcode_plot(data_train), use_container_width=True)
tmp_res = data_train.zip_code.value_counts(normalize=True) * 100
st.write(f'Большинство клиентов из пригорода: {tmp_res["Surburban"]:.2f}%, из города: {tmp_res["Urban"]:.2f}% и из села: {tmp_res["Rural"]:.2f}%')
tmp_res = data_train.channel.value_counts(normalize=True) * 100
st.plotly_chart(tools.get_channel_plot(data_train), use_container_width=True)
st.write(f'В прошлом году почти одинаковое количество клиентов покупало товары через телефон и сайт, {tmp_res["Phone"]:.2f}% и {tmp_res["Web"]:.2f}% соответственно,'
f' а {tmp_res["Multichannel"]:.2f}% клиентов покупали товары воспользовавшись двумя платформами.')
tmp_res = data_train.history_segment.value_counts(normalize=True) * 100
st.plotly_chart(tools.get_history_segment_plot(data_train), use_container_width=True)
st.write(f'Как мы видим, большинство пользователей относится к сегменту \$0-\$100 ({tmp_res[0]:.2f}%), второй и '
f'третий по количеству пользователей сегменты \$100-\$200 ({tmp_res[1]:.2f}%) и \$200-\$350 ({tmp_res[2]:.2f}%).')
st.write(f'К сегментам \$350-\$500 и \$500-\$750 относится {tmp_res[3]:.2f}% и {tmp_res[4]:.2f}% пользователей соответственно.')
st.write(f'Меньше всего пользователей в сегментах \$750-\$1.000 ({tmp_res[-2]:.2f}%) и \$1.000+ ({tmp_res[-1]:.2f}%).')
tmp_res = list(data_train.recency.value_counts(normalize=True) * 100)
st.plotly_chart(tools.get_recency_plot(data_train), use_container_width=True)
st.write(f'Большинство клиентов являются активными клиентами платформы, и совершали покупки в течение месяца ({tmp_res[0]:.2f}%)')
st.write('Также заметно, что 9 и 10 месяцев назад, много клиентов совершали покупки. Это может свидетельствовать о проведении'
'рекламной кампании в это время или чего-то еще.')
st.write('Также интересно понаблюдать за долями новых клиентов в данном распределении.')
st.plotly_chart(tools.get_history_plot(data_train), use_container_width=True)
st.markdown('_График интерактивный. Двойной клик вернет в начальное состояние._')
st.write('Абсолютное большинство клиентов тратят \$25-\$35 на покупки, но есть и малая доля тех, кто тратит более \$3.000')
st.write('Интересный факт: все покупки более \$500 совершают только новые клиенты')
filters = {}
with st.form(key='filter-clients'):
st.subheader('Выберем клиентов, которым отправим рекламу.')
col1, col2, col3 = st.columns(3)
channel_filter = col1.radio('Канал покупки прошлом году', options=['Все', 'Phone', 'Web', 'Multichannel'])
filters['channel_filter'] = channel_filter
newbie_filter = col2.radio('Тип клиента', options=['Все', 'Только новые', 'Только старые'])
filters['newbie_filter'] = newbie_filter
mens_filter = col3.radio('Клиенты, приобретавшие товары', options=['Любые', 'Мужские', 'Женские'])
filters['mens_filter'] = mens_filter
filters['history_segments'] = {}
col1, col2 = st.columns(2)
with col1:
st.write('Класс клиентов по объему денег, потраченных в прошлом году (history segments)')
first_group = st.checkbox('$0-$100', value=True)
if first_group:
filters['history_segments']['1) $0 - $100'] = True
second_group = st.checkbox('$100-$200', value=True)
if second_group:
filters['history_segments']['2) $100 - $200'] = True
third_group = st.checkbox('$200-$350', value=True)
if third_group:
filters['history_segments']['3) $200 - $350'] = True
fourth_group = st.checkbox('$350-$500', value=True)
if fourth_group:
filters['history_segments']['4) $350 - $500'] = True
fifth_group = st.checkbox('$500-$750', value=True)
if fifth_group:
filters['history_segments']['5) $500 - $750'] = True
sixth_group = st.checkbox('$750-$1.000', value=True)
if sixth_group:
filters['history_segments']['6) $750 - $1,000'] = True
seventh_group = st.checkbox('$1.000+', value=True)
if seventh_group:
filters['history_segments']['7) $1,000 +'] = True
with col2:
st.write('Каких пользователей по почтовому коду выберем')
filters['zip_code'] = {}
surburban = st.checkbox('Surburban', value=True)
if surburban:
filters['zip_code']['surburban'] = True
urban = st.checkbox('Urban', value=True)
if urban:
filters['zip_code']['urban'] = True
rural = st.checkbox('Rural', value=True)
if rural:
filters['zip_code']['rural'] = True
recency = st.slider(label='Месяцев с момента покупки', min_value=int(data_test.recency.min()), max_value=int(data_test.recency.max()), value=(int(data_test.recency.min()), int(data_test.recency.max())))
filters['recency'] = recency
st.write('Если известно на какой процент аудитории необходимо повлиять, измените значение')
k = st.slider(label='Процент аудитории', min_value=1, max_value=100, value=100)
filter_form_submit_button = st.form_submit_button('Применить фильтр')
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
st.error('Необходимо выбрать хотя бы один класс')
st.stop()
elif not surburban and not urban and not rural:
st.error('Необходимо выбрать хотя бы один почтовый индекс')
st.stop()
filtered_dataset = tools.filter_data(data_test, filters)
if filtered_dataset is None:
st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
st.stop()
# значение uplift для записей тех клиентов, который выбрал пользователь равен 1
uplift = [1 for _ in filtered_dataset.index]
target_filtered = target_test.loc[filtered_dataset.index]
treatment_filtered = treatment_test.loc[filtered_dataset.index]
with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
example = filtered_dataset.sample(sample_size)
st.dataframe(example)
res = st.button('Обновить')
with st.form(key='user_metricks'):
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
col1, col2, col3 = st.columns(3)
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя')
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
st.write('Uplift по процентилям')
st.write(user_metric_uplift_by_percentile)
st.form_submit_button('Обновить графики', help='При изменении флагов')
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
st.pyplot(plot_uplift_curve(target_filtered, uplift, treatment_filtered, perfect=prefect_uplift).figure_)
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
if show_ml_reasons:
with st.expander('Решение с помощью CatBoost'):
tm_ctrl = TwoModels(
estimator_trmnt=tm_dependend_trmntl_cbc,
estimator_ctrl=tm_dependend_ctrl_cbc,
method='ddr_control'
)
tm_ctrl = tm_ctrl.fit(
data_train, target_train, treatment_train,
estimator_trmnt_fit_params={
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
estimator_ctrl_fit_params={
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
)
uplift_tm_ctrl = tm_ctrl.predict(data_test)
tm_ctrl_score = uplift_at_k(y_true=target_test, uplift=uplift_tm_ctrl, treatment=treatment_test,
strategy='by_group', k=k)
# считаем метрики для ML
catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
# отображаем метрики
col1, col2, col3 = st.columns(3)
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
st.write('Uplift по процентилям')
st.write(catboost_uplift_by_percentile)
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)