VN-Housing-App / screens /predict.py
A-New-Day-001's picture
Update screens/predict.py
46f6438
raw
history blame contribute delete
No virus
6.07 kB
import streamlit as st
import json
from autogluon.multimodal import MultiModalPredictor
import pandas as pd
from geopy.geocoders import GoogleV3
import os
import tempfile
def predict_page():
if "price_text" not in st.session_state:
st.session_state.price_text = 0
@st.cache_resource
def load_mm_text_no_price_model():
return MultiModalPredictor.load("models/mm-text-no-price/", verbosity=0)
mm_text_no_price_predictor = load_mm_text_no_price_model()
@st.cache_resource
def load_city_map():
return json.load(open("city-map.json"))
city_map = load_city_map()
@st.cache_resource
def load_city_district_map():
return json.load(open("city-district-map.json"))
city_district_map = load_city_district_map()
CERT_STATUS = pd.CategoricalDtype(
categories=["Không có", "hợp đồng", "sổ đỏ / sổ hồng"], ordered=False
)
DIRECTION = pd.CategoricalDtype(
categories=[
"Không có",
"Tây - Nam",
"Đông - Nam",
"Đông - Bắc",
"Tây - Bắc",
"Nam",
"Tây",
"Bắc",
"Đông",
],
ordered=False,
)
CITY = pd.CategoricalDtype(categories=city_map.keys(), ordered=False)
DISTRICT = pd.CategoricalDtype(
categories=sum([list(map(int, v.keys())) for v in city_district_map.values()], []),
ordered=False,
)
location_options = st.columns([1, 1, 2, 1, 1])
with location_options[0]:
city = st.selectbox(
"Choose city", options=city_map.items(), format_func=lambda x: x[1]
)
with location_options[1]:
district = st.selectbox(
"Choose district",
options=city_district_map[city[0]].items(),
format_func=lambda x: x[1],
)
with location_options[2]:
location = st.text_input("Enter precise location")
location = (location + ", " if location else "") + city[1] + ", " + district[1]
geocode_result = geocoder.geocode(query=location, region="vn", language="vi")
latitude = float("nan")
longitude = float("nan")
with location_options[3]:
latitude = st.number_input(
"Enter latitude", value=latitude, step=1e-8, format="%.7f"
)
with location_options[4]:
longitude = st.number_input(
"Enter longitude", value=longitude, step=1e-8, format="%.7f"
)
numerical_options = st.columns(6)
with numerical_options[0]:
area = st.number_input("Area (m2)", min_value=1.0)
with numerical_options[1]:
bedrooms = st.number_input("Number of bedrooms", min_value=1, value=1)
with numerical_options[2]:
bathrooms = st.number_input("Number of bathrooms", min_value=1, value=1)
with numerical_options[3]:
floors = st.number_input("Number of floors", min_value=1, value=1)
with numerical_options[4]:
front_width = st.number_input(
"Front width, leave 0 for N/A", min_value=0.0, value=0.0, step=0.1
)
with numerical_options[5]:
road_width = st.number_input(
"Road width, leave 0 for N/A", min_value=0.0, value=0.0, step=0.1
)
cat_time_columns = st.columns(4)
with cat_time_columns[0]:
timestamp = st.date_input("Date posted", format="DD/MM/YYYY")
with cat_time_columns[1]:
cert_status = st.selectbox("Certification status", options=CERT_STATUS.categories)
with cat_time_columns[2]:
direction = st.selectbox("Direction", options=DIRECTION.categories)
with cat_time_columns[3]:
balcony_direction = st.selectbox("Balcony direction", options=DIRECTION.categories)
description = st.text_area("Description")
title = description.split(".", maxsplit=1)[0]
uploaded_image = st.file_uploader("Upload an image")
image_tmp = None
if uploaded_image:
image_tmp = tempfile.NamedTemporaryFile(suffix=uploaded_image.name)
image_tmp.write(uploaded_image.read())
print(image_tmp.name)
df = pd.DataFrame(
[
{
"Title": title,
"Area": area,
"Location": location,
"Time stamp": timestamp,
"Certification status": cert_status,
"Direction": direction,
"Bedrooms": bedrooms,
"Bathrooms": bathrooms,
"Front width": front_width or float("nan"),
"Floor": floors,
"Description": description,
"Image URL": image_tmp.name if image_tmp else None,
"Road width": road_width or float("nan"),
"City_code": city[0],
"DistrictId": int(district[0]),
"Lattitude": latitude,
"Longitude": longitude,
"Balcony_Direction": balcony_direction,
}
]
).astype(
{
"Title": "str",
"Area": "float",
"Location": "str",
"Time stamp": "datetime64[ns]",
"Certification status": CERT_STATUS,
"Direction": DIRECTION,
"Bedrooms": "int",
"Bathrooms": "int",
"Front width": "float",
"Floor": "int",
"Description": "str",
"Image URL": "str",
"Road width": "float",
"City_code": CITY,
"DistrictId": DISTRICT,
"Lattitude": "float",
"Longitude": "float",
"Balcony_Direction": DIRECTION,
}
)
if st.button("Get estimated price with text"):
st.session_state.price_text = mm_text_no_price_predictor.predict(
df, as_pandas=False
).item()
st.text(
"Estimated price: {0:,} VND".format(int(st.session_state.price_text * 1e6))
if st.session_state.price_text
else "No price estimated."
)