Emil25 commited on
Commit
f6cf9b2
·
verified ·
1 Parent(s): f8449b3

Upload 5 files

Browse files
img/taxi_img.png ADDED
main.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import plotly.graph_objects as go
4
+ from geopy.geocoders import Nominatim
5
+ import pandas as pd
6
+ from datetime import datetime
7
+ import holidays
8
+ import numpy as np
9
+ from sklearn.preprocessing import MinMaxScaler
10
+ import pickle
11
+ import xgboost as xgb
12
+
13
+ # Setting up the page configuration for Streamlit App
14
+ st.set_page_config(
15
+ page_title="Taxi",
16
+ # layout="wide",
17
+ initial_sidebar_state="expanded"
18
+ )
19
+
20
+ # Load the XGBoost model
21
+ #@st.cache_data()
22
+ def get_model():
23
+ model = pickle.load(open("models/model_xgb.pkl", "rb"))
24
+ return model
25
+
26
+ # Function to make prediction using the model and input data
27
+ def make_prediction(data):
28
+ model = get_model()
29
+ best_features = ['vendor_id', 'passenger_count', 'pickup_longitude', 'pickup_latitude',
30
+ 'dropoff_longitude', 'dropoff_latitude', 'store_and_fwd_flag',
31
+ 'pickup_hour', 'pickup_holiday', 'total_distance', 'total_travel_time',
32
+ 'number_of_steps', 'haversine_distance', 'temperature',
33
+ 'pickup_day_of_week_1', 'pickup_day_of_week_2', 'pickup_day_of_week_3',
34
+ 'pickup_day_of_week_4', 'pickup_day_of_week_5', 'pickup_day_of_week_6',
35
+ 'geo_cluster_1', 'geo_cluster_3', 'geo_cluster_5', 'geo_cluster_7',
36
+ 'geo_cluster_9']
37
+ data_matrix = xgb.DMatrix(data, feature_names=best_features)
38
+ return model.predict(data_matrix)
39
+
40
+
41
+ def get_coordinates(address):
42
+ # Создание экземпляра геокодера
43
+ geolocator = Nominatim(user_agent="my_app")
44
+
45
+ # Получение координат по адресу
46
+ location = geolocator.geocode(address)
47
+
48
+ # Вывод широты и долготы
49
+ return (location.longitude, location.latitude)
50
+
51
+
52
+ def show_map(lon_from, lat_from, lon_to, lat_to):
53
+ # Создание карты
54
+ fig = go.Figure(go.Scattermapbox(
55
+ mode = "markers",
56
+ marker = {'size': 15, 'color': 'red'}
57
+ ))
58
+
59
+ # Добавление флажков для точек
60
+ fig.add_trace(go.Scattermapbox(
61
+ mode = "markers",
62
+ lon = [lon_from, lon_to],
63
+ lat = [lat_from, lat_to],
64
+ marker = go.scattermapbox.Marker(
65
+ size=25,
66
+ color='red'
67
+ )
68
+ ))
69
+
70
+ # Добавление линии между точками
71
+ fig.add_trace(go.Scattermapbox(
72
+ mode = "lines",
73
+ lon = [lon_from, lon_to],
74
+ lat = [lat_from, lat_to],
75
+ line = dict(width=2, color='green')
76
+ ))
77
+
78
+ # Настройка отображения карты
79
+ fig.update_layout(
80
+ mapbox = {
81
+ 'style': "open-street-map", # Стиль карты
82
+ 'center': {'lon': (lon_from + lon_to) / 2, 'lat': (lat_from + lat_to) / 2}, # Центр карты
83
+ 'zoom': 9, # Уровень масштабирования карты
84
+ },
85
+ showlegend = False,
86
+ height = 600, # Изменение высоты карты
87
+ width = 1200 # Изменение ширины карты
88
+ )
89
+
90
+ # Отображение карты
91
+ return fig
92
+
93
+
94
+ # Get total distance
95
+ def get_total_distance(start_longitude, start_latitude, end_longitude, end_latitude):
96
+ # Construct the URL for sending a request to the public OSRM server
97
+ url = f"http://router.project-osrm.org/route/v1/driving/{start_longitude},{start_latitude};{end_longitude},{end_latitude}?overview=false"
98
+
99
+ # Send a GET request to the OSRM server
100
+ response = requests.get(url)
101
+
102
+ # Process the response from the server
103
+ if response.status_code == 200:
104
+ data = response.json()
105
+ total_distance = data["routes"][0]["distance"] # Total distance in meters
106
+ total_travel_time = data["routes"][0]["duration"] # Total travel time in seconds
107
+ number_of_steps = len(data["routes"][0]["legs"][0]["steps"]) # Number of steps in the
108
+ return total_distance, total_travel_time, number_of_steps
109
+
110
+
111
+ # Get Harversine distance
112
+ def get_haversine_distance(lat1, lng1, lat2, lng2):
113
+ # Convert angles to radians
114
+ lat1, lng1, lat2, lng2 = map(np.radians, (lat1, lng1, lat2, lng2))
115
+ # Earth's radius in kilometers
116
+ EARTH_RADIUS = 6371
117
+ # Calculate the shortest distance h using the Haversine formula
118
+ lat_delta = lat2 - lat1
119
+ lng_delta = lng2 - lng1
120
+ d = np.sin(lat_delta * 0.5) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(lng_delta * 0.5) ** 2
121
+ h = 2 * EARTH_RADIUS * np.arcsin(np.sqrt(d))
122
+ return h
123
+
124
+
125
+ # User input features
126
+ def user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count, temperature):
127
+ current_time = datetime.now()
128
+ pickup_hour= current_time.hour
129
+ today = datetime.today()
130
+ pickup_holiday = 1 if today in holidays.USA() else 0
131
+ total_distance, total_travel_time, number_of_steps = get_total_distance(lon_from, lat_from, lon_to, lat_to)
132
+ haversine_distance = get_haversine_distance(lat_from, lon_from, lat_to, lon_to)
133
+ weekday_number = current_time.weekday()
134
+
135
+ data = {'vendor_id': 1,
136
+ 'passenger_count': passenger_count,
137
+ 'pickup_longitude': lon_from,
138
+ 'pickup_latitude': lat_from,
139
+ 'dropoff_longitude': lon_to,
140
+ 'dropoff_latitude': lat_to,
141
+ 'store_and_fwd_flag': 0.0,
142
+ 'pickup_hour': pickup_hour,
143
+ 'pickup_holiday': pickup_holiday,
144
+ 'total_distance': total_distance,
145
+ 'total_travel_time': total_travel_time,
146
+ 'number_of_steps': number_of_steps,
147
+ 'haversine_distance': haversine_distance,
148
+ 'temperature': temperature,
149
+ 'pickup_day_of_week_1': 1 if weekday_number == 1 else 0,
150
+ 'pickup_day_of_week_2': 1 if weekday_number == 2 else 0,
151
+ 'pickup_day_of_week_3': 1 if weekday_number == 3 else 0,
152
+ 'pickup_day_of_week_4': 1 if weekday_number == 4 else 0,
153
+ 'pickup_day_of_week_5': 1 if weekday_number == 5 else 0,
154
+ 'pickup_day_of_week_6': 1 if weekday_number == 6 else 0,
155
+ 'geo_cluster_1':1,
156
+ 'geo_cluster_3':0,
157
+ 'geo_cluster_5':0,
158
+ 'geo_cluster_7':0,
159
+ 'geo_cluster_9':0
160
+ }
161
+ features = pd.DataFrame(data, index=[0])
162
+ return features
163
+
164
+
165
+ # Scale the input data using a pre-trained MinMaxScaler
166
+ def min_max_scaler(data):
167
+ scaler = pickle.load(open("models/min_max_scaler.pkl", "rb"))
168
+ data_scaled = scaler.transform(data)
169
+ return data_scaled
170
+
171
+ # Main function
172
+ def main():
173
+
174
+ if 'btn_predict' not in st.session_state:
175
+ st.session_state['btn_predict'] = False
176
+
177
+ # Sidebar
178
+ st.sidebar.markdown(''' # New York City Taxi Trip Duration''')
179
+ st.sidebar.image("img/taxi_img.png")
180
+ address_from = st.sidebar.text_input("Откуда:", value="New York, Liberty Island")
181
+ address_to = st.sidebar.text_input("Куда:", value="New York, 20 W 34th St")
182
+ passenger_count = st.sidebar.slider("Количество пассажиров", 1, 4, 1)
183
+ temperature = st.sidebar.slider("Temperature (C)", -20, 40, 15)
184
+ st.session_state['btn_predict'] = st.sidebar.button('Start')
185
+
186
+ if st.session_state['btn_predict']:
187
+ lon_from, lat_from = get_coordinates(address_from)
188
+ lon_to, lat_to = get_coordinates(address_to)
189
+ st.plotly_chart(show_map(lon_from, lat_from, lon_to, lat_to))
190
+ user_data = user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count, temperature)
191
+ # st.write(user_data)
192
+ data_scaled = min_max_scaler(user_data)
193
+ trip_duration = np.exp(make_prediction(data_scaled)) - 1
194
+ trip_duration = round(float(trip_duration) / 60)
195
+ st.markdown(f"""
196
+ <div style='background-color: lightgreen; padding: 10px;'>
197
+ <h2 style='color: black; text-align: center;'>Длительность поездки составит: {trip_duration} мин.</h2>
198
+ </div>
199
+ """, unsafe_allow_html=True)
200
+
201
+ # Running the main function
202
+ if __name__ == "__main__":
203
+ main()
models/min_max_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2974581d9e870affcb1eaa0eb86290630840e6262eddd70d604f8415a789493a
3
+ size 2036
models/model_xgb.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83d6e22ee287b9a2e15efefbebb5376b89815f6b06212eb48a338ad745a046bf
3
+ size 1330844
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ requests
3
+ plotly
4
+ geopy
5
+ datetime
6
+ pandas
7
+ holidays
8
+ scikit-learn
9
+ xgboost