Spaces:
Runtime error
Runtime error
ChuckNorris
commited on
Commit
•
e46d7eb
1
Parent(s):
938a6c1
Initial commit
Browse files- src/tools.py +7 -0
- src/web_app.py +54 -47
src/tools.py
CHANGED
@@ -115,6 +115,13 @@ def filter_by_recency(data: pd.DataFrame, recency_filter: list) -> pd.DataFrame:
|
|
115 |
|
116 |
|
117 |
def filter_data(data: pd.DataFrame, filters: dict) -> pd.DataFrame or None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
data = filter_by_newbie(data, filters['newbie_filter'])
|
119 |
if data.shape[0] == 0:
|
120 |
return None
|
|
|
115 |
|
116 |
|
117 |
def filter_data(data: pd.DataFrame, filters: dict) -> pd.DataFrame or None:
|
118 |
+
"""
|
119 |
+
Filter data by user filters
|
120 |
+
|
121 |
+
:param data: filtered data
|
122 |
+
:param filters: dict of filters
|
123 |
+
:return: filtered data
|
124 |
+
"""
|
125 |
data = filter_by_newbie(data, filters['newbie_filter'])
|
126 |
if data.shape[0] == 0:
|
127 |
return None
|
src/web_app.py
CHANGED
@@ -9,25 +9,27 @@ import catboost
|
|
9 |
|
10 |
import tools
|
11 |
|
12 |
-
|
13 |
dataset, target, treatment = tools.get_data()
|
14 |
|
|
|
15 |
ct_cbc_model = catboost.CatBoostClassifier()
|
16 |
ct_cbc_model.load_model('src/models/ct_cbc.cbm')
|
17 |
-
|
18 |
sm_cbc_model = catboost.CatBoostClassifier()
|
19 |
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
|
20 |
-
|
21 |
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
|
22 |
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
|
23 |
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
|
24 |
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
|
25 |
-
|
26 |
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
|
27 |
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
|
28 |
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
|
29 |
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
|
30 |
|
|
|
31 |
data_train_index = pd.read_csv('data/data_train_index.csv')
|
32 |
data_test_index = pd.read_csv('data/data_test_index.csv')
|
33 |
treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
|
@@ -35,7 +37,6 @@ treatment_test_index = pd.read_csv('data/treatment_test_index.csv')
|
|
35 |
target_train_index = pd.read_csv('data/target_train_index.csv')
|
36 |
target_test_index = pd.read_csv('data/target_test_index.csv')
|
37 |
|
38 |
-
|
39 |
# фиксируем выборки, чтобы результат работы ML был предсказуем
|
40 |
data_train = dataset.loc[data_train_index['0']]
|
41 |
data_test = dataset.loc[data_test_index['0']]
|
@@ -44,9 +45,6 @@ treatment_test = treatment.loc[treatment_test_index['0']]
|
|
44 |
target_train = target.loc[target_train_index['0']]
|
45 |
target_test = target.loc[target_test_index['0']]
|
46 |
|
47 |
-
if 'filter_data' not in st.session_state.keys():
|
48 |
-
st.session_state.filter_data = True
|
49 |
-
|
50 |
st.title('Uplift lab')
|
51 |
|
52 |
st.markdown(
|
@@ -69,6 +67,7 @@ st.markdown(
|
|
69 |
Пример данных приведен ниже.
|
70 |
"""
|
71 |
)
|
|
|
72 |
refresh = st.button('Обновить выборку')
|
73 |
title_subsample = data_train.sample(7)
|
74 |
if refresh:
|
@@ -132,6 +131,7 @@ with st.expander('Развернуть блок анализа данных'):
|
|
132 |
|
133 |
filters = {}
|
134 |
|
|
|
135 |
with st.form(key='filter-clients'):
|
136 |
st.subheader('Выберем клиентов, которым отправим рекламу.')
|
137 |
|
@@ -195,7 +195,7 @@ with st.form(key='filter-clients'):
|
|
195 |
|
196 |
filter_form_submit_button = st.form_submit_button('Применить фильтр')
|
197 |
|
198 |
-
|
199 |
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
|
200 |
st.error('Необходимо выбрать хотя бы один класс')
|
201 |
st.stop()
|
@@ -203,7 +203,10 @@ elif not surburban and not urban and not rural:
|
|
203 |
st.error('Необходимо выбрать хотя бы один почтовый индекс')
|
204 |
st.stop()
|
205 |
|
|
|
206 |
filtered_dataset = tools.filter_data(data_test, filters)
|
|
|
|
|
207 |
if filtered_dataset is None:
|
208 |
st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
|
209 |
st.stop()
|
@@ -213,25 +216,27 @@ uplift = [1 for _ in filtered_dataset.index]
|
|
213 |
target_filtered = target_test.loc[filtered_dataset.index]
|
214 |
treatment_filtered = treatment_test.loc[filtered_dataset.index]
|
215 |
|
|
|
216 |
with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
|
217 |
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
|
218 |
example = filtered_dataset.sample(sample_size)
|
219 |
st.dataframe(example)
|
220 |
res = st.button('Обновить')
|
221 |
|
222 |
-
|
223 |
with st.form(key='user_metricks'):
|
|
|
224 |
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
225 |
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
226 |
user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
|
227 |
user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
|
|
|
228 |
col1, col2, col3 = st.columns(3)
|
229 |
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
|
230 |
col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользова��еля')
|
231 |
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
232 |
st.write('Uplift по процентилям')
|
233 |
st.write(user_metric_uplift_by_percentile)
|
234 |
-
|
235 |
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
236 |
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
237 |
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
@@ -242,39 +247,41 @@ with st.form(key='user_metricks'):
|
|
242 |
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
243 |
if show_ml_reasons:
|
244 |
with st.expander('Решение с помощью CatBoost'):
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
9 |
|
10 |
import tools
|
11 |
|
12 |
+
# загрузим датасет
|
13 |
dataset, target, treatment = tools.get_data()
|
14 |
|
15 |
+
# загрузим модель для ClassTransform
|
16 |
ct_cbc_model = catboost.CatBoostClassifier()
|
17 |
ct_cbc_model.load_model('src/models/ct_cbc.cbm')
|
18 |
+
# загрузим модель для SingleMod
|
19 |
sm_cbc_model = catboost.CatBoostClassifier()
|
20 |
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
|
21 |
+
# загрузим модели для независимого класификатора
|
22 |
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
|
23 |
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
|
24 |
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
|
25 |
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
|
26 |
+
# загрузим модели для зависимого класификатора
|
27 |
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
|
28 |
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
|
29 |
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
|
30 |
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
|
31 |
|
32 |
+
# загрузим данные
|
33 |
data_train_index = pd.read_csv('data/data_train_index.csv')
|
34 |
data_test_index = pd.read_csv('data/data_test_index.csv')
|
35 |
treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
|
|
|
37 |
target_train_index = pd.read_csv('data/target_train_index.csv')
|
38 |
target_test_index = pd.read_csv('data/target_test_index.csv')
|
39 |
|
|
|
40 |
# фиксируем выборки, чтобы результат работы ML был предсказуем
|
41 |
data_train = dataset.loc[data_train_index['0']]
|
42 |
data_test = dataset.loc[data_test_index['0']]
|
|
|
45 |
target_train = target.loc[target_train_index['0']]
|
46 |
target_test = target.loc[target_test_index['0']]
|
47 |
|
|
|
|
|
|
|
48 |
st.title('Uplift lab')
|
49 |
|
50 |
st.markdown(
|
|
|
67 |
Пример данных приведен ниже.
|
68 |
"""
|
69 |
)
|
70 |
+
|
71 |
refresh = st.button('Обновить выборку')
|
72 |
title_subsample = data_train.sample(7)
|
73 |
if refresh:
|
|
|
131 |
|
132 |
filters = {}
|
133 |
|
134 |
+
# блок фильтров
|
135 |
with st.form(key='filter-clients'):
|
136 |
st.subheader('Выберем клиентов, которым отправим рекламу.')
|
137 |
|
|
|
195 |
|
196 |
filter_form_submit_button = st.form_submit_button('Применить фильтр')
|
197 |
|
198 |
+
# проверка корректности заполнения форм
|
199 |
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
|
200 |
st.error('Необходимо выбрать хотя бы один класс')
|
201 |
st.stop()
|
|
|
203 |
st.error('Необходимо выбрать хотя бы один почтовый индекс')
|
204 |
st.stop()
|
205 |
|
206 |
+
# фильтруем тестовые данные по пользовательскому выбору
|
207 |
filtered_dataset = tools.filter_data(data_test, filters)
|
208 |
+
|
209 |
+
# проверяем, что данные отфильтровались
|
210 |
if filtered_dataset is None:
|
211 |
st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
|
212 |
st.stop()
|
|
|
216 |
target_filtered = target_test.loc[filtered_dataset.index]
|
217 |
treatment_filtered = treatment_test.loc[filtered_dataset.index]
|
218 |
|
219 |
+
# блок с демонстрацией отфильтрованных данных
|
220 |
with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
|
221 |
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
|
222 |
example = filtered_dataset.sample(sample_size)
|
223 |
st.dataframe(example)
|
224 |
res = st.button('Обновить')
|
225 |
|
|
|
226 |
with st.form(key='user_metricks'):
|
227 |
+
# считаем метрики для пользователя
|
228 |
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
229 |
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
230 |
user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
|
231 |
user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
|
232 |
+
# отображаем метрики
|
233 |
col1, col2, col3 = st.columns(3)
|
234 |
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
|
235 |
col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользова��еля')
|
236 |
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
237 |
st.write('Uplift по процентилям')
|
238 |
st.write(user_metric_uplift_by_percentile)
|
239 |
+
# отображаем графики
|
240 |
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
241 |
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
242 |
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
|
|
247 |
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
248 |
if show_ml_reasons:
|
249 |
with st.expander('Решение с помощью CatBoost'):
|
250 |
+
with st.form(key='catboost_metricks'):
|
251 |
+
|
252 |
+
tm_ctrl = TwoModels(
|
253 |
+
estimator_trmnt=tm_dependend_trmntl_cbc,
|
254 |
+
estimator_ctrl=tm_dependend_ctrl_cbc,
|
255 |
+
method='ddr_control'
|
256 |
+
)
|
257 |
+
|
258 |
+
tm_ctrl = tm_ctrl.fit(
|
259 |
+
data_train, target_train, treatment_train,
|
260 |
+
estimator_trmnt_fit_params={
|
261 |
+
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
|
262 |
+
estimator_ctrl_fit_params={
|
263 |
+
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
|
264 |
+
)
|
265 |
+
|
266 |
+
uplift_tm_ctrl = tm_ctrl.predict(filtered_dataset)
|
267 |
+
|
268 |
+
tm_ctrl_score = uplift_at_k(y_true=target_filtered, uplift=uplift_tm_ctrl, treatment=treatment_filtered,
|
269 |
+
strategy='by_group', k=k)
|
270 |
+
# считаем метрики для ML
|
271 |
+
catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
|
272 |
+
catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
273 |
+
catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
274 |
+
catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
275 |
+
# отображаем метрики
|
276 |
+
col1, col2, col3 = st.columns(3)
|
277 |
+
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
|
278 |
+
col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
|
279 |
+
col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
|
280 |
+
st.write('Uplift по процентилям')
|
281 |
+
st.write(catboost_uplift_by_percentile)
|
282 |
+
|
283 |
+
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
284 |
+
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
285 |
+
st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
|
286 |
+
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
287 |
+
st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)
|