jdowling commited on
Commit
6ed1571
Β·
1 Parent(s): 8e05d33

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +211 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import pickle
4
+ import joblib
5
+
6
+ import hopsworks
7
+ import streamlit as st
8
+ from geopy import distance
9
+
10
+ import plotly.express as px
11
+ import folium
12
+ from streamlit_folium import st_folium
13
+
14
+ from functions import *
15
+
16
+
17
+
18
+ def print_fancy_header(text, font_size=22, color="#ff5f27"):
19
+ res = f'<span style="color:{color}; font-size: {font_size}px;">{text}</span>'
20
+ st.markdown(res, unsafe_allow_html=True)
21
+
22
+
23
+ # I want to cache this so streamlit would run much faster after restart (it restarts a lot)
24
+ @st.cache_data()
25
+ def get_feature_view():
26
+ st.write("Getting the Feature View...")
27
+ feature_view = fs.get_feature_view(
28
+ name = 'air_quality_fv',
29
+ version = 1
30
+ )
31
+ st.write("βœ… Success!")
32
+
33
+ return feature_view
34
+
35
+
36
+ @st.cache_data()
37
+ def get_batch_data_from_fs(td_version, date_threshold):
38
+ st.write(f"Retrieving the Batch data since {date_threshold}")
39
+ feature_view.init_batch_scoring(training_dataset_version=td_version)
40
+
41
+ batch_data = feature_view.get_batch_data(start_time=date_threshold)
42
+ return batch_data
43
+
44
+
45
+ @st.cache_data()
46
+ def download_model(name="air_quality_xgboost_model",
47
+ version=1):
48
+ mr = project.get_model_registry()
49
+ retrieved_model = mr.get_model(
50
+ name="air_quality_xgboost_model",
51
+ version=1
52
+ )
53
+ saved_model_dir = retrieved_model.download()
54
+ return saved_model_dir
55
+
56
+
57
+
58
+ def plot_pm2_5(df):
59
+ # create figure with plotly express
60
+ fig = px.line(df, x='date', y='pm2_5', color='city_name')
61
+
62
+ # customize line colors and styles
63
+ fig.update_traces(mode='lines+markers')
64
+ fig.update_layout({
65
+ 'plot_bgcolor': 'rgba(0, 0, 0, 0)',
66
+ 'paper_bgcolor': 'rgba(0, 0, 0, 0)',
67
+ 'legend_title': 'City',
68
+ 'legend_font': {'size': 12},
69
+ 'legend_bgcolor': 'rgba(0, 0, 0, 0)',
70
+ 'xaxis': {'title': 'Date'},
71
+ 'yaxis': {'title': 'PM2.5'},
72
+ 'shapes': [{
73
+ 'type': 'line',
74
+ 'x0': datetime.datetime.now().strftime('%Y-%m-%d'),
75
+ 'y0': 0,
76
+ 'x1': datetime.datetime.now().strftime('%Y-%m-%d'),
77
+ 'y1': df['pm2_5'].max(),
78
+ 'line': {'color': 'red', 'width': 2, 'dash': 'dashdot'}
79
+ }]
80
+ })
81
+
82
+ # show plot
83
+ st.plotly_chart(fig, use_container_width=True)
84
+
85
+
86
+ with open('target_cities.json') as json_file:
87
+ target_cities = json.load(json_file)
88
+
89
+
90
+ #########################
91
+ st.title('🌫 Air Quality Prediction 🌦')
92
+
93
+ st.write(3 * "-")
94
+ print_fancy_header('\nπŸ“‘ Connecting to Hopsworks Feature Store...')
95
+
96
+ st.write("Logging... ")
97
+ # (Attention! If the app has stopped at this step,
98
+ # please enter your Hopsworks API Key in the commmand prompt.)
99
+ project = hopsworks.login()
100
+ fs = project.get_feature_store()
101
+ st.write("βœ… Logged in successfully!")
102
+
103
+ feature_view = get_feature_view()
104
+
105
+ # I am going to load data for of last 60 days (for feature engineering)
106
+ today = datetime.date.today()
107
+ date_threshold = today
108
+ #- datetime.timedelta(days=60)
109
+
110
+ st.write(3 * "-")
111
+ print_fancy_header('\n☁️ Retriving batch data from Feature Store...')
112
+ batch_data = get_batch_data_from_fs(td_version=1,
113
+ date_threshold=date_threshold)
114
+
115
+ st.write("Batch data:")
116
+ st.write(batch_data.sample(5))
117
+
118
+ # +
119
+ saved_model_dir = download_model(
120
+ name="air_quality_xgboost_model",
121
+ version=1
122
+ )
123
+
124
+ pipeline = joblib.load(saved_model_dir + "/xgboost_pipeline.pkl")
125
+ st.write("\n")
126
+ st.write("βœ… Model was downloaded and cached.")
127
+ # -
128
+
129
+ st.write(3 * '-')
130
+ st.write("\n")
131
+ print_fancy_header(text="πŸ– Select the cities using the form below. \
132
+ Click the 'Submit' button at the bottom of the form to continue.",
133
+ font_size=22)
134
+ dict_for_streamlit = {}
135
+ for continent in target_cities:
136
+ for city_name, coords in target_cities[continent].items():
137
+ dict_for_streamlit[city_name] = coords
138
+ selected_cities_full_list = []
139
+
140
+ with st.form(key="user_inputs"):
141
+ print_fancy_header(text='\nπŸ—Ί Here you can choose cities from the drop-down menu',
142
+ font_size=20, color="#00FFFF")
143
+
144
+ cities_multiselect = st.multiselect(label='',
145
+ options=dict_for_streamlit.keys())
146
+ selected_cities_full_list.extend(cities_multiselect)
147
+ st.write("_" * 3)
148
+ print_fancy_header(text="\nπŸ“Œ To add a city using the interactive map, click somewhere \
149
+ (for the coordinates to appear)",
150
+ font_size=20, color="#00FFFF")
151
+
152
+ my_map = folium.Map(location=[42.57, -44.092], zoom_start=2)
153
+ # Add markers for each city
154
+ for city_name, coords in dict_for_streamlit.items():
155
+ folium.CircleMarker(
156
+ location=coords
157
+ ).add_to(my_map)
158
+
159
+ my_map.add_child(folium.LatLngPopup())
160
+ res_map = st_folium(my_map, width=640, height=480)
161
+
162
+ try:
163
+ new_lat, new_long = res_map["last_clicked"]["lat"], res_map["last_clicked"]["lng"]
164
+
165
+ # Calculate the distance between the clicked location and each city
166
+ distances = {city: distance.distance(coord, (new_lat, new_long)).km for city, coord in dict_for_streamlit.items()}
167
+
168
+ # Find the city with the minimum distance and print its name
169
+ nearest_city = min(distances, key=distances.get)
170
+ print_fancy_header(text=f"You have selected {nearest_city} using map", font_size=18, color="#52fa23")
171
+
172
+ selected_cities_full_list.append(nearest_city)
173
+ # st.write(label_encoder.transform([nearest_city])[0])
174
+
175
+ except Exception as err:
176
+ print(err)
177
+ pass
178
+
179
+ submit_button = st.form_submit_button(label='Submit')
180
+
181
+ # +
182
+
183
+ if submit_button:
184
+ st.write('Selected cities:', selected_cities_full_list)
185
+
186
+ st.write(3*'-')
187
+
188
+ dataset = batch_data
189
+
190
+ dataset = dataset.sort_values(by=["city_name", "date"])
191
+
192
+ st.write("\n")
193
+ print_fancy_header(text='\n🧠 Predicting PM2.5 for selected cities...',
194
+ font_size=18, color="#FDF4F5")
195
+ st.write("")
196
+ preds = pd.DataFrame(columns=dataset.columns)
197
+ for city_name in selected_cities_full_list:
198
+ st.write(f"\t * {city_name}...")
199
+ features = dataset.loc[dataset['city_name'] == city_name]
200
+ print(features.head())
201
+ features['pm2_5'] = pipeline.predict(features)
202
+ preds = pd.concat([preds, features])
203
+
204
+ st.write("")
205
+ print_fancy_header(text="πŸ“ˆResults πŸ“‰",
206
+ font_size=22)
207
+ plot_pm2_5(preds[preds['city_name'].isin(selected_cities_full_list)])
208
+
209
+ st.write(3 * "-")
210
+ st.subheader('\nπŸŽ‰ πŸ“ˆ 🀝 App Finished Successfully 🀝 πŸ“ˆ πŸŽ‰')
211
+ st.button("Re-run")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ hopsworks==3.0.*
2
+ geopy
3
+ python-dotenv
4
+ streamlit
5
+ streamlit-folium