Biomap / biomap /plot_functions.py
jeremyLE-Ekimetrics's picture
clean
5dd3935
from PIL import Image
import matplotlib as mpl
from utils import prep_for_plot
import torch.multiprocessing
import torchvision.transforms as T
from utils_gee import extract_img, transform_ee_img
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from plotly.subplots import make_subplots
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
cmap = mpl.colors.ListedColormap(colors)
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
scores_init = [1,2,4,3,4,1,0]
# Function that look for img on EE and segment it
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
def segment_loc(model, location, month, year, how = "month", month_end = '12', year_end = None) :
if how == 'month':
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
elif how == 'year' :
if year_end == None :
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
else :
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
img_test= transform_ee_img(img, max = 0.25)
# Preprocess opened img
x = preprocess(img_test)
x = torch.unsqueeze(x, dim=0).cpu()
# model=model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_preds = model.linear_probe(x, code)
linear_preds = linear_preds.argmax(1)
outputs = {
'img': x[:model.cfg.n_images].detach().cpu(),
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
}
return outputs
# Function that look for all img on EE and extract all segments with the date as first output arg
def segment_group(location, start_date, end_date, how = 'month') :
outputs = []
st_month = int(start_date[5:7])
end_month = int(end_date[5:7])
st_year = int(start_date[0:4])
end_year = int(end_date[0:4])
for year in range(st_year, end_year+1) :
if year != end_year :
last = 12
else :
last = end_month
if year != st_year:
start = 1
else :
start = st_month
if how == 'month' :
for month in range(start, last + 1):
month_str = f"{month:0>2d}"
year_str = str(year)
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
elif how == 'year' :
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
elif how == '2months' :
for month in range(start, last + 1):
month_str = f"{month:0>2d}"
year_str = str(year)
month_end = (month) % 12 +1
if month_end < month :
year_end = year +1
else :
year_end = year
month_end= f"{month_end:0>2d}"
year_end = str(year_end)
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
return outputs
def values_from_output(output):
imgs = transform_to_pil(output, alpha = 0.3)
img = imgs[0]
img = np.array(img.convert('RGB'))
labeled_img = imgs[2]
labeled_img = np.array(labeled_img.convert('RGB'))
nb_values = []
for i in range(7):
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
return img, labeled_img, nb_values, score
# Function that extract from outputs (from segment_group function) all dates/ all images
def values_from_outputs(outputs) :
months = []
imgs = []
imgs_label = []
nb_values = []
scores = []
for output in outputs:
img, labeled_img, nb_value, score = values_from_output(output[1])
months.append(output[0])
imgs.append(img)
imgs_label.append(labeled_img)
nb_values.append(nb_value)
scores.append(score)
return months, imgs, imgs_label, nb_values, scores
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
# Scores
scatters = []
temp = []
for score in scores :
temp_score = []
temp_date = []
score = scores[i]
temp.append(score)
text_temp = ["" for i in temp]
text_temp[-1] = str(round(score,2))
scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
# Scores
fig = make_subplots(
rows=1, cols=4,
# specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
# row_heights=[0.8, 0.2],
column_widths = [0.6, 0.6,0.3, 0.3],
subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
)
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
fig.add_trace(go.Pie(labels = class_names,
values = nb_values[0],
marker_colors = colors,
name="Segment repartition",
textposition='inside',
texttemplate = "%{percent:.0%}",
textfont_size=14
),
row=1, col=3)
fig.add_trace(scatters[0], row=1, col=4)
# fig.add_annotation(text='score:' + str(scores[0]),
# showarrow=False,
# row=2, col=2)
number_frames = len(imgs)
frames = [dict(
name = k,
data = [ fig2["frames"][k]["data"][0],
fig3["frames"][k]["data"][0],
go.Pie(labels = class_names,
values = nb_values[k],
marker_colors = colors,
name="Segment repartition",
textposition='inside',
texttemplate = "%{percent:.0%}",
textfont_size=14
),
scatters[k]
],
traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
# that are updated by the above three go.Scatter instances
) for k in range(number_frames)]
updatemenus = [dict(type='buttons',
buttons=[dict(label='Play',
method='animate',
args=[[f'{k}' for k in range(number_frames)],
dict(frame=dict(duration=500, redraw=False),
transition=dict(duration=0),
easing='linear',
fromcurrent=True,
mode='immediate'
)])],
direction= 'left',
pad=dict(r= 10, t=85),
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
]
sliders = [{'yanchor': 'top',
'xanchor': 'left',
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
'transition': {'duration': 500.0, 'easing': 'linear'},
'pad': {'b': 10, 't': 50},
'len': 0.9, 'x': 0.1, 'y': 0,
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
'transition': {'duration': 0, 'easing': 'linear'}}],
'label': months[k], 'method': 'animate'} for k in range(number_frames)
]}]
fig.update(frames=frames)
for i,fr in enumerate(fig["frames"]):
fr.update(
layout={
"xaxis": {
"range": [0,imgs[0].shape[1]+i/100000]
},
"yaxis": {
"range": [imgs[0].shape[0]+i/100000,0]
},
})
fr.update(layout_title_text= months[i])
fig.update(layout_title_text= 'tot')
fig.update(
layout={
"xaxis": {
"range": [0,imgs[0].shape[1]+i/100000],
'showgrid': False, # thin lines in the background
'zeroline': False, # thick line at x=0
'visible': False, # numbers below
},
"yaxis": {
"range": [imgs[0].shape[0]+i/100000,0],
'showgrid': False, # thin lines in the background
'zeroline': False, # thick line at y=0
'visible': False,},
"xaxis3": {
"range": [0,len(scores)+1],
'autorange': False, # thin lines in the background
'showgrid': False, # thin lines in the background
'zeroline': False, # thick line at y=0
'visible': False
},
"yaxis3": {
"range": [0,1.5],
'autorange': False,
'showgrid': False, # thin lines in the background
'zeroline': False, # thick line at y=0
'visible': False # thin lines in the background
}
},
legend=dict(
yanchor="bottom",
y=0.99,
xanchor="center",
x=0.01
)
)
fig.update_layout(updatemenus=updatemenus,
sliders=sliders)
fig.update_layout(margin=dict(b=0, r=0))
# fig.show() #in jupyter notebook
return fig
# Last function (global one)
# how = 'month' or '2months' or 'year'
def segment_region(location, start_date, end_date, how = 'month'):
#extract the outputs for each image
outputs = segment_group(location, start_date, end_date, how = how)
#extract the intersting values from image
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
#Create the figure
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
return fig
#normalize img
preprocess = T.Compose([
T.ToPILImage(),
T.Resize((320,320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Function that look for img on EE and segment it
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
def segment_loc(model,location, month, year, how = "month", month_end = '12', year_end = None) :
if how == 'month':
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
elif how == 'year' :
if year_end == None :
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
else :
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
img_test= transform_ee_img(img, max = 0.25)
# Preprocess opened img
x = preprocess(img_test)
x = torch.unsqueeze(x, dim=0).cpu()
# model=model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_preds = model.linear_probe(x, code)
linear_preds = linear_preds.argmax(1)
outputs = {
'img': x[:model.cfg.n_images].detach().cpu(),
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
}
return outputs
# Function that look for all img on EE and extract all segments with the date as first output arg
def segment_group(location, start_date, end_date, how = 'month') :
outputs = []
st_month = int(start_date[5:7])
end_month = int(end_date[5:7])
st_year = int(start_date[0:4])
end_year = int(end_date[0:4])
for year in range(st_year, end_year+1) :
if year != end_year :
last = 12
else :
last = end_month
if year != st_year:
start = 1
else :
start = st_month
if how == 'month' :
for month in range(start, last + 1):
month_str = f"{month:0>2d}"
year_str = str(year)
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
elif how == 'year' :
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
elif how == '2months' :
for month in range(start, last + 1):
month_str = f"{month:0>2d}"
year_str = str(year)
month_end = (month) % 12 +1
if month_end < month :
year_end = year +1
else :
year_end = year
month_end= f"{month_end:0>2d}"
year_end = str(year_end)
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
return outputs
# Function that transforms an output to PIL images
def transform_to_pil(outputs,alpha=0.3):
# Transform img with torch
img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
img=T.ToPILImage()(img)
# Transform label by saving it then open it
# label = outputs['linear_preds'][0]
# plt.imsave('label.png',label,cmap=cmap)
# label = Image.open('label.png')
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
labels = np.array(outputs['linear_preds'][0])-1
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
# Overlay labels with img wit alpha
background = img.convert("RGBA")
overlay = label.convert("RGBA")
labeled_img = Image.blend(background, overlay, alpha)
return img, label, labeled_img
def values_from_output(output):
imgs = transform_to_pil(output,alpha = 0.3)
img = imgs[0]
img = np.array(img.convert('RGB'))
labeled_img = imgs[2]
labeled_img = np.array(labeled_img.convert('RGB'))
nb_values = []
for i in range(7):
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
return img, labeled_img, nb_values, score
# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
# Function that extract from outputs (from segment_group function) all dates/ all images
def values_from_outputs(outputs) :
months = []
imgs = []
imgs_label = []
nb_values = []
scores = []
for output in outputs:
img, labeled_img, nb_value, score = values_from_output(output[1])
months.append(output[0])
imgs.append(img)
imgs_label.append(labeled_img)
nb_values.append(nb_value)
scores.append(score)
return months, imgs, imgs_label, nb_values, scores
# Last function (global one)
# how = 'month' or '2months' or 'year'
def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
location = [float(latitude),float(longitude)]
how = how[0]
#extract the outputs for each image
outputs = segment_group(location, start_date, end_date, how = how)
#extract the intersting values from image
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
print(months, imgs, imgs_label, nb_values, scores)
#Create the figure
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
return fig