jeremyLE-Ekimetrics commited on
Commit
23a53cd
·
1 Parent(s): 9780f87

Update biomap/app.py

Browse files
Files changed (1) hide show
  1. biomap/app.py +110 -110
biomap/app.py CHANGED
@@ -1,110 +1,110 @@
1
- from plot_functions import *
2
- import hydra
3
-
4
- import torch
5
- from model import LitUnsupervisedSegmenter
6
- from helper import inference_on_location_and_month, inference_on_location
7
- from plot_functions import segment_region
8
-
9
- from functools import partial
10
- import gradio as gr
11
- import logging
12
-
13
- import geopandas as gpd
14
- mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w"
15
- geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities'))
16
-
17
- def get_geomap(long, lat ):
18
-
19
-
20
- fig = go.Figure(go.Scattermapbox(
21
- lat=geo_df.geometry.y,
22
- lon=geo_df.geometry.x,
23
- mode='markers',
24
- marker=go.scattermapbox.Marker(
25
- size=14
26
- ),
27
- text=geo_df.name,
28
- ))
29
-
30
- fig.add_trace(go.Scattermapbox(lat=[lat],
31
- lon=[long],
32
- mode='markers',
33
- marker=go.scattermapbox.Marker(
34
- size=14
35
- ),
36
- marker_color="green",
37
- text=['Actual position']))
38
-
39
- fig.update_layout(
40
- showlegend=False,
41
- hovermode='closest',
42
- mapbox=dict(
43
- accesstoken=mapbox_access_token,
44
- center=go.layout.mapbox.Center(
45
- lat=lat,
46
- lon=long
47
- ),
48
- zoom=3
49
- )
50
- )
51
-
52
- return fig
53
-
54
-
55
- if __name__ == "__main__":
56
-
57
-
58
- logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
59
- # Initialize hydra with configs
60
- #hydra.initialize(config_path="configs", job_name="corine")
61
- cfg = hydra.compose(config_name="my_train_config.yml")
62
- logging.info(f"config : {cfg}")
63
- # Load the model
64
-
65
- nbclasses = cfg.dir_dataset_n_classes
66
- model = LitUnsupervisedSegmenter(nbclasses, cfg)
67
- logging.info(f"Model Initialiazed")
68
-
69
- model_path = "checkpoint/model/model.pt"
70
- saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
71
- logging.info(f"Model weights Loaded")
72
- model.load_state_dict(saved_state_dict)
73
- logging.info(f"Model Loaded")
74
- # css=".VIDEO video{height: 100%;width:50%;margin:auto};.VIDEO{height: 50%;};.svelte-1vnmhm4{height:auto}"
75
- with gr.Blocks() as demo:
76
- gr.Markdown("Estimate Biodiversity in the world.")
77
- with gr.Tab("Single Image"):
78
- with gr.Row():
79
- input_map = gr.Plot().style()
80
- with gr.Column():
81
- input_latitude = gr.Number(label="lattitude", value=2.98)
82
- input_longitude = gr.Number(label="longitude", value=48.81)
83
- input_date = gr.Textbox(label="start_date", value="2020-03-20")
84
-
85
- single_button = gr.Button("Predict")
86
- with gr.Row():
87
- raw_image = gr.Image(label = "Localisation visualization")
88
- output_image = gr.Image(label = "Labeled visualisation")
89
- score_biodiv = gr.Number(label = "Biodiversity score")
90
-
91
- with gr.Tab("TimeLapse"):
92
- with gr.Row():
93
- input_map_2 = gr.Plot().style()
94
- with gr.Row():
95
- timelapse_input_latitude = gr.Number(value=2.98, label="Latitude")
96
- timelapse_input_longitude = gr.Number(value=48.81, label="Longitude")
97
- timelapse_start_date = gr.Textbox(value='2020-05-01', label="Start Date")
98
- timelapse_end_date = gr.Textbox(value='2020-06-30', label="End Date")
99
- segmentation = gr.CheckboxGroup(choices=['month', 'year', '2months'], value=['month'], label="Select Segmentation Level:")
100
- timelapse_button = gr.Button(value="Predict")
101
- map = gr.Plot().style()
102
-
103
- demo.load(get_geomap, [input_latitude, input_longitude], input_map)
104
- single_button.click(get_geomap, [input_latitude, input_longitude], input_map)
105
- single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv])
106
-
107
- demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
108
- timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
109
- timelapse_button.click(segment_region, inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map])
110
- demo.launch(share=True)
 
1
+ from plot_functions import *
2
+ import hydra
3
+
4
+ import torch
5
+ from model import LitUnsupervisedSegmenter
6
+ from helper import inference_on_location_and_month, inference_on_location
7
+ from plot_functions import segment_region
8
+
9
+ from functools import partial
10
+ import gradio as gr
11
+ import logging
12
+
13
+ import geopandas as gpd
14
+ mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w"
15
+ geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities'))
16
+
17
+ def get_geomap(long, lat ):
18
+
19
+
20
+ fig = go.Figure(go.Scattermapbox(
21
+ lat=geo_df.geometry.y,
22
+ lon=geo_df.geometry.x,
23
+ mode='markers',
24
+ marker=go.scattermapbox.Marker(
25
+ size=14
26
+ ),
27
+ text=geo_df.name,
28
+ ))
29
+
30
+ fig.add_trace(go.Scattermapbox(lat=[lat],
31
+ lon=[long],
32
+ mode='markers',
33
+ marker=go.scattermapbox.Marker(
34
+ size=14
35
+ ),
36
+ marker_color="green",
37
+ text=['Actual position']))
38
+
39
+ fig.update_layout(
40
+ showlegend=False,
41
+ hovermode='closest',
42
+ mapbox=dict(
43
+ accesstoken=mapbox_access_token,
44
+ center=go.layout.mapbox.Center(
45
+ lat=lat,
46
+ lon=long
47
+ ),
48
+ zoom=3
49
+ )
50
+ )
51
+
52
+ return fig
53
+
54
+
55
+ if __name__ == "__main__":
56
+
57
+
58
+ logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
59
+ # Initialize hydra with configs
60
+ #hydra.initialize(config_path="configs", job_name="corine")
61
+ cfg = hydra.compose(config_name="my_train_config.yml")
62
+ logging.info(f"config : {cfg}")
63
+ # Load the model
64
+
65
+ nbclasses = cfg.dir_dataset_n_classes
66
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
67
+ logging.info(f"Model Initialiazed")
68
+
69
+ model_path = "biomap/checkpoint/model/model.pt"
70
+ saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
71
+ logging.info(f"Model weights Loaded")
72
+ model.load_state_dict(saved_state_dict)
73
+ logging.info(f"Model Loaded")
74
+ # css=".VIDEO video{height: 100%;width:50%;margin:auto};.VIDEO{height: 50%;};.svelte-1vnmhm4{height:auto}"
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("Estimate Biodiversity in the world.")
77
+ with gr.Tab("Single Image"):
78
+ with gr.Row():
79
+ input_map = gr.Plot().style()
80
+ with gr.Column():
81
+ input_latitude = gr.Number(label="lattitude", value=2.98)
82
+ input_longitude = gr.Number(label="longitude", value=48.81)
83
+ input_date = gr.Textbox(label="start_date", value="2020-03-20")
84
+
85
+ single_button = gr.Button("Predict")
86
+ with gr.Row():
87
+ raw_image = gr.Image(label = "Localisation visualization")
88
+ output_image = gr.Image(label = "Labeled visualisation")
89
+ score_biodiv = gr.Number(label = "Biodiversity score")
90
+
91
+ with gr.Tab("TimeLapse"):
92
+ with gr.Row():
93
+ input_map_2 = gr.Plot().style()
94
+ with gr.Row():
95
+ timelapse_input_latitude = gr.Number(value=2.98, label="Latitude")
96
+ timelapse_input_longitude = gr.Number(value=48.81, label="Longitude")
97
+ timelapse_start_date = gr.Textbox(value='2020-05-01', label="Start Date")
98
+ timelapse_end_date = gr.Textbox(value='2020-06-30', label="End Date")
99
+ segmentation = gr.CheckboxGroup(choices=['month', 'year', '2months'], value=['month'], label="Select Segmentation Level:")
100
+ timelapse_button = gr.Button(value="Predict")
101
+ map = gr.Plot().style()
102
+
103
+ demo.load(get_geomap, [input_latitude, input_longitude], input_map)
104
+ single_button.click(get_geomap, [input_latitude, input_longitude], input_map)
105
+ single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv])
106
+
107
+ demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
108
+ timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
109
+ timelapse_button.click(segment_region, inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map])
110
+ demo.launch(share=True)