File size: 5,271 Bytes
23a53cd
 
 
 
 
 
 
 
 
 
 
5dd3935
23a53cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd3935
 
 
 
 
23a53cd
5dd3935
23a53cd
 
 
 
 
5dd3935
23a53cd
 
 
 
 
 
 
5dd3935
 
 
 
 
 
23a53cd
 
5dd3935
23a53cd
5dd3935
 
 
23a53cd
 
 
 
 
 
 
 
 
 
5dd3935
 
 
 
 
 
 
 
7a7548d
23a53cd
5dd3935
23a53cd
 
 
 
 
 
 
5dd3935
735264e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from plot_functions import *
import hydra

import torch
from model import LitUnsupervisedSegmenter
from helper import inference_on_location_and_month, inference_on_location
from plot_functions import segment_region

from functools import partial
import gradio as gr
import logging
import sys

import geopandas as gpd
mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w"
geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities'))

def get_geomap(long, lat ):
    fig = go.Figure(go.Scattermapbox(
            lat=geo_df.geometry.y,
            lon=geo_df.geometry.x,
            mode='markers',
            marker=go.scattermapbox.Marker(
                size=14
            ),
            text=geo_df.name,
        ))
    
    fig.add_trace(go.Scattermapbox(lat=[lat],
        lon=[long],
        mode='markers',
        marker=go.scattermapbox.Marker(
            size=14
        ),
        marker_color="green",
        text=['Actual position']))

    fig.update_layout(
        showlegend=False,
        hovermode='closest',
        mapbox=dict(
            accesstoken=mapbox_access_token,
            center=go.layout.mapbox.Center(
                lat=lat,
                lon=long
            ),
            zoom=3
        )
    )
    
    return fig


if __name__ == "__main__":
    file_handler = logging.FileHandler(filename='biomap.log')
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    handlers = [file_handler, stdout_handler]

    logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    # Initialize hydra with configs
    hydra.initialize(config_path="configs", job_name="corine")
    cfg = hydra.compose(config_name="my_train_config.yml")
    logging.info(f"config : {cfg}")

    nbclasses = cfg.dir_dataset_n_classes
    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    model = model.cpu()
    logging.info(f"Model Initialiazed")
    
    model_path = "biomap/checkpoint/model/model.pt"
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
    logging.info(f"Model weights Loaded")
    model.load_state_dict(saved_state_dict)
    logging.info(f"Model Loaded")
    with gr.Blocks(title="Biomap by Ekimetrics") as demo:
        gr.Markdown("<h1><center>🐢 Biomap by Ekimetrics 🐢</center></h1>")
        gr.Markdown("<h4><center>Estimate Biodiversity score in the world by using segmentation of land.</center></h4>")
        gr.Markdown("Land use is divided into 6 differents classes :Each class is assigned a GBS score from 0 to 1")
        gr.Markdown("Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 ")
        gr.Markdown("The score is then average on the full image.")
        with gr.Tab("Single Image"):
            with gr.Row():
                input_map = gr.Plot() 
                with gr.Column():
                    with gr.Row():
                        input_latitude = gr.Number(label="lattitude", value=2.98)
                        input_longitude = gr.Number(label="longitude", value=48.81)
                    input_date = gr.Textbox(label="start_date", value="2020-03-20")

            single_button = gr.Button("Predict")
            with gr.Row():
                raw_image = gr.Image(label = "Localisation visualization")
                output_image = gr.Image(label = "Labeled visualisation")
                score_biodiv = gr.Number(label = "Biodiversity score")

        with gr.Tab("TimeLapse"):
            with gr.Row():
                input_map_2 = gr.Plot()
                with gr.Column():
                    with gr.Row():
                        timelapse_input_latitude = gr.Number(value=2.98, label="Latitude")
                        timelapse_input_longitude = gr.Number(value=48.81, label="Longitude")
                    with gr.Row():
                        timelapse_start_date = gr.Dropdown(choices=[2017,2018,2019,2020,2021,2022,2023], value=2020, label="Start Date")
                        timelapse_end_date = gr.Dropdown(choices=[2017,2018,2019,2020,2021,2022,2023], value=2021, label="End Date")
                    segmentation = gr.Radio(choices=['month', 'year', '2months'], value='year', label="Interval of time between two segmentation")
            timelapse_button = gr.Button(value="Predict")     
            map = gr.Plot()
        
        demo.load(get_geomap, [input_latitude, input_longitude], input_map)
        single_button.click(get_geomap, [input_latitude, input_longitude], input_map)
        single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv])

        demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
        timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
        timelapse_button.click(partial(inference_on_location, model), inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map])
    demo.launch()