|
|
|
import os |
|
from huggingface_hub import hf_hub_download |
|
config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", |
|
filename="multi_temporal_crop_classification_Prithvi_100M.py", |
|
token=os.environ.get("token")) |
|
ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", |
|
filename='multi_temporal_crop_classification_best_mIoU_epoch_66.pth', |
|
token=os.environ.get("token")) |
|
|
|
import argparse |
|
from mmcv import Config |
|
|
|
from mmseg.models import build_segmentor |
|
|
|
from mmseg.datasets.pipelines import Compose, LoadImageFromFile |
|
|
|
import rasterio |
|
import torch |
|
|
|
from mmseg.apis import init_segmentor |
|
|
|
from mmcv.parallel import collate, scatter |
|
|
|
import numpy as np |
|
import glob |
|
import os |
|
|
|
import time |
|
|
|
import numpy as np |
|
import gradio as gr |
|
from functools import partial |
|
|
|
import pdb |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def open_tiff(fname): |
|
|
|
with rasterio.open(fname, "r") as src: |
|
|
|
data = src.read() |
|
|
|
return data |
|
|
|
def write_tiff(img_wrt, filename, metadata): |
|
|
|
""" |
|
It writes a raster image to file. |
|
|
|
:param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands) |
|
:param filename: file path to the output file |
|
:param metadata: metadata to use to write the raster to disk |
|
:return: |
|
""" |
|
|
|
with rasterio.open(filename, "w", **metadata) as dest: |
|
|
|
if len(img_wrt.shape) == 2: |
|
|
|
img_wrt = img_wrt[None] |
|
|
|
for i in range(img_wrt.shape[0]): |
|
dest.write(img_wrt[i, :, :], i + 1) |
|
|
|
return filename |
|
|
|
|
|
def get_meta(fname): |
|
|
|
with rasterio.open(fname, "r") as src: |
|
|
|
meta = src.meta |
|
|
|
return meta |
|
|
|
def preprocess_example(example_list): |
|
|
|
example_list = [os.path.join(os.path.abspath(''), x) for x in example_list] |
|
|
|
return example_list |
|
|
|
|
|
def inference_segmentor(model, imgs, custom_test_pipeline=None): |
|
"""Inference image(s) with the segmentor. |
|
|
|
Args: |
|
model (nn.Module): The loaded segmentor. |
|
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded |
|
images. |
|
|
|
Returns: |
|
(list[Tensor]): The segmentation result. |
|
""" |
|
cfg = model.cfg |
|
device = next(model.parameters()).device |
|
|
|
test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline |
|
test_pipeline = Compose(test_pipeline) |
|
|
|
data = [] |
|
imgs = imgs if isinstance(imgs, list) else [imgs] |
|
for img in imgs: |
|
img_data = {'img_info': {'filename': img}} |
|
img_data = test_pipeline(img_data) |
|
data.append(img_data) |
|
|
|
|
|
data = collate(data, samples_per_gpu=len(imgs)) |
|
if next(model.parameters()).is_cuda: |
|
|
|
|
|
data = scatter(data, [device])[0] |
|
else: |
|
|
|
|
|
|
|
img_metas = data['img_metas'].data[0] |
|
img = data['img'] |
|
data = {'img': img, 'img_metas':img_metas} |
|
|
|
with torch.no_grad(): |
|
result = model(return_loss=False, rescale=True, **data) |
|
return result |
|
|
|
|
|
def inference_on_file(target_image, model, custom_test_pipeline): |
|
|
|
target_image = target_image.name |
|
|
|
|
|
|
|
time_taken=-1 |
|
try: |
|
st = time.time() |
|
print('Running inference...') |
|
result = inference_segmentor(model, target_image, custom_test_pipeline) |
|
print("Output has shape: " + str(result[0].shape)) |
|
|
|
|
|
mask = open_tiff(target_image) |
|
|
|
rgb1 = mask[[2, 1, 0], :, :].transpose((1,2,0)) |
|
rgb2 = mask[[8, 7, 6], :, :].transpose((1,2,0)) |
|
rgb3 = mask[[14, 13, 12], :, :].transpose((1,2,0)) |
|
meta = get_meta(target_image) |
|
mask = np.where(mask == meta['nodata'], 1, 0) |
|
mask = np.max(mask, axis=0)[None] |
|
|
|
result[0] = np.where(mask == 1, -1, result[0]) |
|
|
|
|
|
meta["count"] = 1 |
|
meta["dtype"] = "int16" |
|
meta["compress"] = "lzw" |
|
meta["nodata"] = -1 |
|
print('Saving output...') |
|
|
|
et = time.time() |
|
time_taken = np.round(et - st, 1) |
|
print(f'Inference completed in {str(time_taken)} seconds') |
|
|
|
except: |
|
print(f'Error on image {target_image} \nContinue to next input') |
|
|
|
return rgb, result[0][0]*255 |
|
|
|
def process_test_pipeline(custom_test_pipeline, bands=None): |
|
|
|
|
|
if bands is not None: |
|
|
|
extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ] |
|
|
|
if len(extract_index) > 0: |
|
|
|
custom_test_pipeline[extract_index[0]]['bands'] = eval(bands) |
|
|
|
collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1] |
|
|
|
|
|
if len(collect_index) > 0: |
|
|
|
keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'] |
|
custom_test_pipeline[collect_index[0]]['meta_keys'] = keys |
|
|
|
return custom_test_pipeline |
|
|
|
config = Config.fromfile(config_path) |
|
config.model.backbone.pretrained=None |
|
model = init_segmentor(config, ckpt, device='cpu') |
|
custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None) |
|
|
|
func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(value='# Prithvi multi temporal crop classification') |
|
gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n |
|
The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order. |
|
''') |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp = gr.File() |
|
btn = gr.Button("Submit") |
|
|
|
with gr.Row(): |
|
gr.Markdown(value='### T1') |
|
gr.Markdown(value='### T2') |
|
gr.Markdown(value='### T3') |
|
gr.Markdown(value='### Model prediction') |
|
|
|
with gr.Row(): |
|
inp1=gr.Image(image_mode='RGB') |
|
inp2=gr.Image(image_mode='RGB') |
|
inp3=gr.Image(image_mode='RGB') |
|
out = gr.Image(image_mode='L') |
|
|
|
btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out]) |
|
|
|
with gr.Row(): |
|
gr.Examples(examples=["chip_102_345_merged.tif", |
|
"chip_104_104_merged.tif", |
|
"chip_109_421_merged.tif"], |
|
inputs=inp, |
|
outputs=[inp1, inp2, inp3, out], |
|
preprocess=preprocess_example, |
|
fn=func, |
|
cache_examples=True, |
|
) |
|
|
|
demo.launch() |