ChuckNorris commited on
Commit
a22630b
1 Parent(s): 402eb1c

Initial commit

Browse files
Files changed (3) hide show
  1. src/test.ipynb +0 -0
  2. src/tools.py +5 -5
  3. src/web_app.py +191 -91
src/test.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/tools.py CHANGED
@@ -50,16 +50,16 @@ def data_split(data: pd.DataFrame, treatment: pd.DataFrame, target: pd.DataFrame
50
 
51
 
52
  def filter_by_newbie(data: pd.DataFrame, newbie_filter: str) -> pd.DataFrame:
53
- if newbie_filter == 'Всем':
54
  return data
55
- elif newbie_filter == 'Только новым':
56
  return data[data['newbie'] == 1]
57
- elif newbie_filter == 'Только старым':
58
  return data[data['newbie'] == 0]
59
 
60
 
61
  def filter_by_channel(data: pd.DataFrame, channel_filter: str) -> pd.DataFrame:
62
- if channel_filter == 'Всем':
63
  return data
64
  if channel_filter == 'Phone':
65
  return data[data['channel'] == channel_filter]
@@ -70,7 +70,7 @@ def filter_by_channel(data: pd.DataFrame, channel_filter: str) -> pd.DataFrame:
70
 
71
 
72
  def filter_by_mens(data: pd.DataFrame, mens_filter: str) -> pd.DataFrame:
73
- if mens_filter == 'Любые товары':
74
  return data
75
  if mens_filter == 'Мужские':
76
  return data[data['mens'] == 1]
 
50
 
51
 
52
  def filter_by_newbie(data: pd.DataFrame, newbie_filter: str) -> pd.DataFrame:
53
+ if newbie_filter == 'Все':
54
  return data
55
+ elif newbie_filter == 'Только новые':
56
  return data[data['newbie'] == 1]
57
+ elif newbie_filter == 'Только старые':
58
  return data[data['newbie'] == 0]
59
 
60
 
61
  def filter_by_channel(data: pd.DataFrame, channel_filter: str) -> pd.DataFrame:
62
+ if channel_filter == 'Все':
63
  return data
64
  if channel_filter == 'Phone':
65
  return data[data['channel'] == channel_filter]
 
70
 
71
 
72
  def filter_by_mens(data: pd.DataFrame, mens_filter: str) -> pd.DataFrame:
73
+ if mens_filter == 'Любые':
74
  return data
75
  if mens_filter == 'Мужские':
76
  return data[data['mens'] == 1]
src/web_app.py CHANGED
@@ -1,17 +1,52 @@
 
1
  import pandas as pd
2
-
 
 
 
3
  import streamlit as st
 
4
 
5
  import tools
6
 
 
7
  dataset, target, treatment = tools.get_data()
8
 
9
- data_train, data_test, treatment_train, treatment_test, target_train, target_test = tools.data_split(dataset, treatment, target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  if 'filter_data' not in st.session_state.keys():
12
  st.session_state.filter_data = True
13
 
14
-
15
  st.title('Uplift lab')
16
 
17
  st.markdown(
@@ -35,11 +70,11 @@ st.markdown(
35
  """
36
  )
37
  refresh = st.button('Обновить выборку')
38
- title_subsample = data_test.sample(7)
39
  if refresh:
40
- title_subsample = data_test.sample(7)
41
  st.dataframe(title_subsample, width=700)
42
- st.write(f"Всего записей: {data_test.shape[0]}")
43
 
44
  st.write('Описание данных')
45
  st.markdown(
@@ -63,118 +98,183 @@ st.write("Для того, чтобы лучше понять на какую а
63
 
64
  with st.expander('Развернуть блок анализа данных'):
65
 
66
- st.plotly_chart(tools.get_newbie_plot(data_test), use_container_width=True)
67
  st.write(f'В данных примерно одинаковое количество новых и "старых клиентов". '
68
- f'Отношение новых клиентов к старым: {(data_test["newbie"] == 1).sum() / (data_test["newbie"] == 0).sum():.2f}')
69
 
70
- st.plotly_chart(tools.get_zipcode_plot(data_test), use_container_width=True)
71
- tmp_res = data_test.zip_code.value_counts(normalize=True) * 100
72
  st.write(f'Большинство клиентов из пригорода: {tmp_res["Surburban"]:.2f}%, из города: {tmp_res["Urban"]:.2f}% и из села: {tmp_res["Rural"]:.2f}%')
73
 
74
- tmp_res = data_test.channel.value_counts(normalize=True) * 100
75
- st.plotly_chart(tools.get_channel_plot(data_test), use_container_width=True)
76
  st.write(f'В прошлом году почти одинаковое количество клиентов покупало товары через телефон и сайт, {tmp_res["Phone"]:.2f}% и {tmp_res["Web"]:.2f}% соответственно,'
77
- f' а {tmp_res["Multichannel"]:.2f}% клиентов покупали товары воспользовавшись двумя платформами.')
78
 
79
- tmp_res = data_test.history_segment.value_counts(normalize=True) * 100
80
- st.plotly_chart(tools.get_history_segment_plot(data_test), use_container_width=True)
81
  st.write(f'Как мы видим, большинство пользователей относится к сегменту \$0-\$100 ({tmp_res[0]:.2f}%), второй и '
82
- f'третий по количеству пользователей сегменты \$100-\$200 ({tmp_res[1]:.2f}%) и \$200-\$350 ({tmp_res[2]:.2f}%).')
83
  st.write(f'К сегментам \$350-\$500 и \$500-\$750 относится {tmp_res[3]:.2f}% и {tmp_res[4]:.2f}% пользователей соответственно.')
84
  st.write(f'Меньше всего пользователей в сегментах \$750-\$1.000 ({tmp_res[-2]:.2f}%) и \$1.000+ ({tmp_res[-1]:.2f}%).')
85
 
86
- tmp_res = list(data_test.recency.value_counts(normalize=True) * 100)
87
- st.plotly_chart(tools.get_recency_plot(data_test), use_container_width=True)
88
  st.write(f'Большинство клиентов являются активными клиентами платформы, и совершали покупки в течение месяца ({tmp_res[0]:.2f}%)')
89
  st.write('Также заметно, что 9 и 10 месяцев назад, много клиентов совершали покупки. Это может свидетельствовать о проведении'
90
- 'рекламной кампании в это время или чего-то еще.')
91
  st.write('Также интересно понаблюдать за долями новых клиентов в данном распределении.')
92
 
93
- st.plotly_chart(tools.get_history_plot(data_test), use_container_width=True)
94
  st.markdown('_График интерактивный. Двойной клик вернет в начальное состояние._')
95
  st.write('Абсолютное большинство клиентов тратят \$25-\$35 на покупки, но есть и малая доля тех, кто тратит более \$3.000')
96
  st.write('Интересный факт: все покупки более \$500 совершают только новые клиенты')
97
 
98
  filters = {}
99
 
100
- st.subheader('Выберем клиентов, которым отправим рекламу.')
101
- newbie_filter = st.radio('Каким клиентам отправим рекламу?', options=['Всем', 'Только новым', 'Только старым'])
102
- filters['newbie_filter'] = newbie_filter
103
-
104
- channel_filter = st.radio('Канал, по которому клиент покупал в прошлом году', options=['Всем', 'Phone', 'Web', 'Multichannel'])
105
- filters['channel_filter'] = channel_filter
106
-
107
- mens_filter = st.radio('Клиенты, приобретавшие', options=['Любые товары', 'Мужские', 'Женские'])
108
- filters['mens_filter'] = mens_filter
109
-
110
- st.write('Выберите класс клиентов, по объему денег, потраченных в прошлом году (history segments)')
111
- filters['history_segments'] = {}
112
- first_group = st.checkbox('$0-$100', value=True)
113
- if first_group:
114
- filters['history_segments']['1) $0 - $100'] = True
115
- second_group = st.checkbox('$100-$200', value=True)
116
- if second_group:
117
- filters['history_segments']['2) $100 - $200'] = True
118
- third_group = st.checkbox('$200-$350', value=True)
119
- if third_group:
120
- filters['history_segments']['3) $200 - $350'] = True
121
- fourth_group = st.checkbox('$350-$500', value=True)
122
- if fourth_group:
123
- filters['history_segments']['4) $350 - $500'] = True
124
- fifth_group = st.checkbox('$500-$750', value=True)
125
- if fifth_group:
126
- filters['history_segments']['5) $500 - $750'] = True
127
- sixth_group = st.checkbox('$750-$1.000', value=True)
128
- if sixth_group:
129
- filters['history_segments']['6) $750 - $1,000'] = True
130
- seventh_group = st.checkbox('$1.000+', value=True)
131
- if seventh_group:
132
- filters['history_segments']['7) $1,000 +'] = True
133
-
134
- st.write('Каких пользователей по почтовому коду выберем')
135
- filters['zip_code'] = {}
136
- surburban = st.checkbox('Surburban', value=True)
137
- if surburban:
138
- filters['zip_code']['surburban'] = True
139
- urban = st.checkbox('Urban', value=True)
140
- if urban:
141
- filters['zip_code']['urban'] = True
142
- rural = st.checkbox('Rural', value=True)
143
- if rural:
144
- filters['zip_code']['rural'] = True
145
-
146
- recency = st.slider(label='Месяцев с момента покупки', min_value=int(data_test.recency.min()), max_value=int(data_test.recency.max()), value=(int(data_test.recency.min()), int(data_test.recency.max())))
147
- filters['recency'] = recency
148
-
149
- disabled = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
151
  st.error('Необходимо выбрать хотя бы один класс')
152
- disabled = True
153
  elif not surburban and not urban and not rural:
154
  st.error('Необходимо выбрать хотя бы один почтовый индекс')
155
- disabled = True
156
-
157
-
158
- if not disabled:
159
- filtered_dataset = tools.filter_data(data_test, filters)
160
- if filtered_dataset is None:
161
- st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
162
- # значение uplift для записей тех клиентов, который выбрал пользователь равен 1
163
- import numpy as np
164
- uplift = [1 for _ in filtered_dataset.index]
165
- target_filtered = target_test.loc[filtered_dataset.index]
166
- treatment_filtered = treatment_test.loc[filtered_dataset.index]
 
167
  sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
168
  example = filtered_dataset.sample(sample_size)
169
- st.write('Пример пользователей, которым будет отправлена реклама')
170
  st.dataframe(example)
171
- st.info(f'Количество клиентов, которым реклама будет отправлена: _**{filtered_dataset.shape[0]}**_ ({filtered_dataset.shape[0] / data_train.shape[0] * 100 :.2f}% от вс��х клиентов)')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
 
173
 
174
- send_promo = st.button('Отправить рекламу и посмотреть результат', disabled=disabled)
175
- if send_promo:
176
- st.write(tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered))
177
- # st.write(tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered))
 
 
 
 
 
 
 
 
 
 
178
 
179
- # st.write('Если известно, на какой процент пользователей необходимо воздействовать, укажите это ниже')
180
- # st.slider(label='Процент пользователей', min_value=0, max_value=100, value=100)
 
 
 
1
+ import catboost
2
  import pandas as pd
3
+ import os
4
+ from sklift.metrics import uplift_at_k, uplift_by_percentile, qini_auc_score
5
+ from sklift.viz import plot_qini_curve, plot_uplift_curve
6
+ from sklift.models import SoloModel, TwoModels, ClassTransformation
7
  import streamlit as st
8
+ import catboost
9
 
10
  import tools
11
 
12
+
13
  dataset, target, treatment = tools.get_data()
14
 
15
+ ct_cbc_model = catboost.CatBoostClassifier()
16
+ ct_cbc_model.load_model('src/models/ct_cbc.cbm')
17
+
18
+ sm_cbc_model = catboost.CatBoostClassifier()
19
+ sm_cbc_model.load_model('src/models/sm_cbc.cbm')
20
+
21
+ tm_ctrl_cbc_model = catboost.CatBoostClassifier()
22
+ tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
23
+ tm_trmnt_cbc_model = catboost.CatBoostClassifier()
24
+ tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
25
+
26
+ tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
27
+ tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
28
+ tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
29
+ tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
30
+
31
+ data_train_index = pd.read_csv('data/data_train_index.csv')
32
+ data_test_index = pd.read_csv('data/data_test_index.csv')
33
+ treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
34
+ treatment_test_index = pd.read_csv('data/treatment_test_index.csv')
35
+ target_train_index = pd.read_csv('data/target_train_index.csv')
36
+ target_test_index = pd.read_csv('data/target_test_index.csv')
37
+
38
+
39
+ # фиксируем выборки, чтобы результат работы ML был предсказуем
40
+ data_train = dataset.loc[data_train_index['0']]
41
+ data_test = dataset.loc[data_test_index['0']]
42
+ treatment_train = treatment.loc[treatment_train_index['0']]
43
+ treatment_test = treatment.loc[treatment_test_index['0']]
44
+ target_train = target.loc[target_train_index['0']]
45
+ target_test = target.loc[target_test_index['0']]
46
 
47
  if 'filter_data' not in st.session_state.keys():
48
  st.session_state.filter_data = True
49
 
 
50
  st.title('Uplift lab')
51
 
52
  st.markdown(
 
70
  """
71
  )
72
  refresh = st.button('Обновить выборку')
73
+ title_subsample = data_train.sample(7)
74
  if refresh:
75
+ title_subsample = data_train.sample(7)
76
  st.dataframe(title_subsample, width=700)
77
+ st.write(f"Всего записей: {data_train.shape[0]}")
78
 
79
  st.write('Описание данных')
80
  st.markdown(
 
98
 
99
  with st.expander('Развернуть блок анализа данных'):
100
 
101
+ st.plotly_chart(tools.get_newbie_plot(data_train), use_container_width=True)
102
  st.write(f'В данных примерно одинаковое количество новых и "старых клиентов". '
103
+ f'Отношение новых клиентов к старым: {(data_train["newbie"] == 1).sum() / (data_train["newbie"] == 0).sum():.2f}')
104
 
105
+ st.plotly_chart(tools.get_zipcode_plot(data_train), use_container_width=True)
106
+ tmp_res = data_train.zip_code.value_counts(normalize=True) * 100
107
  st.write(f'Большинство клиентов из пригорода: {tmp_res["Surburban"]:.2f}%, из города: {tmp_res["Urban"]:.2f}% и из села: {tmp_res["Rural"]:.2f}%')
108
 
109
+ tmp_res = data_train.channel.value_counts(normalize=True) * 100
110
+ st.plotly_chart(tools.get_channel_plot(data_train), use_container_width=True)
111
  st.write(f'В прошлом году почти одинаковое количество клиентов покупало товары через телефон и сайт, {tmp_res["Phone"]:.2f}% и {tmp_res["Web"]:.2f}% соответственно,'
112
+ f' а {tmp_res["Multichannel"]:.2f}% клиентов покупали товары воспользовавшись двумя платформами.')
113
 
114
+ tmp_res = data_train.history_segment.value_counts(normalize=True) * 100
115
+ st.plotly_chart(tools.get_history_segment_plot(data_train), use_container_width=True)
116
  st.write(f'Как мы видим, большинство пользователей относится к сегменту \$0-\$100 ({tmp_res[0]:.2f}%), второй и '
117
+ f'третий по количеству пользователей сегменты \$100-\$200 ({tmp_res[1]:.2f}%) и \$200-\$350 ({tmp_res[2]:.2f}%).')
118
  st.write(f'К сегментам \$350-\$500 и \$500-\$750 относится {tmp_res[3]:.2f}% и {tmp_res[4]:.2f}% пользователей соответственно.')
119
  st.write(f'Меньше всего пользователей в сегментах \$750-\$1.000 ({tmp_res[-2]:.2f}%) и \$1.000+ ({tmp_res[-1]:.2f}%).')
120
 
121
+ tmp_res = list(data_train.recency.value_counts(normalize=True) * 100)
122
+ st.plotly_chart(tools.get_recency_plot(data_train), use_container_width=True)
123
  st.write(f'Большинство клиентов являются активными клиентами платформы, и совершали покупки в течение месяца ({tmp_res[0]:.2f}%)')
124
  st.write('Также заметно, что 9 и 10 месяцев назад, много клиентов совершали покупки. Это может свидетельствовать о проведении'
125
+ 'рекламной кампании в это время или чего-то еще.')
126
  st.write('Также интересно понаблюдать за долями новых клиентов в данном распределении.')
127
 
128
+ st.plotly_chart(tools.get_history_plot(data_train), use_container_width=True)
129
  st.markdown('_График интерактивный. Двойной клик вернет в начальное состояние._')
130
  st.write('Абсолютное большинство клиентов тратят \$25-\$35 на покупки, но есть и малая доля тех, кто тратит более \$3.000')
131
  st.write('Интересный факт: все покупки более \$500 совершают только новые клиенты')
132
 
133
  filters = {}
134
 
135
+ with st.form(key='filter-clients'):
136
+ st.subheader('Выберем клиентов, которым отправим рекламу.')
137
+
138
+ col1, col2, col3 = st.columns(3)
139
+
140
+ channel_filter = col1.radio('Канал покупки прошлом году', options=['Все', 'Phone', 'Web', 'Multichannel'])
141
+ filters['channel_filter'] = channel_filter
142
+
143
+ newbie_filter = col2.radio('Тип клиента', options=['Все', 'Только новые', 'Только старые'])
144
+ filters['newbie_filter'] = newbie_filter
145
+
146
+ mens_filter = col3.radio('Клиенты, приобретавшие товары', options=['Любые', 'Мужские', 'Женские'])
147
+ filters['mens_filter'] = mens_filter
148
+
149
+ filters['history_segments'] = {}
150
+
151
+ col1, col2 = st.columns(2)
152
+
153
+ with col1:
154
+ st.write('Класс клиентов по объему денег, потраченных в прошлом году (history segments)')
155
+ first_group = st.checkbox('$0-$100', value=True)
156
+ if first_group:
157
+ filters['history_segments']['1) $0 - $100'] = True
158
+ second_group = st.checkbox('$100-$200', value=True)
159
+ if second_group:
160
+ filters['history_segments']['2) $100 - $200'] = True
161
+ third_group = st.checkbox('$200-$350', value=True)
162
+ if third_group:
163
+ filters['history_segments']['3) $200 - $350'] = True
164
+ fourth_group = st.checkbox('$350-$500', value=True)
165
+ if fourth_group:
166
+ filters['history_segments']['4) $350 - $500'] = True
167
+ fifth_group = st.checkbox('$500-$750', value=True)
168
+ if fifth_group:
169
+ filters['history_segments']['5) $500 - $750'] = True
170
+ sixth_group = st.checkbox('$750-$1.000', value=True)
171
+ if sixth_group:
172
+ filters['history_segments']['6) $750 - $1,000'] = True
173
+ seventh_group = st.checkbox('$1.000+', value=True)
174
+ if seventh_group:
175
+ filters['history_segments']['7) $1,000 +'] = True
176
+
177
+ with col2:
178
+ st.write('Каких пользователей по почтовому коду выберем')
179
+ filters['zip_code'] = {}
180
+ surburban = st.checkbox('Surburban', value=True)
181
+ if surburban:
182
+ filters['zip_code']['surburban'] = True
183
+ urban = st.checkbox('Urban', value=True)
184
+ if urban:
185
+ filters['zip_code']['urban'] = True
186
+ rural = st.checkbox('Rural', value=True)
187
+ if rural:
188
+ filters['zip_code']['rural'] = True
189
+
190
+ recency = st.slider(label='Месяцев с момента покупки', min_value=int(data_test.recency.min()), max_value=int(data_test.recency.max()), value=(int(data_test.recency.min()), int(data_test.recency.max())))
191
+ filters['recency'] = recency
192
+
193
+ st.write('Если известно на какой процент аудитории необходимо повлиять, измените значение')
194
+ k = st.slider(label='Процент аудитории', min_value=1, max_value=100, value=100)
195
+
196
+ filter_form_submit_button = st.form_submit_button('Применить фильтр')
197
+
198
+
199
  if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
200
  st.error('Необходимо выбрать хотя бы один класс')
201
+ st.stop()
202
  elif not surburban and not urban and not rural:
203
  st.error('Необходимо выбрать хотя бы один почтовый индекс')
204
+ st.stop()
205
+
206
+ filtered_dataset = tools.filter_data(data_test, filters)
207
+ if filtered_dataset is None:
208
+ st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
209
+ st.stop()
210
+
211
+ # значение uplift для записей тех клиентов, который выбрал пользователь равен 1
212
+ uplift = [1 for _ in filtered_dataset.index]
213
+ target_filtered = target_test.loc[filtered_dataset.index]
214
+ treatment_filtered = treatment_test.loc[filtered_dataset.index]
215
+
216
+ with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
217
  sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
218
  example = filtered_dataset.sample(sample_size)
 
219
  st.dataframe(example)
220
+ res = st.button('Обновить')
221
+
222
+
223
+ with st.form(key='user_metricks'):
224
+ user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
225
+ user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
226
+ user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
227
+ user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
228
+ col1, col2, col3 = st.columns(3)
229
+ col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
230
+ col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя')
231
+ col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
232
+ st.write('Uplift по процентилям')
233
+ st.write(user_metric_uplift_by_percentile)
234
+
235
+ st.form_submit_button('Обновить графики', help='При изменении флагов')
236
+ perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
237
+ st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
238
+ prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
239
+ st.pyplot(plot_uplift_curve(target_filtered, uplift, treatment_filtered, perfect=prefect_uplift).figure_)
240
+
241
+
242
+ show_ml_reasons = st.checkbox('Показать решения с помощью ML')
243
+ if show_ml_reasons:
244
+ with st.expander('Решение с помощью CatBoost'):
245
+ tm_ctrl = TwoModels(
246
+ estimator_trmnt=tm_dependend_trmntl_cbc,
247
+ estimator_ctrl=tm_dependend_ctrl_cbc,
248
+ method='ddr_control'
249
+ )
250
+
251
+
252
+ tm_ctrl = tm_ctrl.fit(
253
+ data_train, target_train, treatment_train,
254
+ estimator_trmnt_fit_params={
255
+ 'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
256
+ estimator_ctrl_fit_params={
257
+ 'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
258
+ )
259
 
260
+ uplift_tm_ctrl = tm_ctrl.predict(data_test)
261
 
262
+ tm_ctrl_score = uplift_at_k(y_true=target_test, uplift=uplift_tm_ctrl, treatment=treatment_test,
263
+ strategy='by_group', k=k)
264
+ # считаем метрики для ML
265
+ catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
266
+ catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
267
+ catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
268
+ catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
269
+ # отображаем метрики
270
+ col1, col2, col3 = st.columns(3)
271
+ col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
272
+ col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
273
+ col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
274
+ st.write('Uplift по процентилям')
275
+ st.write(catboost_uplift_by_percentile)
276
 
277
+ perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
278
+ st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
279
+ prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
280
+ st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)