geolocator / app.py
Samhita's picture
update coarse distance metric
c64af88
import base64
import json
import mimetypes
# import mimetypes
import os
import sys
from io import BytesIO
from typing import Dict, Tuple, Union
import banana_dev as banana
import geopy.distance
import gradio as gr
import pandas as pd
import plotly
import plotly.express as px
# import requests
from dotenv import load_dotenv
from smart_open import open as smartopen
sys.path.append("..")
from gantry_callback.gantry_util import GantryImageToTextLogger # noqa: E402
from gantry_callback.s3_util import ( # noqa: E402
add_access_policy,
enable_bucket_versioning,
get_or_create_bucket,
get_uri_of,
make_key,
make_unique_bucket_name,
)
from gantry_callback.string_img_util import read_b64_string # noqa: E402
load_dotenv()
URL = os.getenv("ENDPOINT")
GANTRY_APP_NAME = os.getenv("GANTRY_APP_NAME")
GANTRY_KEY = os.getenv("GANTRY_API_KEY")
MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN")
BANANA_API_KEY = os.getenv("BANANA_API_KEY")
BANANA_MODEL_KEY = os.getenv("BANANA_MODEL_KEY")
examples = json.load(open("examples.json"))
def compute_distance(map_data: Dict[str, Dict[str, Union[str, float, None]]]):
hierarchy_lat, hierarchy_long = (
map_data["hierarchy"]["latitude"],
map_data["hierarchy"]["longitude"],
)
coarse_lat, coarse_long = (
map_data["coarse"]["latitude"],
map_data["coarse"]["longitude"],
)
fine_lat, fine_long = (
map_data["fine"]["latitude"],
map_data["fine"]["longitude"],
)
hierarchy_to_coarse = geopy.distance.geodesic(
(hierarchy_lat, hierarchy_long), (coarse_lat, coarse_long)
).miles
hierarchy_to_fine = geopy.distance.geodesic(
(hierarchy_lat, hierarchy_long), (fine_lat, fine_long)
).miles
return hierarchy_to_coarse, hierarchy_to_fine
def get_plotly_graph(
map_data: Dict[str, Dict[str, Union[str, float, None]]]
) -> plotly.graph_objects.Figure:
hierarchy_to_coarse, hierarchy_to_fine = compute_distance(map_data)
what_to_consider = {"hierarchy"}
if hierarchy_to_coarse > 5000:
what_to_consider.add("coarse")
if hierarchy_to_fine > 30:
what_to_consider.add("fine")
size_map = {"hierarchy": 3, "fine": 1, "coarse": 1}
lat_long_data = []
for subdivision, location_data in map_data.items():
if subdivision in what_to_consider:
lat_long_data.append(
[
subdivision,
float(location_data["latitude"]),
float(location_data["longitude"]),
location_data["location"],
size_map[subdivision],
]
)
map_df = pd.DataFrame(
lat_long_data,
columns=["subdivision", "latitude", "longitude", "location", "size"],
)
px.set_mapbox_access_token(MAPBOX_TOKEN)
fig = px.scatter_mapbox(
map_df,
lat="latitude",
lon="longitude",
hover_name="location",
hover_data=["latitude", "longitude", "subdivision"],
color="subdivision",
color_discrete_map={
"hierarchy": "fuchsia",
"coarse": "blue",
"fine": "yellow",
},
zoom=2,
height=500,
size="size",
)
fig.update_layout(mapbox_style="dark")
fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})
return fig
def gradio_error():
raise gr.Error("Unable to detect the location!")
def get_outputs(
data: Dict[str, Dict[str, Union[str, float, None]]]
) -> Tuple[str, str, plotly.graph_objects.Figure]:
if data is None:
gradio_error()
location, latitude, longitude = (
data["hierarchy"]["location"],
data["hierarchy"]["latitude"],
data["hierarchy"]["longitude"],
)
if location is None:
gradio_error()
return (
location,
f"{latitude},{longitude}",
get_plotly_graph(map_data=data),
)
def image_gradio(img_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
# data = json.loads(
# requests.post(
# f"{URL}predict-image",
# files={
# "image": (
# img_file,
# open(img_file, "rb"),
# mimetypes.guess_type(img_file)[0],
# )
# },
# ).text
# )
with open(img_file, "rb") as image_file:
image_bytes = BytesIO(image_file.read())
data = banana.run(
BANANA_API_KEY,
BANANA_MODEL_KEY,
{
"image": base64.b64encode(image_bytes.getvalue()).decode("utf-8"),
"filename": os.path.basename(img_file),
},
)["modelOutputs"][0]
return get_outputs(data=data)
def _upload_video_to_s3(video_b64_string):
bucket = get_or_create_bucket(
make_unique_bucket_name(prefix="geolocator-app", seed="420")
)
enable_bucket_versioning(bucket)
add_access_policy(bucket)
data_type, video_buffer = read_b64_string(video_b64_string, return_data_type=True)
video_bytes = video_buffer.read()
key = make_key(video_bytes, filetype=data_type)
s3_uri = get_uri_of(bucket, key)
with smartopen(s3_uri, "wb") as s3_object:
s3_object.write(video_bytes)
return s3_uri
def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
# data = json.loads(
# requests.post(
# f"{URL}predict-video",
# files={
# "video": (
# video_file,
# open(video_file, "rb"),
# "application/octet-stream",
# )
# },
# ).text
# )
with open(video_file, "rb") as video_file:
video_b64_string = base64.b64encode(
BytesIO(video_file.read()).getvalue()
).decode("utf8")
video_mime = mimetypes.guess_type(video_file)[0]
s3_uri = _upload_video_to_s3(f"data:{video_mime};base64," + video_b64_string)
data = banana.run(
BANANA_API_KEY,
BANANA_MODEL_KEY,
{
"video": s3_uri,
"filename": os.path.basename(video_file),
},
)["modelOutputs"][0]
return get_outputs(data=data)
def url_gradio(url: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
# data = json.loads(
# requests.post(
# f"{URL}predict-url",
# headers={"content-type": "text/plain"},
# data=url,
# ).text
# )
data = banana.run(BANANA_API_KEY, BANANA_MODEL_KEY, {"url": url},)[
"modelOutputs"
][0]
return get_outputs(data=data)
with gr.Blocks() as demo:
gr.Markdown("# GeoLocator")
gr.Markdown(
"### An app that guesses the location of an image 🌌 or a YouTube video link πŸ”—."
)
with gr.Tab("Image"):
with gr.Row():
img_input = gr.Image(type="filepath", label="Image")
with gr.Column():
img_text_output = gr.Textbox(label="Location")
img_coordinates = gr.Textbox(label="Coordinates")
img_plot = gr.Plot()
img_text_button = gr.Button("Go locate!")
with gr.Row():
# Flag button
img_flag_button = gr.Button("Flag this output")
gr.Examples(examples["images"], inputs=[img_input])
# with gr.Tab("Video"):
# with gr.Row():
# video_input = gr.Video(type="filepath", label="Video")
# with gr.Column():
# video_text_output = gr.Textbox(label="Location")
# video_coordinates = gr.Textbox(label="Coordinates")
# video_plot = gr.Plot()
# video_text_button = gr.Button("Go locate!")
# gr.Examples(examples["videos"], inputs=[video_input])
with gr.Tab("YouTube Link"):
with gr.Row():
url_input = gr.Textbox(label="Link")
with gr.Column():
url_text_output = gr.Textbox(label="Location")
url_coordinates = gr.Textbox(label="Coordinates")
url_plot = gr.Plot()
url_text_button = gr.Button("Go locate!")
gr.Examples(examples["video_urls"], inputs=[url_input])
# Gantry flagging for image #
callback = GantryImageToTextLogger(application=GANTRY_APP_NAME, api_key=GANTRY_KEY)
callback.setup(
components=[img_input, img_text_output],
flagging_dir=make_unique_bucket_name(prefix=GANTRY_APP_NAME, seed="420"),
)
img_flag_button.click(
fn=lambda *args: callback.flag(args),
inputs=[img_input, img_text_output, img_coordinates],
outputs=None,
preprocess=False,
)
###################
img_text_button.click(
image_gradio,
inputs=img_input,
outputs=[img_text_output, img_coordinates, img_plot],
)
# video_text_button.click(
# video_gradio,
# inputs=video_input,
# outputs=[video_text_output, video_coordinates, video_plot],
# )
url_text_button.click(
url_gradio,
inputs=url_input,
outputs=[url_text_output, url_coordinates, url_plot],
)
gr.Markdown(
"Check out the [GitHub repository](https://github.com/samhita-alla/geolocator) that this demo is based off of."
)
gr.Markdown(
"#### To understand what subdivision means, refer to the [Geolocation paper](https://openaccess.thecvf.com/content_ECCV_2018/papers/Eric_Muller-Budack_Geolocation_Estimation_of_ECCV_2018_paper.pdf)."
)
gr.Markdown(
"#### TL;DR Fine and Coarse are spatial resolutions and Hierarchy generates predictions at fine scale but incorporates knowledge from coarse and middle partitionings."
)
demo.launch()