|
import tensorflow as tf |
|
import torch |
|
|
|
def load_cesbio_sr() -> tf.function: |
|
"""Prepare the CESBIO model |
|
|
|
Returns: |
|
tf.function: A tf.function to get the SR image |
|
""" |
|
|
|
|
|
model = tf.saved_model.load("weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel") |
|
|
|
|
|
signature = list(model.signatures.keys())[0] |
|
|
|
|
|
func = model.signatures[signature] |
|
|
|
return func |
|
|
|
def run_sr4rs( |
|
model: tf.function, |
|
lr: tf.Tensor, |
|
hr: tf.Tensor, |
|
) -> dict: |
|
"""Run the SR4RS model |
|
|
|
Args: |
|
model (tf.function): The model to use |
|
lr (tf.Tensor): The low resolution image |
|
hr (tf.Tensor): The high resolution image |
|
cropsize (int, optional): The cropsize. Defaults to 32. |
|
overlap (int, optional): The overlap. Defaults to 0. |
|
|
|
Returns: |
|
dict: The results |
|
""" |
|
|
|
Xnp = torch.from_numpy(lr[[3, 2, 1, 7]][None]).permute(0, 2, 3, 1) |
|
Xtf = tf.convert_to_tensor(Xnp, dtype=tf.float32) |
|
pred = model(Xtf) |
|
|
|
|
|
pred_np = pred['output_32:0'].numpy() |
|
pred_torch = torch.from_numpy(pred_np).permute(0, 3, 1, 2) |
|
pred_torch_padded = torch.nn.functional.pad( |
|
pred_torch, |
|
(32, 32, 32, 32), |
|
mode='constant', |
|
value=0, |
|
).squeeze().numpy().astype('uint16') |
|
|
|
results = { |
|
"lr": lr[[3, 2, 1]], |
|
"sr": pred_torch_padded[0:3], |
|
"hr": hr[0:3], |
|
} |
|
|
|
return results |
|
|