Paolo-Fraccaro's picture
fix app
d83d687
raw
history blame
No virus
15.4 kB
#### pull files from hub
from huggingface_hub import hf_hub_download
import os
yaml_file_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename="Prithvi_100M_config.yaml", token=os.environ.get("token"))
checkpoint=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi_100M.pt', token=os.environ.get("token"))
model_def=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi.py', token=os.environ.get("token"))
os.system(f'cp {model_def} .')
#####
import argparse
import functools
import os
from typing import List
import numpy as np
import rasterio
import torch
import yaml
from einops import rearrange
from models_mae import MaskedAutoencoderViT
import gradio as gr
from functools import partial
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
PERCENTILES = (0.1, 99.9)
def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
""" Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
Args:
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
new_img: torch.Tensor representing image with shape = (bands, H, W).
channels: list of indices representing RGB channels.
data_mean: list of mean values for each band.
data_std: list of std values for each band.
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
torch.Tensor with shape (num_channels, height, width) for the other image
"""
stack_c = [], []
for c in channels:
orig_ch = orig_img[c, ...]
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
valid_mask[orig_ch == 0.0001] = False
# Back to original data range
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
# Rescale (enhancing contrast)
min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
# No data as zeros
orig_ch[~valid_mask] = 0
new_ch[~valid_mask] = 0
stack_c[0].append(orig_ch)
stack_c[1].append(new_ch)
# Channels first
stack_orig = torch.stack(stack_c[0], dim=0)
stack_rec = torch.stack(stack_c[1], dim=0)
return stack_orig, stack_rec
def read_geotiff(file_path: str):
""" Read all bands from *file_path* and returns image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
with rasterio.open(file_path) as src:
img = src.read()
meta = src.meta
return img, meta
def save_geotiff(image, output_path: str, meta: dict):
""" Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
with rasterio.open(output_path, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
return
def _convert_np_uint8(float_image: torch.Tensor):
image = float_image.numpy() * 255.0
image = image.astype(dtype=np.uint8)
image = image.transpose((1, 2, 0))
return image
def load_example(file_paths: List[str], mean: List[float], std: List[float]):
""" Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the images in *file_paths*.
std: list containing std values for each band in the images in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs = []
metas = []
for file in file_paths:
img, meta = read_geotiff(file)
img = img[:6]*10000
# Rescaling (don't normalize on nodata)
img = np.moveaxis(img, 0, -1) # channels last for rescaling
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
imgs.append(img)
metas.append(meta)
imgs = np.stack(imgs, axis=0) # num_frames, img_size, img_size, C
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, img_size, img_size
imgs = np.expand_dims(imgs, axis=0) # add batch dim
return imgs, metas
def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
""" Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
Args:
model: MAE model to run.
input_data: torch.Tensor with shape (B, C, T, H, W).
mask_ratio: mask ratio to use.
device: device where model should run.
Returns:
3 torch.Tensor with shape (B, C, T, H, W).
"""
with torch.no_grad():
x = input_data.to(device)
_, pred, mask = model(x, mask_ratio)
# Create mask and prediction images (un-patchify)
mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
pred_img = model.unpatchify(pred).detach().cpu()
# Mix visible and predicted patches
rec_img = input_data.clone()
rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
return rec_img, mask_img
def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
Args:
input_img: input torch.Tensor with shape (C, T, H, W).
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
mask_img: mask torch.Tensor with shape (C, T, H, W).
channels: list of indices representing RGB channels.
mean: list of mean values for each band.
std: list of std values for each band.
output_dir: directory where to save outputs.
meta_data: list of dicts with geotiff meta info.
"""
for t in range(input_img.shape[1]):
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
new_img=rec_img[:, t, :, :],
channels=channels, data_mean=mean,
data_std=std)
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
# Saving images
save_geotiff(image=_convert_np_uint8(rgb_orig),
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
meta=meta_data[t])
save_geotiff(image=_convert_np_uint8(rgb_pred),
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
meta=meta_data[t])
save_geotiff(image=_convert_np_uint8(rgb_mask),
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
meta=meta_data[t])
def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
Args:
input_img: input torch.Tensor with shape (C, T, H, W).
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
mask_img: mask torch.Tensor with shape (C, T, H, W).
channels: list of indices representing RGB channels.
mean: list of mean values for each band.
std: list of std values for each band.
output_dir: directory where to save outputs.
meta_data: list of dicts with geotiff meta info.
"""
rgb_orig_list = []
rgb_mask_list = []
rgb_pred_list = []
for t in range(input_img.shape[1]):
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
new_img=rec_img[:, t, :, :],
channels=channels, data_mean=mean,
data_std=std)
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
# extract images
rgb_orig_list.append(_convert_np_uint8(rgb_orig))
rgb_mask_list.append(_convert_np_uint8(rgb_mask))
rgb_pred_list.append(_convert_np_uint8(rgb_pred))
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
return outputs
def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
# os.makedirs(output_dir, exist_ok=True)
# Get parameters --------
with open(yaml_file_path, 'r') as f:
params = yaml.safe_load(f)
# data related
num_frames = params['num_frames']
img_size = params['img_size']
bands = params['bands']
mean = params['data_mean']
std = params['data_std']
# model related
depth = params['depth']
patch_size = params['patch_size']
embed_dim = params['embed_dim']
num_heads = params['num_heads']
tubelet_size = params['tubelet_size']
decoder_embed_dim = params['decoder_embed_dim']
decoder_num_heads = params['decoder_num_heads']
decoder_depth = params['decoder_depth']
batch_size = params['batch_size']
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
# We must have *num_frames* files to build one example!
assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f"Using {device} device.\n")
# Loading data ---------------------------------------------------------------------------------
input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
# Create model and load checkpoint -------------------------------------------------------------
model = MaskedAutoencoderViT(
img_size=img_size,
patch_size=patch_size,
num_frames=num_frames,
tubelet_size=tubelet_size,
in_chans=len(bands),
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=4.,
norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
norm_pix_loss=False)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> model has {total_params / 1e6} Million params.\n")
model.to(device)
state_dict = torch.load(checkpoint, map_location=device)
model.load_state_dict(state_dict)
print(f"Loaded checkpoint from {checkpoint}")
# Running model --------------------------------------------------------------------------------
model.eval()
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
# Build sliding window
batch = torch.tensor(input_data, device='cpu')
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
h1, w1 = windows.shape[3:5]
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
# Split into batches if number of windows > batch_size
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0)
# Run model
rec_imgs = []
mask_imgs = []
for x in windows:
rec_img, mask_img = run_model(model, x, mask_ratio, device)
rec_imgs.append(rec_img)
mask_imgs.append(mask_img)
rec_imgs = torch.concat(rec_imgs, dim=0)
mask_imgs = torch.concat(mask_imgs, dim=0)
# Build images from patches
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
# Mix original image with patches
h, w = rec_imgs.shape[-2:]
rec_imgs_full = batch.clone()
rec_imgs_full[..., :h, :w] = rec_imgs
mask_imgs_full = torch.ones_like(batch)
mask_imgs_full[..., :h, :w] = mask_imgs
# Build RGB images
for d in meta_data:
d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
# save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
# channels, mean, std, output_dir, meta_data)
outputs = extract_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
channels, mean, std)
print("Done!")
return outputs
func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
inp_files = gr.Files(elem_id='files')
# inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
btn = gr.Button("Submit")
with gr.Row():
gr.Markdown(value='Original images')
with gr.Row():
gr.Markdown(value='T1')
gr.Markdown(value='T2')
gr.Markdown(value='T3')
with gr.Row():
out1_orig_t1=gr.Image(image_mode='RGB')
out2_orig_t2 = gr.Image(image_mode='RGB')
out3_orig_t3 = gr.Image(image_mode='RGB')
with gr.Row():
gr.Markdown(value='Masked images')
with gr.Row():
gr.Markdown(value='T1')
gr.Markdown(value='T2')
gr.Markdown(value='T3')
with gr.Row():
out4_masked_t1=gr.Image(image_mode='RGB')
out5_masked_t2 = gr.Image(image_mode='RGB')
out6_masked_t3 = gr.Image(image_mode='RGB')
with gr.Row():
gr.Markdown(value='Reonstructed images')
with gr.Row():
gr.Markdown(value='T1')
gr.Markdown(value='T2')
gr.Markdown(value='T3')
with gr.Row():
out7_pred_t1=gr.Image(image_mode='RGB')
out8_pred_t2 = gr.Image(image_mode='RGB')
out9_pred_t3 = gr.Image(image_mode='RGB')
btn.click(fn=func,
# inputs=[inp_files, inp_slider],
inputs=inp_files,
outputs=[out1_orig_t1,
out2_orig_t2,
out3_orig_t3,
out4_masked_t1,
out5_masked_t2,
out6_masked_t3,
out7_pred_t1,
out8_pred_t2,
out9_pred_t3])
demo.launch()