Spaces:
Runtime error
Runtime error
from datetime import datetime | |
import ee | |
from func_timeout import func_set_timeout | |
import pandas as pd | |
from PIL import Image | |
import requests | |
import tempfile | |
import io | |
from tqdm import tqdm | |
import functools | |
import re # Used in an eval statement | |
from typing import List | |
from typing import Union | |
from typing import Any | |
class DataLoader: | |
""" | |
Main class for loading and exploring data from satellite images. | |
The goal is to load an ImageCollection and to filter that collection according to needs, with methods like | |
filter, filterDate, filterBounds, select. These will work just like earth engine's methods with the same names. | |
This class, just like earth engine, works with lazy loading and compute. This means that running filterBounds | |
will not actually filter the image collection until required, e.g. when counting the images by accessing .count | |
property. | |
However, it will only load once the information it needs, unless additional filtering is made. | |
This works thanks to the signal_change decorator. If you develop a new filtering method for this class, | |
you will need to decorate your method with @signal_change. | |
In addition, if you develop a new method that will require to run getInfo to actually load data from | |
Google Earth Engine, you will need to use _get_timeout_info(your object before getInfo). This will run | |
getInfo with a timeout (currently set to 10 seconds). | |
It is important to use a timeout to avoid unexpected run times. | |
Usage: | |
>>> dl = DataLoader(satellite_name="COPERNICUS/S2_SR", \ | |
start_date='2021-01-01', \ | |
end_date='2021-01-15', \ | |
bands=["TCI_R", "TCI_G", "TCI_B"], \ | |
geographic_bounds=ee.Geometry.Point(*[5.238728194366604, 44.474864056855935]).buffer(500) \ | |
) | |
Get a pandas dataframe with all pixel values as a timeseries: | |
>>> dl.getRegion(dl.bounds, 500) | |
>>> dl.region.head(2) | |
[Out] | |
id longitude latitude time B1 B2 B3 B4 B5 B6 ... WVP SCL TCI_R TCI_G TCI_B MSK_CLDPRB MSK_SNWPRB QA10 QA20 QA60 | |
0 20210102T104441_20210102T104435_T31TFK 5.234932 44.473344 2021-01-02 10:48:36.299 6297 5955 5768 5773 5965 5883 ... 393 8 255 255 255 0 95 0 0 1024 | |
1 20210104T103329_20210104T103331_T31TFK 5.234932 44.473344 2021-01-04 10:38:38.304 5547 5355 5184 5090 5254 5229 ... 314 9 255 255 255 29 9 0 0 1024 | |
>>> dl.date_range | |
[Out] | |
{'max': datetime.datetime(2021, 1, 14, 11, 38, 39, 208000), | |
'min': datetime.datetime(2021, 1, 2, 11, 48, 36, 299000)} | |
>>> dl.count | |
[Out] | |
6 | |
>>> dl.collection_info # constains a html description of the dataset in "description" | |
>>> dl.image_ids | |
[Out] | |
['COPERNICUS/S2_SR/20210102T104441_20210102T104435_T31TFK', | |
'COPERNICUS/S2_SR/20210104T103329_20210104T103331_T31TFK', | |
'COPERNICUS/S2_SR/20210107T104329_20210107T104328_T31TFK', | |
'COPERNICUS/S2_SR/20210109T103421_20210109T103431_T31TFK', | |
'COPERNICUS/S2_SR/20210112T104411_20210112T104438_T31TFK', | |
'COPERNICUS/S2_SR/20210114T103309_20210114T103305_T31TFK'] | |
# Download the image | |
>>> img = dl.download_image(dl.image_ids[3]) | |
# Download all images as a list | |
>>> imgs = dl.download_all_images(scale=1) | |
""" | |
def __init__(self, | |
satellite_name: str, | |
bands: Union[List, str] = None, | |
start_date: str = None, | |
end_date: str = None, | |
geographic_bounds: ee.geometry = None, | |
scale: int = 10, | |
crs: str = "EPSG:32630" | |
): | |
""" | |
Args: | |
satellite_name: satellite to use. Examples: COPERNICUS/S2_SR, COPERNICUS/CORINE/V20/100m. | |
See https://developers.google.com/earth-engine/datasets for the full list. | |
bands: list of bands to load. | |
start_date: lowest possible date. Might be lower than the actual date of the first picture. | |
end_date: Latest possible date. | |
geographic_bounds: Region of interest. | |
""" | |
self.satellite_name = satellite_name | |
if isinstance(bands, str): | |
bands = [bands] | |
self.bands = bands if bands is not None else list() | |
if start_date is None or end_date is None: | |
assert (start_date is not None) and (end_date is not None), "start_date and end_date must both be provided" | |
self.start_date = start_date | |
self.end_date = end_date | |
self.bounds = geographic_bounds | |
# Lazy computed | |
self._available_images = None | |
# Start getting info from google cloud | |
if satellite_name: | |
self.image_collection = ee.ImageCollection(self.satellite_name) | |
if self.bounds: | |
self.filterBounds(self.bounds) | |
if self.start_date is not None: | |
self.filterDate(self.start_date, self.end_date) | |
self.scale = scale | |
self.crs = crs | |
self.image_list = None | |
self._df_image_list = None | |
self.image_collection_info = None | |
self._date_range = None | |
self.date_filter_change = False | |
self._count = None | |
# Bool for caching | |
self.filter_change = True | |
self._describe = None | |
def signal_change(func): | |
"""Signals that additional filtering was performed. To be used | |
as a decorator.""" | |
def wrap(self, *args, **kwargs): | |
self.filter_change = True | |
self.date_filter_change = True | |
return func(self, *args, **kwargs) | |
return wrap | |
def _get_timeout_info(instance: Any): | |
"""Runs getInfo on anything that is passed, with a timeout.""" | |
return instance.getInfo() | |
def _authenticate_gee(): | |
"""Authenticates earth engine if needed, and initializes.""" | |
try: | |
ee.Initialize() | |
except Exception as e: | |
# Trigger the authentication flow. | |
ee.Authenticate() | |
# Initialize the library. | |
ee.Initialize() | |
def filter(self, ee_filter: ee.Filter): | |
"""Applies a filter to the image_collection attribute. This can be useful for example | |
to filter out clouds | |
Args: | |
ee_filter: Filter to apply, must be an instance of ee.Filter. | |
Returns: self, for operation chaining as possible with the earth engine API. | |
""" | |
self.image_collection = self.image_collection.filter(ee_filter) | |
return self | |
def count(self): | |
"""Number of images in the ImageCollection""" | |
if self.filter_change or self._count is None: | |
self._count = self._get_timeout_info(self.image_collection.size()) | |
self.filter_change = False | |
return self._count | |
def available_images(self): | |
"""Gets the ImageCollection info""" | |
if self.filter_change or self._available_images is None: | |
self._available_images = self._get_timeout_info(self.image_collection) | |
return self._available_images | |
def filterDate(self, *args, **kwargs): | |
"""Wrapper for the filterDate method in earth engine on the ImageCollection""" | |
self.image_collection = self.image_collection.filterDate(*args, **kwargs) | |
return self | |
def getRegion(self, *args, **kwargs): | |
"""Wrapper for the getRegion method in earth engine on the ImageCollection. | |
Caveat! getRegion does not return an image collection, so the image_list attribute gets | |
updated instead of the image_collection attribute. However, the instance of the DataLoader class | |
is still returned, so this could be chained with another method on ImageCollection, which wouldn't be | |
possible using earth engine. | |
""" | |
self.image_list = self.image_collection.getRegion(*args, **kwargs) | |
return self | |
def filterBounds(self, geometry, *args, **kwargs): | |
"""Wrapper for the filterBounds method in earth engine on the ImageCollection""" | |
self.image_collection = self.image_collection.filterBounds(geometry, *args, **kwargs) | |
self.bounds = geometry | |
return self | |
def select(self, *bands, **kwargs): | |
"""Wrapper for the select method in earth engine on the ImageCollection""" | |
self.image_collection = self.image_collection.select(*bands, **kwargs) | |
self.bands = list(set(self.bands) | set(bands)) # Unique bands | |
return self | |
def date_range(self): | |
"""Gets the actual date range of the images in the image collection.""" | |
if self.date_filter_change or self._date_range is None: | |
date_range = self.image_collection.reduceColumns(ee.Reducer.minMax(), ["system:time_start"]).getInfo() | |
self._date_range = {key: datetime.fromtimestamp(value/1e3) for key, value in date_range.items()} | |
self.date_filter_change = False | |
return self._date_range | |
def region(self): | |
"""Gets a time series as a pandas DataFrame of the band values for the specified region.""" | |
if self.filter_change: | |
if self.image_list is None: | |
self.getRegion() | |
res_list = self._get_timeout_info(self.image_list) | |
df = pd.DataFrame(res_list[1:], columns=res_list[0]) | |
df.loc[:, "time"] = pd.to_datetime(df.loc[:, "time"], unit="ms") | |
self._df_image_list = df | |
self.filter_change = False | |
return self._df_image_list | |
def collection_info(self): | |
"""Runs getInfo on the image collection (the first time the next time the previously | |
populated attribute will be returned).""" | |
if self.count > 5000: | |
raise Exception("Too many images to load. Try filtering more") | |
if self.filter_change or self.image_collection_info is None: | |
self.image_collection_info = self._get_timeout_info(self.image_collection) | |
return self.image_collection_info | |
def image_ids(self): | |
"""list of names of available images in the image collection""" | |
return [i["id"] for i in self.collection_info["features"]] | |
def __repr__(self): | |
try: | |
return f""" | |
Size: {self.count} | |
Dataset date ranges: | |
From: {self.date_range["min"]} | |
To: {self.date_range["max"]} | |
Selected bands: | |
{self.bands} | |
""" | |
except Exception as e: | |
raise Exception("Impossible to represent the dataset. Try filtering more. Error handling to do.") | |
def reproject(self, image, **kwargs): | |
def resolve(name: str): | |
# Resolve crs | |
if name in kwargs: | |
item = kwargs[name] | |
elif getattr(self, name): | |
item = getattr(self, name) | |
else: | |
item = None | |
return item | |
crs = resolve("crs") | |
scale = resolve("scale") | |
if crs is not None or scale is not None: | |
image = image.reproject(crs, None, scale) | |
return image | |
def download_image(self, image_id: str, **kwargs): | |
"""Downloads an image based on its id / name. The additional arguments are passed | |
to getThumbUrl, and could be scale, max, min... | |
""" | |
img = ee.Image(image_id).select(*self.bands) | |
img = self.reproject(img, **kwargs) | |
input_args = {'region': self.bounds} | |
input_args.update(**kwargs) | |
all_bands = self.collection_info["features"][0]["bands"] | |
selected_bands = [band for i, band in enumerate(all_bands) if all_bands[i]["id"] in self.bands] | |
if "min" not in input_args: | |
input_args.update({"min": selected_bands[0]["data_type"]["min"]}) | |
if "max" not in input_args: | |
input_args.update({"max": selected_bands[0]["data_type"]["max"]}) | |
url = img.getThumbUrl(input_args) | |
buffer = tempfile.SpooledTemporaryFile(max_size=1e9) | |
r = requests.get(url, stream=True) | |
if r.status_code == 200: | |
downloaded = 0 | |
# filesize = int(r.headers['content-length']) | |
for chunk in r.iter_content(chunk_size=1024): | |
downloaded += len(chunk) | |
buffer.write(chunk) | |
buffer.seek(0) | |
img = Image.open(io.BytesIO(buffer.read())) | |
buffer.close() | |
return img | |
def _regex(regex: str, im_id_list: List[str], include: bool) -> list: | |
""" | |
Filters the im_id_list based on a regular expression. This is useful before downloading | |
a collection of images. For example, using (.*)TXT with include=True will only download images | |
that end with TXT, wich for Nantes means filtering out empty or half empty images. | |
Args: | |
regex: python regex as a strng | |
im_id_list: list, image id list | |
include: whether to include or exclude elements that match the regex. | |
Returns: filtered list. | |
""" | |
expression = "re.match('{regex}', '{im_id}') is not None" | |
if not include: | |
expression = "not " + expression | |
filtered_list = list() | |
for im_id in im_id_list: | |
if eval(expression.format(regex=regex, im_id=im_id)): | |
filtered_list.append(im_id) | |
return filtered_list | |
def download_all_images(self, regex_exclude: str = None, regex_include: str = None, **kwargs): | |
""" | |
Runs download_image in a for loop around the available images. | |
Makes it possible to filter images to download based on a regex. | |
Args: | |
regex_exclude: any image that matches this regex will be excluded. | |
regex_include: any image that matches this regex will be included | |
**kwargs: arguments to be passed to getThumbUrl | |
Returns: list of PIL images | |
""" | |
images = list() | |
image_ids = self.image_ids | |
if regex_exclude is not None: | |
image_ids = self._regex(regex_exclude, image_ids, include=False) | |
if regex_include is not None: | |
image_ids = self._regex(regex_include, image_ids, include=True) | |
for i in tqdm(range(len(image_ids))): | |
images.append(self.download_image(image_ids[i], **kwargs)) | |
return images | |