superIX / sr4rs /utils.py
csaybar's picture
Upload 5 files
9c55c41 verified
raw
history blame
1.52 kB
import tensorflow as tf
import torch
def load_cesbio_sr() -> tf.function:
"""Prepare the CESBIO model
Returns:
tf.function: A tf.function to get the SR image
"""
# read the model
model = tf.saved_model.load("weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel")
# get the signature
signature = list(model.signatures.keys())[0]
# get the function
func = model.signatures[signature]
return func
def run_sr4rs(
model: tf.function,
lr: tf.Tensor,
hr: tf.Tensor,
) -> dict:
"""Run the SR4RS model
Args:
model (tf.function): The model to use
lr (tf.Tensor): The low resolution image
hr (tf.Tensor): The high resolution image
cropsize (int, optional): The cropsize. Defaults to 32.
overlap (int, optional): The overlap. Defaults to 0.
Returns:
dict: The results
"""
# Run inference
Xnp = torch.from_numpy(lr[[3, 2, 1, 7]][None]).permute(0, 2, 3, 1)
Xtf = tf.convert_to_tensor(Xnp, dtype=tf.float32)
pred = model(Xtf)
# Save the results
pred_np = pred['output_32:0'].numpy()
pred_torch = torch.from_numpy(pred_np).permute(0, 3, 1, 2)
pred_torch_padded = torch.nn.functional.pad(
pred_torch,
(32, 32, 32, 32),
mode='constant',
value=0,
).squeeze().numpy().astype('uint16')
results = {
"lr": lr[[3, 2, 1]],
"sr": pred_torch_padded[0:3],
"hr": hr[0:3],
}
return results