Spaces:
Runtime error
Runtime error
ChuckNorris
commited on
Commit
•
a22630b
1
Parent(s):
402eb1c
Initial commit
Browse files- src/test.ipynb +0 -0
- src/tools.py +5 -5
- src/web_app.py +191 -91
src/test.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/tools.py
CHANGED
@@ -50,16 +50,16 @@ def data_split(data: pd.DataFrame, treatment: pd.DataFrame, target: pd.DataFrame
|
|
50 |
|
51 |
|
52 |
def filter_by_newbie(data: pd.DataFrame, newbie_filter: str) -> pd.DataFrame:
|
53 |
-
if newbie_filter == '
|
54 |
return data
|
55 |
-
elif newbie_filter == 'Только
|
56 |
return data[data['newbie'] == 1]
|
57 |
-
elif newbie_filter == 'Только
|
58 |
return data[data['newbie'] == 0]
|
59 |
|
60 |
|
61 |
def filter_by_channel(data: pd.DataFrame, channel_filter: str) -> pd.DataFrame:
|
62 |
-
if channel_filter == '
|
63 |
return data
|
64 |
if channel_filter == 'Phone':
|
65 |
return data[data['channel'] == channel_filter]
|
@@ -70,7 +70,7 @@ def filter_by_channel(data: pd.DataFrame, channel_filter: str) -> pd.DataFrame:
|
|
70 |
|
71 |
|
72 |
def filter_by_mens(data: pd.DataFrame, mens_filter: str) -> pd.DataFrame:
|
73 |
-
if mens_filter == 'Любые
|
74 |
return data
|
75 |
if mens_filter == 'Мужские':
|
76 |
return data[data['mens'] == 1]
|
|
|
50 |
|
51 |
|
52 |
def filter_by_newbie(data: pd.DataFrame, newbie_filter: str) -> pd.DataFrame:
|
53 |
+
if newbie_filter == 'Все':
|
54 |
return data
|
55 |
+
elif newbie_filter == 'Только новые':
|
56 |
return data[data['newbie'] == 1]
|
57 |
+
elif newbie_filter == 'Только старые':
|
58 |
return data[data['newbie'] == 0]
|
59 |
|
60 |
|
61 |
def filter_by_channel(data: pd.DataFrame, channel_filter: str) -> pd.DataFrame:
|
62 |
+
if channel_filter == 'Все':
|
63 |
return data
|
64 |
if channel_filter == 'Phone':
|
65 |
return data[data['channel'] == channel_filter]
|
|
|
70 |
|
71 |
|
72 |
def filter_by_mens(data: pd.DataFrame, mens_filter: str) -> pd.DataFrame:
|
73 |
+
if mens_filter == 'Любые':
|
74 |
return data
|
75 |
if mens_filter == 'Мужские':
|
76 |
return data[data['mens'] == 1]
|
src/web_app.py
CHANGED
@@ -1,17 +1,52 @@
|
|
|
|
1 |
import pandas as pd
|
2 |
-
|
|
|
|
|
|
|
3 |
import streamlit as st
|
|
|
4 |
|
5 |
import tools
|
6 |
|
|
|
7 |
dataset, target, treatment = tools.get_data()
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
if 'filter_data' not in st.session_state.keys():
|
12 |
st.session_state.filter_data = True
|
13 |
|
14 |
-
|
15 |
st.title('Uplift lab')
|
16 |
|
17 |
st.markdown(
|
@@ -35,11 +70,11 @@ st.markdown(
|
|
35 |
"""
|
36 |
)
|
37 |
refresh = st.button('Обновить выборку')
|
38 |
-
title_subsample =
|
39 |
if refresh:
|
40 |
-
title_subsample =
|
41 |
st.dataframe(title_subsample, width=700)
|
42 |
-
st.write(f"Всего записей: {
|
43 |
|
44 |
st.write('Описание данных')
|
45 |
st.markdown(
|
@@ -63,118 +98,183 @@ st.write("Для того, чтобы лучше понять на какую а
|
|
63 |
|
64 |
with st.expander('Развернуть блок анализа данных'):
|
65 |
|
66 |
-
st.plotly_chart(tools.get_newbie_plot(
|
67 |
st.write(f'В данных примерно одинаковое количество новых и "старых клиентов". '
|
68 |
-
f'Отношение новых клиентов к старым: {(
|
69 |
|
70 |
-
st.plotly_chart(tools.get_zipcode_plot(
|
71 |
-
tmp_res =
|
72 |
st.write(f'Большинство клиентов из пригорода: {tmp_res["Surburban"]:.2f}%, из города: {tmp_res["Urban"]:.2f}% и из села: {tmp_res["Rural"]:.2f}%')
|
73 |
|
74 |
-
tmp_res =
|
75 |
-
st.plotly_chart(tools.get_channel_plot(
|
76 |
st.write(f'В прошлом году почти одинаковое количество клиентов покупало товары через телефон и сайт, {tmp_res["Phone"]:.2f}% и {tmp_res["Web"]:.2f}% соответственно,'
|
77 |
-
|
78 |
|
79 |
-
tmp_res =
|
80 |
-
st.plotly_chart(tools.get_history_segment_plot(
|
81 |
st.write(f'Как мы видим, большинство пользователей относится к сегменту \$0-\$100 ({tmp_res[0]:.2f}%), второй и '
|
82 |
-
|
83 |
st.write(f'К сегментам \$350-\$500 и \$500-\$750 относится {tmp_res[3]:.2f}% и {tmp_res[4]:.2f}% пользователей соответственно.')
|
84 |
st.write(f'Меньше всего пользователей в сегментах \$750-\$1.000 ({tmp_res[-2]:.2f}%) и \$1.000+ ({tmp_res[-1]:.2f}%).')
|
85 |
|
86 |
-
tmp_res = list(
|
87 |
-
st.plotly_chart(tools.get_recency_plot(
|
88 |
st.write(f'Большинство клиентов являются активными клиентами платформы, и совершали покупки в течение месяца ({tmp_res[0]:.2f}%)')
|
89 |
st.write('Также заметно, что 9 и 10 месяцев назад, много клиентов совершали покупки. Это может свидетельствовать о проведении'
|
90 |
-
|
91 |
st.write('Также интересно понаблюдать за долями новых клиентов в данном распределении.')
|
92 |
|
93 |
-
st.plotly_chart(tools.get_history_plot(
|
94 |
st.markdown('_График интерактивный. Двойной клик вернет в начальное состояние._')
|
95 |
st.write('Абсолютное большинство клиентов тратят \$25-\$35 на покупки, но есть и малая доля тех, кто тратит более \$3.000')
|
96 |
st.write('Интересный факт: все покупки более \$500 совершают только новые клиенты')
|
97 |
|
98 |
filters = {}
|
99 |
|
100 |
-
st.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
filters['history_segments']
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
filters['
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
|
151 |
st.error('Необходимо выбрать хотя бы один класс')
|
152 |
-
|
153 |
elif not surburban and not urban and not rural:
|
154 |
st.error('Необходимо выбрать хотя бы один почтовый индекс')
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
if
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
167 |
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
|
168 |
example = filtered_dataset.sample(sample_size)
|
169 |
-
st.write('Пример пользователей, которым будет отправлена реклама')
|
170 |
st.dataframe(example)
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
|
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
|
180 |
-
|
|
|
|
|
|
1 |
+
import catboost
|
2 |
import pandas as pd
|
3 |
+
import os
|
4 |
+
from sklift.metrics import uplift_at_k, uplift_by_percentile, qini_auc_score
|
5 |
+
from sklift.viz import plot_qini_curve, plot_uplift_curve
|
6 |
+
from sklift.models import SoloModel, TwoModels, ClassTransformation
|
7 |
import streamlit as st
|
8 |
+
import catboost
|
9 |
|
10 |
import tools
|
11 |
|
12 |
+
|
13 |
dataset, target, treatment = tools.get_data()
|
14 |
|
15 |
+
ct_cbc_model = catboost.CatBoostClassifier()
|
16 |
+
ct_cbc_model.load_model('src/models/ct_cbc.cbm')
|
17 |
+
|
18 |
+
sm_cbc_model = catboost.CatBoostClassifier()
|
19 |
+
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
|
20 |
+
|
21 |
+
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
|
22 |
+
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
|
23 |
+
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
|
24 |
+
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
|
25 |
+
|
26 |
+
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
|
27 |
+
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
|
28 |
+
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
|
29 |
+
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
|
30 |
+
|
31 |
+
data_train_index = pd.read_csv('data/data_train_index.csv')
|
32 |
+
data_test_index = pd.read_csv('data/data_test_index.csv')
|
33 |
+
treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
|
34 |
+
treatment_test_index = pd.read_csv('data/treatment_test_index.csv')
|
35 |
+
target_train_index = pd.read_csv('data/target_train_index.csv')
|
36 |
+
target_test_index = pd.read_csv('data/target_test_index.csv')
|
37 |
+
|
38 |
+
|
39 |
+
# фиксируем выборки, чтобы результат работы ML был предсказуем
|
40 |
+
data_train = dataset.loc[data_train_index['0']]
|
41 |
+
data_test = dataset.loc[data_test_index['0']]
|
42 |
+
treatment_train = treatment.loc[treatment_train_index['0']]
|
43 |
+
treatment_test = treatment.loc[treatment_test_index['0']]
|
44 |
+
target_train = target.loc[target_train_index['0']]
|
45 |
+
target_test = target.loc[target_test_index['0']]
|
46 |
|
47 |
if 'filter_data' not in st.session_state.keys():
|
48 |
st.session_state.filter_data = True
|
49 |
|
|
|
50 |
st.title('Uplift lab')
|
51 |
|
52 |
st.markdown(
|
|
|
70 |
"""
|
71 |
)
|
72 |
refresh = st.button('Обновить выборку')
|
73 |
+
title_subsample = data_train.sample(7)
|
74 |
if refresh:
|
75 |
+
title_subsample = data_train.sample(7)
|
76 |
st.dataframe(title_subsample, width=700)
|
77 |
+
st.write(f"Всего записей: {data_train.shape[0]}")
|
78 |
|
79 |
st.write('Описание данных')
|
80 |
st.markdown(
|
|
|
98 |
|
99 |
with st.expander('Развернуть блок анализа данных'):
|
100 |
|
101 |
+
st.plotly_chart(tools.get_newbie_plot(data_train), use_container_width=True)
|
102 |
st.write(f'В данных примерно одинаковое количество новых и "старых клиентов". '
|
103 |
+
f'Отношение новых клиентов к старым: {(data_train["newbie"] == 1).sum() / (data_train["newbie"] == 0).sum():.2f}')
|
104 |
|
105 |
+
st.plotly_chart(tools.get_zipcode_plot(data_train), use_container_width=True)
|
106 |
+
tmp_res = data_train.zip_code.value_counts(normalize=True) * 100
|
107 |
st.write(f'Большинство клиентов из пригорода: {tmp_res["Surburban"]:.2f}%, из города: {tmp_res["Urban"]:.2f}% и из села: {tmp_res["Rural"]:.2f}%')
|
108 |
|
109 |
+
tmp_res = data_train.channel.value_counts(normalize=True) * 100
|
110 |
+
st.plotly_chart(tools.get_channel_plot(data_train), use_container_width=True)
|
111 |
st.write(f'В прошлом году почти одинаковое количество клиентов покупало товары через телефон и сайт, {tmp_res["Phone"]:.2f}% и {tmp_res["Web"]:.2f}% соответственно,'
|
112 |
+
f' а {tmp_res["Multichannel"]:.2f}% клиентов покупали товары воспользовавшись двумя платформами.')
|
113 |
|
114 |
+
tmp_res = data_train.history_segment.value_counts(normalize=True) * 100
|
115 |
+
st.plotly_chart(tools.get_history_segment_plot(data_train), use_container_width=True)
|
116 |
st.write(f'Как мы видим, большинство пользователей относится к сегменту \$0-\$100 ({tmp_res[0]:.2f}%), второй и '
|
117 |
+
f'третий по количеству пользователей сегменты \$100-\$200 ({tmp_res[1]:.2f}%) и \$200-\$350 ({tmp_res[2]:.2f}%).')
|
118 |
st.write(f'К сегментам \$350-\$500 и \$500-\$750 относится {tmp_res[3]:.2f}% и {tmp_res[4]:.2f}% пользователей соответственно.')
|
119 |
st.write(f'Меньше всего пользователей в сегментах \$750-\$1.000 ({tmp_res[-2]:.2f}%) и \$1.000+ ({tmp_res[-1]:.2f}%).')
|
120 |
|
121 |
+
tmp_res = list(data_train.recency.value_counts(normalize=True) * 100)
|
122 |
+
st.plotly_chart(tools.get_recency_plot(data_train), use_container_width=True)
|
123 |
st.write(f'Большинство клиентов являются активными клиентами платформы, и совершали покупки в течение месяца ({tmp_res[0]:.2f}%)')
|
124 |
st.write('Также заметно, что 9 и 10 месяцев назад, много клиентов совершали покупки. Это может свидетельствовать о проведении'
|
125 |
+
'рекламной кампании в это время или чего-то еще.')
|
126 |
st.write('Также интересно понаблюдать за долями новых клиентов в данном распределении.')
|
127 |
|
128 |
+
st.plotly_chart(tools.get_history_plot(data_train), use_container_width=True)
|
129 |
st.markdown('_График интерактивный. Двойной клик вернет в начальное состояние._')
|
130 |
st.write('Абсолютное большинство клиентов тратят \$25-\$35 на покупки, но есть и малая доля тех, кто тратит более \$3.000')
|
131 |
st.write('Интересный факт: все покупки более \$500 совершают только новые клиенты')
|
132 |
|
133 |
filters = {}
|
134 |
|
135 |
+
with st.form(key='filter-clients'):
|
136 |
+
st.subheader('Выберем клиентов, которым отправим рекламу.')
|
137 |
+
|
138 |
+
col1, col2, col3 = st.columns(3)
|
139 |
+
|
140 |
+
channel_filter = col1.radio('Канал покупки прошлом году', options=['Все', 'Phone', 'Web', 'Multichannel'])
|
141 |
+
filters['channel_filter'] = channel_filter
|
142 |
+
|
143 |
+
newbie_filter = col2.radio('Тип клиента', options=['Все', 'Только новые', 'Только старые'])
|
144 |
+
filters['newbie_filter'] = newbie_filter
|
145 |
+
|
146 |
+
mens_filter = col3.radio('Клиенты, приобретавшие товары', options=['Любые', 'Мужские', 'Женские'])
|
147 |
+
filters['mens_filter'] = mens_filter
|
148 |
+
|
149 |
+
filters['history_segments'] = {}
|
150 |
+
|
151 |
+
col1, col2 = st.columns(2)
|
152 |
+
|
153 |
+
with col1:
|
154 |
+
st.write('Класс клиентов по объему денег, потраченных в прошлом году (history segments)')
|
155 |
+
first_group = st.checkbox('$0-$100', value=True)
|
156 |
+
if first_group:
|
157 |
+
filters['history_segments']['1) $0 - $100'] = True
|
158 |
+
second_group = st.checkbox('$100-$200', value=True)
|
159 |
+
if second_group:
|
160 |
+
filters['history_segments']['2) $100 - $200'] = True
|
161 |
+
third_group = st.checkbox('$200-$350', value=True)
|
162 |
+
if third_group:
|
163 |
+
filters['history_segments']['3) $200 - $350'] = True
|
164 |
+
fourth_group = st.checkbox('$350-$500', value=True)
|
165 |
+
if fourth_group:
|
166 |
+
filters['history_segments']['4) $350 - $500'] = True
|
167 |
+
fifth_group = st.checkbox('$500-$750', value=True)
|
168 |
+
if fifth_group:
|
169 |
+
filters['history_segments']['5) $500 - $750'] = True
|
170 |
+
sixth_group = st.checkbox('$750-$1.000', value=True)
|
171 |
+
if sixth_group:
|
172 |
+
filters['history_segments']['6) $750 - $1,000'] = True
|
173 |
+
seventh_group = st.checkbox('$1.000+', value=True)
|
174 |
+
if seventh_group:
|
175 |
+
filters['history_segments']['7) $1,000 +'] = True
|
176 |
+
|
177 |
+
with col2:
|
178 |
+
st.write('Каких пользователей по почтовому коду выберем')
|
179 |
+
filters['zip_code'] = {}
|
180 |
+
surburban = st.checkbox('Surburban', value=True)
|
181 |
+
if surburban:
|
182 |
+
filters['zip_code']['surburban'] = True
|
183 |
+
urban = st.checkbox('Urban', value=True)
|
184 |
+
if urban:
|
185 |
+
filters['zip_code']['urban'] = True
|
186 |
+
rural = st.checkbox('Rural', value=True)
|
187 |
+
if rural:
|
188 |
+
filters['zip_code']['rural'] = True
|
189 |
+
|
190 |
+
recency = st.slider(label='Месяцев с момента покупки', min_value=int(data_test.recency.min()), max_value=int(data_test.recency.max()), value=(int(data_test.recency.min()), int(data_test.recency.max())))
|
191 |
+
filters['recency'] = recency
|
192 |
+
|
193 |
+
st.write('Если известно на какой процент аудитории необходимо повлиять, измените значение')
|
194 |
+
k = st.slider(label='Процент аудитории', min_value=1, max_value=100, value=100)
|
195 |
+
|
196 |
+
filter_form_submit_button = st.form_submit_button('Применить фильтр')
|
197 |
+
|
198 |
+
|
199 |
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
|
200 |
st.error('Необходимо выбрать хотя бы один класс')
|
201 |
+
st.stop()
|
202 |
elif not surburban and not urban and not rural:
|
203 |
st.error('Необходимо выбрать хотя бы один почтовый индекс')
|
204 |
+
st.stop()
|
205 |
+
|
206 |
+
filtered_dataset = tools.filter_data(data_test, filters)
|
207 |
+
if filtered_dataset is None:
|
208 |
+
st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
|
209 |
+
st.stop()
|
210 |
+
|
211 |
+
# значение uplift для записей тех клиентов, который выбрал пользователь равен 1
|
212 |
+
uplift = [1 for _ in filtered_dataset.index]
|
213 |
+
target_filtered = target_test.loc[filtered_dataset.index]
|
214 |
+
treatment_filtered = treatment_test.loc[filtered_dataset.index]
|
215 |
+
|
216 |
+
with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
|
217 |
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
|
218 |
example = filtered_dataset.sample(sample_size)
|
|
|
219 |
st.dataframe(example)
|
220 |
+
res = st.button('Обновить')
|
221 |
+
|
222 |
+
|
223 |
+
with st.form(key='user_metricks'):
|
224 |
+
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
225 |
+
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
226 |
+
user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
|
227 |
+
user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
|
228 |
+
col1, col2, col3 = st.columns(3)
|
229 |
+
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
|
230 |
+
col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя')
|
231 |
+
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
232 |
+
st.write('Uplift по процентилям')
|
233 |
+
st.write(user_metric_uplift_by_percentile)
|
234 |
+
|
235 |
+
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
236 |
+
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
237 |
+
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
238 |
+
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
239 |
+
st.pyplot(plot_uplift_curve(target_filtered, uplift, treatment_filtered, perfect=prefect_uplift).figure_)
|
240 |
+
|
241 |
+
|
242 |
+
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
243 |
+
if show_ml_reasons:
|
244 |
+
with st.expander('Решение с помощью CatBoost'):
|
245 |
+
tm_ctrl = TwoModels(
|
246 |
+
estimator_trmnt=tm_dependend_trmntl_cbc,
|
247 |
+
estimator_ctrl=tm_dependend_ctrl_cbc,
|
248 |
+
method='ddr_control'
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
tm_ctrl = tm_ctrl.fit(
|
253 |
+
data_train, target_train, treatment_train,
|
254 |
+
estimator_trmnt_fit_params={
|
255 |
+
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
|
256 |
+
estimator_ctrl_fit_params={
|
257 |
+
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
|
258 |
+
)
|
259 |
|
260 |
+
uplift_tm_ctrl = tm_ctrl.predict(data_test)
|
261 |
|
262 |
+
tm_ctrl_score = uplift_at_k(y_true=target_test, uplift=uplift_tm_ctrl, treatment=treatment_test,
|
263 |
+
strategy='by_group', k=k)
|
264 |
+
# считаем метрики для ML
|
265 |
+
catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
|
266 |
+
catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
267 |
+
catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
268 |
+
catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
269 |
+
# отображаем метрики
|
270 |
+
col1, col2, col3 = st.columns(3)
|
271 |
+
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
|
272 |
+
col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
|
273 |
+
col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
|
274 |
+
st.write('Uplift по процентилям')
|
275 |
+
st.write(catboost_uplift_by_percentile)
|
276 |
|
277 |
+
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
278 |
+
st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
|
279 |
+
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
280 |
+
st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)
|