ChuckNorris commited on
Commit
e46d7eb
1 Parent(s): 938a6c1

Initial commit

Browse files
Files changed (2) hide show
  1. src/tools.py +7 -0
  2. src/web_app.py +54 -47
src/tools.py CHANGED
@@ -115,6 +115,13 @@ def filter_by_recency(data: pd.DataFrame, recency_filter: list) -> pd.DataFrame:
115
 
116
 
117
  def filter_data(data: pd.DataFrame, filters: dict) -> pd.DataFrame or None:
 
 
 
 
 
 
 
118
  data = filter_by_newbie(data, filters['newbie_filter'])
119
  if data.shape[0] == 0:
120
  return None
 
115
 
116
 
117
  def filter_data(data: pd.DataFrame, filters: dict) -> pd.DataFrame or None:
118
+ """
119
+ Filter data by user filters
120
+
121
+ :param data: filtered data
122
+ :param filters: dict of filters
123
+ :return: filtered data
124
+ """
125
  data = filter_by_newbie(data, filters['newbie_filter'])
126
  if data.shape[0] == 0:
127
  return None
src/web_app.py CHANGED
@@ -9,25 +9,27 @@ import catboost
9
 
10
  import tools
11
 
12
-
13
  dataset, target, treatment = tools.get_data()
14
 
 
15
  ct_cbc_model = catboost.CatBoostClassifier()
16
  ct_cbc_model.load_model('src/models/ct_cbc.cbm')
17
-
18
  sm_cbc_model = catboost.CatBoostClassifier()
19
  sm_cbc_model.load_model('src/models/sm_cbc.cbm')
20
-
21
  tm_ctrl_cbc_model = catboost.CatBoostClassifier()
22
  tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
23
  tm_trmnt_cbc_model = catboost.CatBoostClassifier()
24
  tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
25
-
26
  tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
27
  tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
28
  tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
29
  tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
30
 
 
31
  data_train_index = pd.read_csv('data/data_train_index.csv')
32
  data_test_index = pd.read_csv('data/data_test_index.csv')
33
  treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
@@ -35,7 +37,6 @@ treatment_test_index = pd.read_csv('data/treatment_test_index.csv')
35
  target_train_index = pd.read_csv('data/target_train_index.csv')
36
  target_test_index = pd.read_csv('data/target_test_index.csv')
37
 
38
-
39
  # фиксируем выборки, чтобы результат работы ML был предсказуем
40
  data_train = dataset.loc[data_train_index['0']]
41
  data_test = dataset.loc[data_test_index['0']]
@@ -44,9 +45,6 @@ treatment_test = treatment.loc[treatment_test_index['0']]
44
  target_train = target.loc[target_train_index['0']]
45
  target_test = target.loc[target_test_index['0']]
46
 
47
- if 'filter_data' not in st.session_state.keys():
48
- st.session_state.filter_data = True
49
-
50
  st.title('Uplift lab')
51
 
52
  st.markdown(
@@ -69,6 +67,7 @@ st.markdown(
69
  Пример данных приведен ниже.
70
  """
71
  )
 
72
  refresh = st.button('Обновить выборку')
73
  title_subsample = data_train.sample(7)
74
  if refresh:
@@ -132,6 +131,7 @@ with st.expander('Развернуть блок анализа данных'):
132
 
133
  filters = {}
134
 
 
135
  with st.form(key='filter-clients'):
136
  st.subheader('Выберем клиентов, которым отправим рекламу.')
137
 
@@ -195,7 +195,7 @@ with st.form(key='filter-clients'):
195
 
196
  filter_form_submit_button = st.form_submit_button('Применить фильтр')
197
 
198
-
199
  if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
200
  st.error('Необходимо выбрать хотя бы один класс')
201
  st.stop()
@@ -203,7 +203,10 @@ elif not surburban and not urban and not rural:
203
  st.error('Необходимо выбрать хотя бы один почтовый индекс')
204
  st.stop()
205
 
 
206
  filtered_dataset = tools.filter_data(data_test, filters)
 
 
207
  if filtered_dataset is None:
208
  st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
209
  st.stop()
@@ -213,25 +216,27 @@ uplift = [1 for _ in filtered_dataset.index]
213
  target_filtered = target_test.loc[filtered_dataset.index]
214
  treatment_filtered = treatment_test.loc[filtered_dataset.index]
215
 
 
216
  with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
217
  sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
218
  example = filtered_dataset.sample(sample_size)
219
  st.dataframe(example)
220
  res = st.button('Обновить')
221
 
222
-
223
  with st.form(key='user_metricks'):
 
224
  user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
225
  user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
226
  user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
227
  user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
 
228
  col1, col2, col3 = st.columns(3)
229
  col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
230
  col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользова��еля')
231
  col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
232
  st.write('Uplift по процентилям')
233
  st.write(user_metric_uplift_by_percentile)
234
-
235
  st.form_submit_button('Обновить графики', help='При изменении флагов')
236
  perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
237
  st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
@@ -242,39 +247,41 @@ with st.form(key='user_metricks'):
242
  show_ml_reasons = st.checkbox('Показать решения с помощью ML')
243
  if show_ml_reasons:
244
  with st.expander('Решение с помощью CatBoost'):
245
- tm_ctrl = TwoModels(
246
- estimator_trmnt=tm_dependend_trmntl_cbc,
247
- estimator_ctrl=tm_dependend_ctrl_cbc,
248
- method='ddr_control'
249
- )
250
-
251
-
252
- tm_ctrl = tm_ctrl.fit(
253
- data_train, target_train, treatment_train,
254
- estimator_trmnt_fit_params={
255
- 'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
256
- estimator_ctrl_fit_params={
257
- 'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
258
- )
259
-
260
- uplift_tm_ctrl = tm_ctrl.predict(data_test)
261
-
262
- tm_ctrl_score = uplift_at_k(y_true=target_test, uplift=uplift_tm_ctrl, treatment=treatment_test,
263
- strategy='by_group', k=k)
264
- # считаем метрики для ML
265
- catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
266
- catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
267
- catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
268
- catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
269
- # отображаем метрики
270
- col1, col2, col3 = st.columns(3)
271
- col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
272
- col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
273
- col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
274
- st.write('Uplift по процентилям')
275
- st.write(catboost_uplift_by_percentile)
276
-
277
- perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
278
- st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
279
- prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
280
- st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)
 
 
 
9
 
10
  import tools
11
 
12
+ # загрузим датасет
13
  dataset, target, treatment = tools.get_data()
14
 
15
+ # загрузим модель для ClassTransform
16
  ct_cbc_model = catboost.CatBoostClassifier()
17
  ct_cbc_model.load_model('src/models/ct_cbc.cbm')
18
+ # загрузим модель для SingleMod
19
  sm_cbc_model = catboost.CatBoostClassifier()
20
  sm_cbc_model.load_model('src/models/sm_cbc.cbm')
21
+ # загрузим модели для независимого класификатора
22
  tm_ctrl_cbc_model = catboost.CatBoostClassifier()
23
  tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
24
  tm_trmnt_cbc_model = catboost.CatBoostClassifier()
25
  tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
26
+ # загрузим модели для зависимого класификатора
27
  tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
28
  tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
29
  tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
30
  tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
31
 
32
+ # загрузим данные
33
  data_train_index = pd.read_csv('data/data_train_index.csv')
34
  data_test_index = pd.read_csv('data/data_test_index.csv')
35
  treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
 
37
  target_train_index = pd.read_csv('data/target_train_index.csv')
38
  target_test_index = pd.read_csv('data/target_test_index.csv')
39
 
 
40
  # фиксируем выборки, чтобы результат работы ML был предсказуем
41
  data_train = dataset.loc[data_train_index['0']]
42
  data_test = dataset.loc[data_test_index['0']]
 
45
  target_train = target.loc[target_train_index['0']]
46
  target_test = target.loc[target_test_index['0']]
47
 
 
 
 
48
  st.title('Uplift lab')
49
 
50
  st.markdown(
 
67
  Пример данных приведен ниже.
68
  """
69
  )
70
+
71
  refresh = st.button('Обновить выборку')
72
  title_subsample = data_train.sample(7)
73
  if refresh:
 
131
 
132
  filters = {}
133
 
134
+ # блок фильтров
135
  with st.form(key='filter-clients'):
136
  st.subheader('Выберем клиентов, которым отправим рекламу.')
137
 
 
195
 
196
  filter_form_submit_button = st.form_submit_button('Применить фильтр')
197
 
198
+ # проверка корректности заполнения форм
199
  if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
200
  st.error('Необходимо выбрать хотя бы один класс')
201
  st.stop()
 
203
  st.error('Необходимо выбрать хотя бы один почтовый индекс')
204
  st.stop()
205
 
206
+ # фильтруем тестовые данные по пользовательскому выбору
207
  filtered_dataset = tools.filter_data(data_test, filters)
208
+
209
+ # проверяем, что данные отфильтровались
210
  if filtered_dataset is None:
211
  st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
212
  st.stop()
 
216
  target_filtered = target_test.loc[filtered_dataset.index]
217
  treatment_filtered = treatment_test.loc[filtered_dataset.index]
218
 
219
+ # блок с демонстрацией отфильтрованных данных
220
  with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
221
  sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
222
  example = filtered_dataset.sample(sample_size)
223
  st.dataframe(example)
224
  res = st.button('Обновить')
225
 
 
226
  with st.form(key='user_metricks'):
227
+ # считаем метрики для пользователя
228
  user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
229
  user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
230
  user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
231
  user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
232
+ # отображаем метрики
233
  col1, col2, col3 = st.columns(3)
234
  col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
235
  col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользова��еля')
236
  col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
237
  st.write('Uplift по процентилям')
238
  st.write(user_metric_uplift_by_percentile)
239
+ # отображаем графики
240
  st.form_submit_button('Обновить графики', help='При изменении флагов')
241
  perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
242
  st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
 
247
  show_ml_reasons = st.checkbox('Показать решения с помощью ML')
248
  if show_ml_reasons:
249
  with st.expander('Решение с помощью CatBoost'):
250
+ with st.form(key='catboost_metricks'):
251
+
252
+ tm_ctrl = TwoModels(
253
+ estimator_trmnt=tm_dependend_trmntl_cbc,
254
+ estimator_ctrl=tm_dependend_ctrl_cbc,
255
+ method='ddr_control'
256
+ )
257
+
258
+ tm_ctrl = tm_ctrl.fit(
259
+ data_train, target_train, treatment_train,
260
+ estimator_trmnt_fit_params={
261
+ 'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
262
+ estimator_ctrl_fit_params={
263
+ 'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
264
+ )
265
+
266
+ uplift_tm_ctrl = tm_ctrl.predict(filtered_dataset)
267
+
268
+ tm_ctrl_score = uplift_at_k(y_true=target_filtered, uplift=uplift_tm_ctrl, treatment=treatment_filtered,
269
+ strategy='by_group', k=k)
270
+ # считаем метрики для ML
271
+ catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
272
+ catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
273
+ catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
274
+ catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
275
+ # отображаем метрики
276
+ col1, col2, col3 = st.columns(3)
277
+ col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
278
+ col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
279
+ col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
280
+ st.write('Uplift по процентилям')
281
+ st.write(catboost_uplift_by_percentile)
282
+
283
+ st.form_submit_button('Обновить графики', help='При изменении флагов')
284
+ perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
285
+ st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
286
+ prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
287
+ st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)