jeremyLE-Ekimetrics's picture
first commit
5c718d1
raw
history blame
14.5 kB
from datetime import datetime
import ee
from func_timeout import func_set_timeout
import pandas as pd
from PIL import Image
import requests
import tempfile
import io
from tqdm import tqdm
import functools
import re # Used in an eval statement
from typing import List
from typing import Union
from typing import Any
class DataLoader:
"""
Main class for loading and exploring data from satellite images.
The goal is to load an ImageCollection and to filter that collection according to needs, with methods like
filter, filterDate, filterBounds, select. These will work just like earth engine's methods with the same names.
This class, just like earth engine, works with lazy loading and compute. This means that running filterBounds
will not actually filter the image collection until required, e.g. when counting the images by accessing .count
property.
However, it will only load once the information it needs, unless additional filtering is made.
This works thanks to the signal_change decorator. If you develop a new filtering method for this class,
you will need to decorate your method with @signal_change.
In addition, if you develop a new method that will require to run getInfo to actually load data from
Google Earth Engine, you will need to use _get_timeout_info(your object before getInfo). This will run
getInfo with a timeout (currently set to 10 seconds).
It is important to use a timeout to avoid unexpected run times.
Usage:
>>> dl = DataLoader(satellite_name="COPERNICUS/S2_SR", \
start_date='2021-01-01', \
end_date='2021-01-15', \
bands=["TCI_R", "TCI_G", "TCI_B"], \
geographic_bounds=ee.Geometry.Point(*[5.238728194366604, 44.474864056855935]).buffer(500) \
)
Get a pandas dataframe with all pixel values as a timeseries:
>>> dl.getRegion(dl.bounds, 500)
>>> dl.region.head(2)
[Out]
id longitude latitude time B1 B2 B3 B4 B5 B6 ... WVP SCL TCI_R TCI_G TCI_B MSK_CLDPRB MSK_SNWPRB QA10 QA20 QA60
0 20210102T104441_20210102T104435_T31TFK 5.234932 44.473344 2021-01-02 10:48:36.299 6297 5955 5768 5773 5965 5883 ... 393 8 255 255 255 0 95 0 0 1024
1 20210104T103329_20210104T103331_T31TFK 5.234932 44.473344 2021-01-04 10:38:38.304 5547 5355 5184 5090 5254 5229 ... 314 9 255 255 255 29 9 0 0 1024
>>> dl.date_range
[Out]
{'max': datetime.datetime(2021, 1, 14, 11, 38, 39, 208000),
'min': datetime.datetime(2021, 1, 2, 11, 48, 36, 299000)}
>>> dl.count
[Out]
6
>>> dl.collection_info # constains a html description of the dataset in "description"
>>> dl.image_ids
[Out]
['COPERNICUS/S2_SR/20210102T104441_20210102T104435_T31TFK',
'COPERNICUS/S2_SR/20210104T103329_20210104T103331_T31TFK',
'COPERNICUS/S2_SR/20210107T104329_20210107T104328_T31TFK',
'COPERNICUS/S2_SR/20210109T103421_20210109T103431_T31TFK',
'COPERNICUS/S2_SR/20210112T104411_20210112T104438_T31TFK',
'COPERNICUS/S2_SR/20210114T103309_20210114T103305_T31TFK']
# Download the image
>>> img = dl.download_image(dl.image_ids[3])
# Download all images as a list
>>> imgs = dl.download_all_images(scale=1)
"""
def __init__(self,
satellite_name: str,
bands: Union[List, str] = None,
start_date: str = None,
end_date: str = None,
geographic_bounds: ee.geometry = None,
scale: int = 10,
crs: str = "EPSG:32630"
):
"""
Args:
satellite_name: satellite to use. Examples: COPERNICUS/S2_SR, COPERNICUS/CORINE/V20/100m.
See https://developers.google.com/earth-engine/datasets for the full list.
bands: list of bands to load.
start_date: lowest possible date. Might be lower than the actual date of the first picture.
end_date: Latest possible date.
geographic_bounds: Region of interest.
"""
self.satellite_name = satellite_name
if isinstance(bands, str):
bands = [bands]
self.bands = bands if bands is not None else list()
if start_date is None or end_date is None:
assert (start_date is not None) and (end_date is not None), "start_date and end_date must both be provided"
self.start_date = start_date
self.end_date = end_date
self.bounds = geographic_bounds
# Lazy computed
self._available_images = None
# Start getting info from google cloud
if satellite_name:
self.image_collection = ee.ImageCollection(self.satellite_name)
if self.bounds:
self.filterBounds(self.bounds)
if self.start_date is not None:
self.filterDate(self.start_date, self.end_date)
self.scale = scale
self.crs = crs
self.image_list = None
self._df_image_list = None
self.image_collection_info = None
self._date_range = None
self.date_filter_change = False
self._count = None
# Bool for caching
self.filter_change = True
self._describe = None
def signal_change(func):
"""Signals that additional filtering was performed. To be used
as a decorator."""
@functools.wraps(func)
def wrap(self, *args, **kwargs):
self.filter_change = True
self.date_filter_change = True
return func(self, *args, **kwargs)
return wrap
@staticmethod
@func_set_timeout(10)
def _get_timeout_info(instance: Any):
"""Runs getInfo on anything that is passed, with a timeout."""
return instance.getInfo()
@staticmethod
def _authenticate_gee():
"""Authenticates earth engine if needed, and initializes."""
try:
ee.Initialize()
except Exception as e:
# Trigger the authentication flow.
ee.Authenticate()
# Initialize the library.
ee.Initialize()
def filter(self, ee_filter: ee.Filter):
"""Applies a filter to the image_collection attribute. This can be useful for example
to filter out clouds
Args:
ee_filter: Filter to apply, must be an instance of ee.Filter.
Returns: self, for operation chaining as possible with the earth engine API.
"""
self.image_collection = self.image_collection.filter(ee_filter)
return self
@property
def count(self):
"""Number of images in the ImageCollection"""
if self.filter_change or self._count is None:
self._count = self._get_timeout_info(self.image_collection.size())
self.filter_change = False
return self._count
@property
def available_images(self):
"""Gets the ImageCollection info"""
if self.filter_change or self._available_images is None:
self._available_images = self._get_timeout_info(self.image_collection)
return self._available_images
@signal_change
def filterDate(self, *args, **kwargs):
"""Wrapper for the filterDate method in earth engine on the ImageCollection"""
self.image_collection = self.image_collection.filterDate(*args, **kwargs)
return self
@signal_change
def getRegion(self, *args, **kwargs):
"""Wrapper for the getRegion method in earth engine on the ImageCollection.
Caveat! getRegion does not return an image collection, so the image_list attribute gets
updated instead of the image_collection attribute. However, the instance of the DataLoader class
is still returned, so this could be chained with another method on ImageCollection, which wouldn't be
possible using earth engine.
"""
self.image_list = self.image_collection.getRegion(*args, **kwargs)
return self
@signal_change
def filterBounds(self, geometry, *args, **kwargs):
"""Wrapper for the filterBounds method in earth engine on the ImageCollection"""
self.image_collection = self.image_collection.filterBounds(geometry, *args, **kwargs)
self.bounds = geometry
return self
@signal_change
def select(self, *bands, **kwargs):
"""Wrapper for the select method in earth engine on the ImageCollection"""
self.image_collection = self.image_collection.select(*bands, **kwargs)
self.bands = list(set(self.bands) | set(bands)) # Unique bands
return self
@property
def date_range(self):
"""Gets the actual date range of the images in the image collection."""
if self.date_filter_change or self._date_range is None:
date_range = self.image_collection.reduceColumns(ee.Reducer.minMax(), ["system:time_start"]).getInfo()
self._date_range = {key: datetime.fromtimestamp(value/1e3) for key, value in date_range.items()}
self.date_filter_change = False
return self._date_range
@property
def region(self):
"""Gets a time series as a pandas DataFrame of the band values for the specified region."""
if self.filter_change:
if self.image_list is None:
self.getRegion()
res_list = self._get_timeout_info(self.image_list)
df = pd.DataFrame(res_list[1:], columns=res_list[0])
df.loc[:, "time"] = pd.to_datetime(df.loc[:, "time"], unit="ms")
self._df_image_list = df
self.filter_change = False
return self._df_image_list
@property
def collection_info(self):
"""Runs getInfo on the image collection (the first time the next time the previously
populated attribute will be returned)."""
if self.count > 5000:
raise Exception("Too many images to load. Try filtering more")
if self.filter_change or self.image_collection_info is None:
self.image_collection_info = self._get_timeout_info(self.image_collection)
return self.image_collection_info
@property
def image_ids(self):
"""list of names of available images in the image collection"""
return [i["id"] for i in self.collection_info["features"]]
def __repr__(self):
try:
return f"""
Size: {self.count}
Dataset date ranges:
From: {self.date_range["min"]}
To: {self.date_range["max"]}
Selected bands:
{self.bands}
"""
except Exception as e:
raise Exception("Impossible to represent the dataset. Try filtering more. Error handling to do.")
def reproject(self, image, **kwargs):
def resolve(name: str):
# Resolve crs
if name in kwargs:
item = kwargs[name]
elif getattr(self, name):
item = getattr(self, name)
else:
item = None
return item
crs = resolve("crs")
scale = resolve("scale")
if crs is not None or scale is not None:
image = image.reproject(crs, None, scale)
return image
def download_image(self, image_id: str, **kwargs):
"""Downloads an image based on its id / name. The additional arguments are passed
to getThumbUrl, and could be scale, max, min...
"""
img = ee.Image(image_id).select(*self.bands)
img = self.reproject(img, **kwargs)
input_args = {'region': self.bounds}
input_args.update(**kwargs)
all_bands = self.collection_info["features"][0]["bands"]
selected_bands = [band for i, band in enumerate(all_bands) if all_bands[i]["id"] in self.bands]
if "min" not in input_args:
input_args.update({"min": selected_bands[0]["data_type"]["min"]})
if "max" not in input_args:
input_args.update({"max": selected_bands[0]["data_type"]["max"]})
url = img.getThumbUrl(input_args)
buffer = tempfile.SpooledTemporaryFile(max_size=1e9)
r = requests.get(url, stream=True)
if r.status_code == 200:
downloaded = 0
# filesize = int(r.headers['content-length'])
for chunk in r.iter_content(chunk_size=1024):
downloaded += len(chunk)
buffer.write(chunk)
buffer.seek(0)
img = Image.open(io.BytesIO(buffer.read()))
buffer.close()
return img
@staticmethod
def _regex(regex: str, im_id_list: List[str], include: bool) -> list:
"""
Filters the im_id_list based on a regular expression. This is useful before downloading
a collection of images. For example, using (.*)TXT with include=True will only download images
that end with TXT, wich for Nantes means filtering out empty or half empty images.
Args:
regex: python regex as a strng
im_id_list: list, image id list
include: whether to include or exclude elements that match the regex.
Returns: filtered list.
"""
expression = "re.match('{regex}', '{im_id}') is not None"
if not include:
expression = "not " + expression
filtered_list = list()
for im_id in im_id_list:
if eval(expression.format(regex=regex, im_id=im_id)):
filtered_list.append(im_id)
return filtered_list
def download_all_images(self, regex_exclude: str = None, regex_include: str = None, **kwargs):
"""
Runs download_image in a for loop around the available images.
Makes it possible to filter images to download based on a regex.
Args:
regex_exclude: any image that matches this regex will be excluded.
regex_include: any image that matches this regex will be included
**kwargs: arguments to be passed to getThumbUrl
Returns: list of PIL images
"""
images = list()
image_ids = self.image_ids
if regex_exclude is not None:
image_ids = self._regex(regex_exclude, image_ids, include=False)
if regex_include is not None:
image_ids = self._regex(regex_include, image_ids, include=True)
for i in tqdm(range(len(image_ids))):
images.append(self.download_image(image_ids[i], **kwargs))
return images