Gary0205 commited on
Commit
6d70ed4
1 Parent(s): a5b9d70

Upload 25 files

Browse files
graphcast/autoregressive.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """A Predictor wrapping a one-step Predictor to make autoregressive predictions.
15
+ """
16
+
17
+ from typing import Optional, cast
18
+
19
+ from absl import logging
20
+ from graphcast import predictor_base
21
+ from graphcast import xarray_jax
22
+ from graphcast import xarray_tree
23
+ import haiku as hk
24
+ import jax
25
+ import xarray
26
+
27
+
28
+ def _unflatten_and_expand_time(flat_variables, tree_def, time_coords):
29
+ variables = jax.tree_util.tree_unflatten(tree_def, flat_variables)
30
+ return variables.expand_dims(time=time_coords, axis=0)
31
+
32
+
33
+ def _get_flat_arrays_and_single_timestep_treedef(variables):
34
+ flat_arrays = jax.tree_util.tree_leaves(variables.transpose('time', ...))
35
+ _, treedef = jax.tree_util.tree_flatten(variables.isel(time=0, drop=True))
36
+ return flat_arrays, treedef
37
+
38
+
39
+ class Predictor(predictor_base.Predictor):
40
+ """Wraps a one-step Predictor to make multi-step predictions autoregressively.
41
+
42
+ The wrapped Predictor will be used to predict a single timestep conditional
43
+ on the inputs passed to the outer Predictor. Its predictions are then
44
+ passed back in as inputs at the next timestep, for as many timesteps as are
45
+ requested in the targets_template. (When multiple timesteps of input are
46
+ used, a rolling window of inputs is maintained with new predictions
47
+ concatenated onto the end).
48
+
49
+ You may ask for additional variables to be predicted as targets which aren't
50
+ used as inputs. These will be predicted as output variables only and not fed
51
+ back in autoregressively. All target variables must be time-dependent however.
52
+
53
+ You may also specify static (non-time-dependent) inputs which will be passed
54
+ in at each timestep but are not predicted.
55
+
56
+ At present, any time-dependent inputs must also be present as targets so they
57
+ can be passed in autoregressively.
58
+
59
+ The loss of the wrapped one-step Predictor is averaged over all timesteps to
60
+ give a loss for the autoregressive Predictor.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ predictor: predictor_base.Predictor,
66
+ noise_level: Optional[float] = None,
67
+ gradient_checkpointing: bool = False,
68
+ ):
69
+ """Initializes an autoregressive predictor wrapper.
70
+
71
+ Args:
72
+ predictor: A predictor to wrap in an auto-regressive way.
73
+ noise_level: Optional value that multiplies the standard normal noise
74
+ added to the time-dependent variables of the predictor inputs. In
75
+ particular, no noise is added to the predictions that are fed back
76
+ auto-regressively. Defaults to not adding noise.
77
+ gradient_checkpointing: If True, gradient checkpointing will be
78
+ used at each step of the computation to save on memory. Roughtly this
79
+ should make the backwards pass two times more expensive, and the time
80
+ per step counting the forward pass, should only increase by about 50%.
81
+ Note this parameter will be ignored with a warning if the scan sequence
82
+ length is 1.
83
+ """
84
+ self._predictor = predictor
85
+ self._noise_level = noise_level
86
+ self._gradient_checkpointing = gradient_checkpointing
87
+
88
+ def _get_and_validate_constant_inputs(self, inputs, targets, forcings):
89
+ constant_inputs = inputs.drop_vars(targets.keys(), errors='ignore')
90
+ constant_inputs = constant_inputs.drop_vars(
91
+ forcings.keys(), errors='ignore')
92
+ for name, var in constant_inputs.items():
93
+ if 'time' in var.dims:
94
+ raise ValueError(
95
+ f'Time-dependent input variable {name} must either be a forcing '
96
+ 'variable, or a target variable to allow for auto-regressive '
97
+ 'feedback.')
98
+ return constant_inputs
99
+
100
+ def _validate_targets_and_forcings(self, targets, forcings):
101
+ for name, var in targets.items():
102
+ if 'time' not in var.dims:
103
+ raise ValueError(f'Target variable {name} must be time-dependent.')
104
+
105
+ for name, var in forcings.items():
106
+ if 'time' not in var.dims:
107
+ raise ValueError(f'Forcing variable {name} must be time-dependent.')
108
+
109
+ overlap = forcings.keys() & targets.keys()
110
+ if overlap:
111
+ raise ValueError('The following were specified as both targets and '
112
+ f'forcings, which isn\'t allowed: {overlap}')
113
+
114
+ def _update_inputs(self, inputs, next_frame):
115
+ num_inputs = inputs.dims['time']
116
+
117
+ predicted_or_forced_inputs = next_frame[list(inputs.keys())]
118
+
119
+ # Combining datasets with inputs and target time stamps aligns them.
120
+ # Only keep the num_inputs trailing frames for use as next inputs.
121
+ return (xarray.concat([inputs, predicted_or_forced_inputs], dim='time')
122
+ .tail(time=num_inputs)
123
+ # Update the time coordinate to reset the lead times for
124
+ # next AR iteration.
125
+ .assign_coords(time=inputs.coords['time']))
126
+
127
+ def __call__(self,
128
+ inputs: xarray.Dataset,
129
+ targets_template: xarray.Dataset,
130
+ forcings: xarray.Dataset,
131
+ **kwargs) -> xarray.Dataset:
132
+ """Calls the Predictor.
133
+
134
+ Args:
135
+ inputs: input variable used to make predictions. Inputs can include both
136
+ time-dependent and time independent variables. Any time-dependent
137
+ input variables must also be present in the targets_template or the
138
+ forcings.
139
+ targets_template: A target template containing informations about which
140
+ variables should be predicted and the time alignment of the predictions.
141
+ All target variables must be time-dependent.
142
+ The number of time frames is used to set the number of unroll of the AR
143
+ predictor (e.g. multiple unroll of the inner predictor for one time step
144
+ in the targets is not supported yet).
145
+ forcings: Variables that will be fed to the model. The variables
146
+ should not overlap with the target ones. The time coordinates of the
147
+ forcing variables should match the target ones.
148
+ Forcing variables which are also present in the inputs, will be used to
149
+ supply ground-truth values for those inputs when they are passed to the
150
+ underlying predictor at timesteps beyond the first timestep.
151
+ **kwargs: Additional arguments passed along to the inner Predictor.
152
+
153
+ Returns:
154
+ predictions: the model predictions matching the target template.
155
+
156
+ Raise:
157
+ ValueError: if the time coordinates of the inputs and targets are not
158
+ different by a constant time step.
159
+ """
160
+
161
+ constant_inputs = self._get_and_validate_constant_inputs(
162
+ inputs, targets_template, forcings)
163
+ self._validate_targets_and_forcings(targets_template, forcings)
164
+
165
+ # After the above checks, the remaining inputs must be time-dependent:
166
+ inputs = inputs.drop_vars(constant_inputs.keys())
167
+
168
+ # A predictions template only including the next time to predict.
169
+ target_template = targets_template.isel(time=[0])
170
+
171
+ flat_forcings, forcings_treedef = (
172
+ _get_flat_arrays_and_single_timestep_treedef(forcings))
173
+ scan_variables = flat_forcings
174
+
175
+ def one_step_prediction(inputs, scan_variables):
176
+
177
+ flat_forcings = scan_variables
178
+ forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
179
+ target_template.coords['time'])
180
+
181
+ # Add constant inputs:
182
+ all_inputs = xarray.merge([constant_inputs, inputs])
183
+ predictions: xarray.Dataset = self._predictor(
184
+ all_inputs, target_template,
185
+ forcings=forcings,
186
+ **kwargs)
187
+
188
+ next_frame = xarray.merge([predictions, forcings])
189
+ next_inputs = self._update_inputs(inputs, next_frame)
190
+
191
+ # Drop the length-1 time dimension, since scan will concat all the outputs
192
+ # for different times along a new leading time dimension:
193
+ predictions = predictions.squeeze('time', drop=True)
194
+ # We return the prediction flattened into plain jax arrays, because the
195
+ # extra leading dimension added by scan prevents the tree_util
196
+ # registrations in xarray_jax from unflattening them back into an
197
+ # xarray.Dataset automatically:
198
+ flat_pred = jax.tree_util.tree_leaves(predictions)
199
+ return next_inputs, flat_pred
200
+
201
+ if self._gradient_checkpointing:
202
+ scan_length = targets_template.dims['time']
203
+ if scan_length <= 1:
204
+ logging.warning(
205
+ 'Skipping gradient checkpointing for sequence length of 1')
206
+ else:
207
+ # Just in case we take gradients (e.g. for control), although
208
+ # in most cases this will just be for a forward pass.
209
+ one_step_prediction = hk.remat(one_step_prediction)
210
+
211
+ # Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
212
+ _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
213
+
214
+ # The result of scan will have an extra leading axis on all arrays,
215
+ # corresponding to the target times in this case. We need to be prepared for
216
+ # it when unflattening the arrays back into a Dataset:
217
+ scan_result_template = (
218
+ target_template.squeeze('time', drop=True)
219
+ .expand_dims(time=targets_template.coords['time'], axis=0))
220
+ _, scan_result_treedef = jax.tree_util.tree_flatten(scan_result_template)
221
+ predictions = jax.tree_util.tree_unflatten(scan_result_treedef, flat_preds)
222
+ return predictions
223
+
224
+ def loss(self,
225
+ inputs: xarray.Dataset,
226
+ targets: xarray.Dataset,
227
+ forcings: xarray.Dataset,
228
+ **kwargs
229
+ ) -> predictor_base.LossAndDiagnostics:
230
+ """The mean of the per-timestep losses of the underlying predictor."""
231
+ if targets.sizes['time'] == 1:
232
+ # If there is only a single target timestep then we don't need any
233
+ # autoregressive feedback and can delegate the loss directly to the
234
+ # underlying single-step predictor. This means the underlying predictor
235
+ # doesn't need to implement .loss_and_predictions.
236
+ return self._predictor.loss(inputs, targets, forcings, **kwargs)
237
+
238
+ constant_inputs = self._get_and_validate_constant_inputs(
239
+ inputs, targets, forcings)
240
+ self._validate_targets_and_forcings(targets, forcings)
241
+ # After the above checks, the remaining inputs must be time-dependent:
242
+ inputs = inputs.drop_vars(constant_inputs.keys())
243
+
244
+ if self._noise_level:
245
+ def add_noise(x):
246
+ return x + self._noise_level * jax.random.normal(
247
+ hk.next_rng_key(), shape=x.shape)
248
+ # Add noise to time-dependent variables of the inputs.
249
+ inputs = jax.tree_map(add_noise, inputs)
250
+
251
+ # The per-timestep targets passed by scan to one_step_loss below will have
252
+ # no leading time axis. We need a treedef without the time axis to use
253
+ # inside one_step_loss to unflatten it back into a dataset:
254
+ flat_targets, target_treedef = _get_flat_arrays_and_single_timestep_treedef(
255
+ targets)
256
+ scan_variables = flat_targets
257
+
258
+ flat_forcings, forcings_treedef = (
259
+ _get_flat_arrays_and_single_timestep_treedef(forcings))
260
+ scan_variables = (flat_targets, flat_forcings)
261
+
262
+ def one_step_loss(inputs, scan_variables):
263
+ flat_target, flat_forcings = scan_variables
264
+ forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
265
+ targets.coords['time'][:1])
266
+
267
+ target = _unflatten_and_expand_time(flat_target, target_treedef,
268
+ targets.coords['time'][:1])
269
+
270
+ # Add constant inputs:
271
+ all_inputs = xarray.merge([constant_inputs, inputs])
272
+
273
+ (loss, diagnostics), predictions = self._predictor.loss_and_predictions(
274
+ all_inputs,
275
+ target,
276
+ forcings=forcings,
277
+ **kwargs)
278
+
279
+ # Unwrap to jax arrays shape (batch,):
280
+ loss, diagnostics = xarray_tree.map_structure(
281
+ xarray_jax.unwrap_data, (loss, diagnostics))
282
+
283
+ predictions = cast(xarray.Dataset, predictions) # Keeps pytype happy.
284
+ next_frame = xarray.merge([predictions, forcings])
285
+ next_inputs = self._update_inputs(inputs, next_frame)
286
+
287
+ return next_inputs, (loss, diagnostics)
288
+
289
+ if self._gradient_checkpointing:
290
+ scan_length = targets.dims['time']
291
+ if scan_length <= 1:
292
+ logging.warning(
293
+ 'Skipping gradient checkpointing for sequence length of 1')
294
+ else:
295
+ one_step_loss = hk.remat(one_step_loss)
296
+
297
+ # We can pass inputs (the initial state of the loop) in directly as a
298
+ # Dataset because the shape we pass in to scan is the same as the shape scan
299
+ # passes to the inner function. But, for scan_variables, we must flatten the
300
+ # targets (and unflatten them inside the inner function) because they are
301
+ # passed to the inner function per-timestep without the original time axis.
302
+ # The same apply to the optional forcing.
303
+ _, (per_timestep_losses, per_timestep_diagnostics) = hk.scan(
304
+ one_step_loss, inputs, scan_variables)
305
+
306
+ # Re-wrap loss and diagnostics as DataArray and average them over time:
307
+ (loss, diagnostics) = jax.tree_util.tree_map(
308
+ lambda x: xarray_jax.DataArray(x, dims=('time', 'batch')).mean( # pylint: disable=g-long-lambda
309
+ 'time', skipna=False),
310
+ (per_timestep_losses, per_timestep_diagnostics))
311
+
312
+ return loss, diagnostics
graphcast/casting.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Wrappers that take care of casting."""
15
+
16
+ import contextlib
17
+ from typing import Any, Mapping, Tuple
18
+
19
+ import chex
20
+ from graphcast import predictor_base
21
+ import haiku as hk
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import numpy as np
25
+ import xarray
26
+
27
+
28
+ PyTree = Any
29
+
30
+
31
+ class Bfloat16Cast(predictor_base.Predictor):
32
+ """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype."""
33
+
34
+ def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True):
35
+ """Inits the wrapper.
36
+
37
+ Args:
38
+ predictor: predictor being wrapped.
39
+ enabled: disables the wrapper if False, for simpler hyperparameter scans.
40
+
41
+ """
42
+ self._enabled = enabled
43
+ self._predictor = predictor
44
+
45
+ def __call__(self,
46
+ inputs: xarray.Dataset,
47
+ targets_template: xarray.Dataset,
48
+ forcings: xarray.Dataset,
49
+ **kwargs
50
+ ) -> xarray.Dataset:
51
+ if not self._enabled:
52
+ return self._predictor(inputs, targets_template, forcings, **kwargs)
53
+
54
+ with bfloat16_variable_view():
55
+ predictions = self._predictor(
56
+ *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
57
+ **kwargs,)
58
+
59
+ predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
60
+ if predictions_dtype != jnp.bfloat16:
61
+ raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
62
+
63
+ targets_dtype = infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types
64
+ return tree_map_cast(
65
+ predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
66
+
67
+ def loss(self,
68
+ inputs: xarray.Dataset,
69
+ targets: xarray.Dataset,
70
+ forcings: xarray.Dataset,
71
+ **kwargs,
72
+ ) -> predictor_base.LossAndDiagnostics:
73
+ if not self._enabled:
74
+ return self._predictor.loss(inputs, targets, forcings, **kwargs)
75
+
76
+ with bfloat16_variable_view():
77
+ loss, scalars = self._predictor.loss(
78
+ *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
79
+
80
+ if loss.dtype != jnp.bfloat16:
81
+ raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
82
+
83
+ targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
84
+
85
+ # Note that casting back the loss to e.g. float32 should not affect data
86
+ # types of the backwards pass, because the first thing the backwards pass
87
+ # should do is to go backwards the casting op and cast back to bfloat16
88
+ # (and xprofs seem to confirm this).
89
+ return tree_map_cast((loss, scalars),
90
+ input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
91
+
92
+ def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
93
+ self,
94
+ inputs: xarray.Dataset,
95
+ targets: xarray.Dataset,
96
+ forcings: xarray.Dataset,
97
+ **kwargs,
98
+ ) -> Tuple[predictor_base.LossAndDiagnostics,
99
+ xarray.Dataset]:
100
+ if not self._enabled:
101
+ return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray
102
+ **kwargs)
103
+
104
+ with bfloat16_variable_view():
105
+ (loss, scalars), predictions = self._predictor.loss_and_predictions(
106
+ *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
107
+
108
+ if loss.dtype != jnp.bfloat16:
109
+ raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
110
+
111
+ predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
112
+ if predictions_dtype != jnp.bfloat16:
113
+ raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
114
+
115
+ targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
116
+ return tree_map_cast(((loss, scalars), predictions),
117
+ input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
118
+
119
+
120
+ def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
121
+ """Infers a floating dtype from an input mapping of data."""
122
+ dtypes = {
123
+ v.dtype
124
+ for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
125
+ if len(dtypes) != 1:
126
+ dtypes_and_shapes = {
127
+ k: (v.dtype, v.shape)
128
+ for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
129
+ raise ValueError(
130
+ f'Did not found exactly one floating dtype {dtypes} in input variables:'
131
+ f'{dtypes_and_shapes}')
132
+ return list(dtypes)[0]
133
+
134
+
135
+ def _all_inputs_to_bfloat16(
136
+ inputs: xarray.Dataset,
137
+ targets: xarray.Dataset,
138
+ forcings: xarray.Dataset,
139
+ ) -> Tuple[xarray.Dataset,
140
+ xarray.Dataset,
141
+ xarray.Dataset]:
142
+ return (inputs.astype(jnp.bfloat16),
143
+ jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
144
+ forcings.astype(jnp.bfloat16))
145
+
146
+
147
+ def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
148
+ ) -> PyTree:
149
+ def cast_fn(x):
150
+ if x.dtype == input_dtype:
151
+ return x.astype(output_dtype)
152
+ return jax.tree_map(cast_fn, inputs)
153
+
154
+
155
+ @contextlib.contextmanager
156
+ def bfloat16_variable_view(enabled: bool = True):
157
+ """Context for Haiku modules with float32 params, but bfloat16 activations.
158
+
159
+ It works as follows:
160
+ * Every time a variable is requested to be created/set as np.bfloat16,
161
+ it will create an underlying float32 variable, instead.
162
+ * Every time a variable a variable is requested as bfloat16, it will check the
163
+ variable is of float32 type, and cast the variable to bfloat16.
164
+
165
+ Note the gradients are still computed and accumulated as float32, because
166
+ the params returned by init are float32, so the gradient function with
167
+ respect to the params will already include an implicit casting to float32.
168
+
169
+ Args:
170
+ enabled: Only enables bfloat16 behavior if True.
171
+
172
+ Yields:
173
+ None
174
+ """
175
+
176
+ if enabled:
177
+ with hk.custom_creator(
178
+ _bfloat16_creator, state=True), hk.custom_getter(
179
+ _bfloat16_getter, state=True), hk.custom_setter(
180
+ _bfloat16_setter):
181
+ yield
182
+ else:
183
+ yield
184
+
185
+
186
+ def _bfloat16_creator(next_creator, shape, dtype, init, context):
187
+ """Creates float32 variables when bfloat16 is requested."""
188
+ if context.original_dtype == jnp.bfloat16:
189
+ dtype = jnp.float32
190
+ return next_creator(shape, dtype, init)
191
+
192
+
193
+ def _bfloat16_getter(next_getter, value, context):
194
+ """Casts float32 to bfloat16 when bfloat16 was originally requested."""
195
+ if context.original_dtype == jnp.bfloat16:
196
+ assert value.dtype == jnp.float32
197
+ value = value.astype(jnp.bfloat16)
198
+ return next_getter(value)
199
+
200
+
201
+ def _bfloat16_setter(next_setter, value, context):
202
+ """Casts bfloat16 to float32 when bfloat16 was originally set."""
203
+ if context.original_dtype == jnp.bfloat16:
204
+ value = value.astype(jnp.float32)
205
+ return next_setter(value)
graphcast/checkpoint.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Serialize and deserialize trees."""
15
+
16
+ import dataclasses
17
+ import io
18
+ import types
19
+ from typing import Any, BinaryIO, Optional, TypeVar
20
+
21
+ import numpy as np
22
+
23
+ _T = TypeVar("_T")
24
+
25
+
26
+ def dump(dest: BinaryIO, value: Any) -> None:
27
+ """Dump a tree of dicts/dataclasses to a file object.
28
+
29
+ Args:
30
+ dest: a file object to write to.
31
+ value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
32
+ other basic types. Unions are not supported, other than Optional/None
33
+ which is only supported in dataclasses, not in dicts, lists or tuples.
34
+ All leaves must be coercible to a numpy array, and recoverable as a single
35
+ arg to a type.
36
+ """
37
+ buffer = io.BytesIO() # In case the destination doesn't support seeking.
38
+ np.savez(buffer, **_flatten(value))
39
+ dest.write(buffer.getvalue())
40
+
41
+
42
+ def load(source: BinaryIO, typ: type[_T]) -> _T:
43
+ """Load from a file object and convert it to the specified type.
44
+
45
+ Args:
46
+ source: a file object to read from.
47
+ typ: a type object that acts as a schema for deserialization. It must match
48
+ what was serialized. If a type is Any, it will be returned however numpy
49
+ serialized it, which is what you want for a tree of numpy arrays.
50
+
51
+ Returns:
52
+ the deserialized value as the specified type.
53
+ """
54
+ return _convert_types(typ, _unflatten(np.load(source)))
55
+
56
+
57
+ _SEP = ":"
58
+
59
+
60
+ def _flatten(tree: Any) -> dict[str, Any]:
61
+ """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
62
+ if dataclasses.is_dataclass(tree):
63
+ # Don't use dataclasses.asdict as it is recursive so skips dropping None.
64
+ tree = {f.name: v for f in dataclasses.fields(tree)
65
+ if (v := getattr(tree, f.name)) is not None}
66
+ elif isinstance(tree, (list, tuple)):
67
+ tree = dict(enumerate(tree))
68
+
69
+ assert isinstance(tree, dict)
70
+
71
+ flat = {}
72
+ for k, v in tree.items():
73
+ k = str(k)
74
+ assert _SEP not in k
75
+ if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
76
+ for a, b in _flatten(v).items():
77
+ flat[f"{k}{_SEP}{a}"] = b
78
+ else:
79
+ assert v is not None
80
+ flat[k] = v
81
+ return flat
82
+
83
+
84
+ def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
85
+ """Unflatten a dict to a tree of dicts."""
86
+ tree = {}
87
+ for flat_key, v in flat.items():
88
+ node = tree
89
+ keys = flat_key.split(_SEP)
90
+ for k in keys[:-1]:
91
+ if k not in node:
92
+ node[k] = {}
93
+ node = node[k]
94
+ node[keys[-1]] = v
95
+ return tree
96
+
97
+
98
+ def _convert_types(typ: type[_T], value: Any) -> _T:
99
+ """Convert some structure into the given type. The structures must match."""
100
+ if typ in (Any, ...):
101
+ return value
102
+
103
+ if typ in (int, float, str, bool):
104
+ return typ(value)
105
+
106
+ if typ is np.ndarray:
107
+ assert isinstance(value, np.ndarray)
108
+ return value
109
+
110
+ if dataclasses.is_dataclass(typ):
111
+ kwargs = {}
112
+ for f in dataclasses.fields(typ):
113
+ # Only support Optional for dataclasses, as numpy can't serialize it
114
+ # directly (without pickle), and dataclasses are the only case where we
115
+ # can know the full set of values and types and therefore know the
116
+ # non-existence must mean None.
117
+ if isinstance(f.type, (types.UnionType, type(Optional[int]))):
118
+ constructors = [t for t in f.type.__args__ if t is not types.NoneType]
119
+ if len(constructors) != 1:
120
+ raise TypeError(
121
+ "Optional works, Union with anything except None doesn't")
122
+ if f.name not in value:
123
+ kwargs[f.name] = None
124
+ continue
125
+ constructor = constructors[0]
126
+ else:
127
+ constructor = f.type
128
+
129
+ if f.name in value:
130
+ kwargs[f.name] = _convert_types(constructor, value[f.name])
131
+ else:
132
+ raise ValueError(f"Missing value: {f.name}")
133
+ return typ(**kwargs)
134
+
135
+ base_type = getattr(typ, "__origin__", None)
136
+
137
+ if base_type is dict:
138
+ assert len(typ.__args__) == 2
139
+ key_type, value_type = typ.__args__
140
+ return {_convert_types(key_type, k): _convert_types(value_type, v)
141
+ for k, v in value.items()}
142
+
143
+ if base_type is list:
144
+ assert len(typ.__args__) == 1
145
+ value_type = typ.__args__[0]
146
+ return [_convert_types(value_type, v)
147
+ for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
148
+
149
+ if base_type is tuple:
150
+ if len(typ.__args__) == 2 and typ.__args__[1] == ...:
151
+ # An arbitrary length tuple of a single type, eg: tuple[int, ...]
152
+ value_type = typ.__args__[0]
153
+ return tuple(_convert_types(value_type, v)
154
+ for _, v in sorted(value.items(), key=lambda x: int(x[0])))
155
+ else:
156
+ # A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
157
+ assert len(typ.__args__) == len(value)
158
+ return tuple(
159
+ _convert_types(t, v)
160
+ for t, (_, v) in zip(
161
+ typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
162
+
163
+ # This is probably unreachable with reasonable serializable inputs.
164
+ try:
165
+ return typ(value)
166
+ except TypeError as e:
167
+ raise TypeError(
168
+ "_convert_types expects the type argument to be a dataclass defined "
169
+ "with types that are valid constructors (eg tuple is fine, Tuple "
170
+ "isn't), and accept a numpy array as the sole argument.") from e
graphcast/checkpoint_test.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Check that the checkpoint serialization is reversable."""
15
+
16
+ import dataclasses
17
+ import io
18
+ from typing import Any, Optional, Union
19
+
20
+ from absl.testing import absltest
21
+ from graphcast import checkpoint
22
+ import numpy as np
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class SubConfig:
27
+ a: int
28
+ b: str
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class Config:
33
+ bt: bool
34
+ bf: bool
35
+ i: int
36
+ f: float
37
+ o1: Optional[int]
38
+ o2: Optional[int]
39
+ o3: Union[int, None]
40
+ o4: Union[int, None]
41
+ o5: int | None
42
+ o6: int | None
43
+ li: list[int]
44
+ ls: list[str]
45
+ ldc: list[SubConfig]
46
+ tf: tuple[float, ...]
47
+ ts: tuple[str, ...]
48
+ t: tuple[str, int, SubConfig]
49
+ tdc: tuple[SubConfig, ...]
50
+ dsi: dict[str, int]
51
+ dss: dict[str, str]
52
+ dis: dict[int, str]
53
+ dsdis: dict[str, dict[int, str]]
54
+ dc: SubConfig
55
+ dco: Optional[SubConfig]
56
+ ddc: dict[str, SubConfig]
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class Checkpoint:
61
+ params: dict[str, Any]
62
+ config: Config
63
+
64
+
65
+ class DataclassTest(absltest.TestCase):
66
+
67
+ def test_serialize_dataclass(self):
68
+ ckpt = Checkpoint(
69
+ params={
70
+ "layer1": {
71
+ "w": np.arange(10).reshape(2, 5),
72
+ "b": np.array([2, 6]),
73
+ },
74
+ "layer2": {
75
+ "w": np.arange(8).reshape(2, 4),
76
+ "b": np.array([2, 6]),
77
+ },
78
+ "blah": np.array([3, 9]),
79
+ },
80
+ config=Config(
81
+ bt=True,
82
+ bf=False,
83
+ i=42,
84
+ f=3.14,
85
+ o1=1,
86
+ o2=None,
87
+ o3=2,
88
+ o4=None,
89
+ o5=3,
90
+ o6=None,
91
+ li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2],
92
+ ls=list("qhjfdxtpzgemryoikwvblcaus"),
93
+ ldc=[SubConfig(1, "hello"), SubConfig(2, "world")],
94
+ tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6),
95
+ ts=("hello", "world"),
96
+ t=("foo", 42, SubConfig(1, "bar")),
97
+ tdc=(SubConfig(1, "hello"), SubConfig(2, "world")),
98
+ dsi={"a": 1, "b": 2, "c": 3},
99
+ dss={"d": "e", "f": "g"},
100
+ dis={1: "a", 2: "b", 3: "c"},
101
+ dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}},
102
+ dc=SubConfig(1, "hello"),
103
+ dco=None,
104
+ ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")},
105
+ ))
106
+
107
+ buffer = io.BytesIO()
108
+ checkpoint.dump(buffer, ckpt)
109
+ buffer.seek(0)
110
+ ckpt2 = checkpoint.load(buffer, Checkpoint)
111
+ np.testing.assert_array_equal(ckpt.params["layer1"]["w"],
112
+ ckpt2.params["layer1"]["w"])
113
+ np.testing.assert_array_equal(ckpt.params["layer1"]["b"],
114
+ ckpt2.params["layer1"]["b"])
115
+ np.testing.assert_array_equal(ckpt.params["layer2"]["w"],
116
+ ckpt2.params["layer2"]["w"])
117
+ np.testing.assert_array_equal(ckpt.params["layer2"]["b"],
118
+ ckpt2.params["layer2"]["b"])
119
+ np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"])
120
+ self.assertEqual(ckpt.config, ckpt2.config)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ absltest.main()
graphcast/data_utils.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Dataset utilities."""
15
+
16
+ from typing import Any, Mapping, Sequence, Tuple, Union
17
+
18
+ from graphcast import solar_radiation
19
+ import numpy as np
20
+ import pandas as pd
21
+ import xarray
22
+
23
+ TimedeltaLike = Any # Something convertible to pd.Timedelta.
24
+ TimedeltaStr = str # A string convertible to pd.Timedelta.
25
+
26
+ TargetLeadTimes = Union[
27
+ TimedeltaLike,
28
+ Sequence[TimedeltaLike],
29
+ slice # with TimedeltaLike as its start and stop.
30
+ ]
31
+
32
+ _SEC_PER_HOUR = 3600
33
+ _HOUR_PER_DAY = 24
34
+ SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY
35
+ _AVG_DAY_PER_YEAR = 365.24219
36
+ AVG_SEC_PER_YEAR = SEC_PER_DAY * _AVG_DAY_PER_YEAR
37
+
38
+ DAY_PROGRESS = "day_progress"
39
+ YEAR_PROGRESS = "year_progress"
40
+ _DERIVED_VARS = {
41
+ DAY_PROGRESS,
42
+ f"{DAY_PROGRESS}_sin",
43
+ f"{DAY_PROGRESS}_cos",
44
+ YEAR_PROGRESS,
45
+ f"{YEAR_PROGRESS}_sin",
46
+ f"{YEAR_PROGRESS}_cos",
47
+ }
48
+ TISR = "toa_incident_solar_radiation"
49
+
50
+
51
+ def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray:
52
+ """Computes year progress for times in seconds.
53
+
54
+ Args:
55
+ seconds_since_epoch: Times in seconds since the "epoch" (the point at which
56
+ UNIX time starts).
57
+
58
+ Returns:
59
+ Year progress normalized to be in the [0, 1) interval for each time point.
60
+ """
61
+
62
+ # Start with the pure integer division, and then float at the very end.
63
+ # We will try to keep as much precision as possible.
64
+ years_since_epoch = (
65
+ seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR)
66
+ )
67
+ # Note depending on how these ops are down, we may end up with a "weak_type"
68
+ # which can cause issues in subtle ways, and hard to track here.
69
+ # In any case, casting to float32 should get rid of the weak type.
70
+ # [0, 1.) Interval.
71
+ return np.mod(years_since_epoch, 1.0).astype(np.float32)
72
+
73
+
74
+ def get_day_progress(
75
+ seconds_since_epoch: np.ndarray,
76
+ longitude: np.ndarray,
77
+ ) -> np.ndarray:
78
+ """Computes day progress for times in seconds at each longitude.
79
+
80
+ Args:
81
+ seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the
82
+ point at which UNIX time starts).
83
+ longitude: 1D array of longitudes at which day progress is computed.
84
+
85
+ Returns:
86
+ 2D array of day progress values normalized to be in the [0, 1) inverval
87
+ for each time point at each longitude.
88
+ """
89
+
90
+ # [0.0, 1.0) Interval.
91
+ day_progress_greenwich = (
92
+ np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY
93
+ )
94
+
95
+ # Offset the day progress to the longitude of each point on Earth.
96
+ longitude_offsets = np.deg2rad(longitude) / (2 * np.pi)
97
+ day_progress = np.mod(
98
+ day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0
99
+ )
100
+ return day_progress.astype(np.float32)
101
+
102
+
103
+ def featurize_progress(
104
+ name: str, dims: Sequence[str], progress: np.ndarray
105
+ ) -> Mapping[str, xarray.Variable]:
106
+ """Derives features used by ML models from the `progress` variable.
107
+
108
+ Args:
109
+ name: Base variable name from which features are derived.
110
+ dims: List of the output feature dimensions, e.g. ("day", "lon").
111
+ progress: Progress variable values.
112
+
113
+ Returns:
114
+ Dictionary of xarray variables derived from the `progress` values. It
115
+ includes the original `progress` variable along with its sin and cos
116
+ transformations.
117
+
118
+ Raises:
119
+ ValueError if the number of feature dimensions is not equal to the number
120
+ of data dimensions.
121
+ """
122
+ if len(dims) != progress.ndim:
123
+ raise ValueError(
124
+ f"Number of feature dimensions ({len(dims)}) must be equal to the"
125
+ f" number of data dimensions: {progress.ndim}."
126
+ )
127
+ progress_phase = progress * (2 * np.pi)
128
+ return {
129
+ name: xarray.Variable(dims, progress),
130
+ name + "_sin": xarray.Variable(dims, np.sin(progress_phase)),
131
+ name + "_cos": xarray.Variable(dims, np.cos(progress_phase)),
132
+ }
133
+
134
+
135
+ def add_derived_vars(data: xarray.Dataset) -> None:
136
+ """Adds year and day progress features to `data` in place if missing.
137
+
138
+ Args:
139
+ data: Xarray dataset to which derived features will be added.
140
+
141
+ Raises:
142
+ ValueError if `datetime` or `lon` are not in `data` coordinates.
143
+ """
144
+
145
+ for coord in ("datetime", "lon"):
146
+ if coord not in data.coords:
147
+ raise ValueError(f"'{coord}' must be in `data` coordinates.")
148
+
149
+ # Compute seconds since epoch.
150
+ # Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)`
151
+ # does not work as xarrays always cast dates into nanoseconds!
152
+ seconds_since_epoch = (
153
+ data.coords["datetime"].data.astype("datetime64[s]").astype(np.int64)
154
+ )
155
+ batch_dim = ("batch",) if "batch" in data.dims else ()
156
+
157
+ # Add year progress features if missing.
158
+ if YEAR_PROGRESS not in data.data_vars:
159
+ year_progress = get_year_progress(seconds_since_epoch)
160
+ data.update(
161
+ featurize_progress(
162
+ name=YEAR_PROGRESS,
163
+ dims=batch_dim + ("time",),
164
+ progress=year_progress,
165
+ )
166
+ )
167
+
168
+ # Add day progress features if missing.
169
+ if DAY_PROGRESS not in data.data_vars:
170
+ longitude_coord = data.coords["lon"]
171
+ day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data)
172
+ data.update(
173
+ featurize_progress(
174
+ name=DAY_PROGRESS,
175
+ dims=batch_dim + ("time",) + longitude_coord.dims,
176
+ progress=day_progress,
177
+ )
178
+ )
179
+
180
+
181
+ def add_tisr_var(data: xarray.Dataset) -> None:
182
+ """Adds TISR feature to `data` in place if missing.
183
+
184
+ Args:
185
+ data: Xarray dataset to which TISR feature will be added.
186
+
187
+ Raises:
188
+ ValueError if `datetime`, 'lat', or `lon` are not in `data` coordinates.
189
+ """
190
+
191
+ if TISR in data.data_vars:
192
+ return
193
+
194
+ for coord in ("datetime", "lat", "lon"):
195
+ if coord not in data.coords:
196
+ raise ValueError(f"'{coord}' must be in `data` coordinates.")
197
+
198
+ # Remove `batch` dimension of size one if present. An error will be raised if
199
+ # the `batch` dimension exists and has size greater than one.
200
+ data_no_batch = data.squeeze("batch") if "batch" in data.dims else data
201
+
202
+ tisr = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
203
+ data_no_batch, use_jit=True
204
+ )
205
+
206
+ if "batch" in data.dims:
207
+ tisr = tisr.expand_dims("batch", axis=0)
208
+
209
+ data.update({TISR: tisr})
210
+
211
+
212
+ def extract_input_target_times(
213
+ dataset: xarray.Dataset,
214
+ input_duration: TimedeltaLike,
215
+ target_lead_times: TargetLeadTimes,
216
+ ) -> Tuple[xarray.Dataset, xarray.Dataset]:
217
+ """Extracts inputs and targets for prediction, from a Dataset with a time dim.
218
+
219
+ The input period is assumed to be contiguous (specified by a duration), but
220
+ the targets can be a list of arbitrary lead times.
221
+
222
+ Examples:
223
+
224
+ # Use 18 hours of data as inputs, and two specific lead times as targets:
225
+ # 3 days and 5 days after the final input.
226
+ extract_inputs_targets(
227
+ dataset,
228
+ input_duration='18h',
229
+ target_lead_times=('3d', '5d')
230
+ )
231
+
232
+ # Use 1 day of data as input, and all lead times between 6 hours and
233
+ # 24 hours inclusive as targets. Demonstrates a friendlier supported string
234
+ # syntax.
235
+ extract_inputs_targets(
236
+ dataset,
237
+ input_duration='1 day',
238
+ target_lead_times=slice('6 hours', '24 hours')
239
+ )
240
+
241
+ # Just use a single target lead time of 3 days:
242
+ extract_inputs_targets(
243
+ dataset,
244
+ input_duration='24h',
245
+ target_lead_times='3d'
246
+ )
247
+
248
+ Args:
249
+ dataset: An xarray.Dataset with a 'time' dimension whose coordinates are
250
+ timedeltas. It's assumed that the time coordinates have a fixed offset /
251
+ time resolution, and that the input_duration and target_lead_times are
252
+ multiples of this.
253
+ input_duration: pandas.Timedelta or something convertible to it (e.g. a
254
+ shorthand string like '6h' or '5d12h').
255
+ target_lead_times: Either a single lead time, a slice with start and stop
256
+ (inclusive) lead times, or a sequence of lead times. Lead times should be
257
+ Timedeltas (or something convertible to). They are given relative to the
258
+ final input timestep, and should be positive.
259
+
260
+ Returns:
261
+ inputs:
262
+ targets:
263
+ Two datasets with the same shape as the input dataset except that a
264
+ selection has been made from the time axis, and the origin of the
265
+ time coordinate will be shifted to refer to lead times relative to the
266
+ final input timestep. So for inputs the times will end at lead time 0,
267
+ for targets the time coordinates will refer to the lead times requested.
268
+ """
269
+
270
+ (target_lead_times, target_duration
271
+ ) = _process_target_lead_times_and_get_duration(target_lead_times)
272
+
273
+ # Shift the coordinates for the time axis so that a timedelta of zero
274
+ # corresponds to the forecast reference time. That is, the final timestep
275
+ # that's available as input to the forecast, with all following timesteps
276
+ # forming the target period which needs to be predicted.
277
+ # This means the time coordinates are now forecast lead times.
278
+ time = dataset.coords["time"]
279
+ dataset = dataset.assign_coords(time=time + target_duration - time[-1])
280
+
281
+ # Slice out targets:
282
+ targets = dataset.sel({"time": target_lead_times})
283
+
284
+ input_duration = pd.Timedelta(input_duration)
285
+ # Both endpoints are inclusive with label-based slicing, so we offset by a
286
+ # small epsilon to make one of the endpoints non-inclusive:
287
+ zero = pd.Timedelta(0)
288
+ epsilon = pd.Timedelta(1, "ns")
289
+ inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)})
290
+ return inputs, targets
291
+
292
+
293
+ def _process_target_lead_times_and_get_duration(
294
+ target_lead_times: TargetLeadTimes) -> TimedeltaLike:
295
+ """Returns the minimum duration for the target lead times."""
296
+ if isinstance(target_lead_times, slice):
297
+ # A slice of lead times. xarray already accepts timedelta-like values for
298
+ # the begin/end/step of the slice.
299
+ if target_lead_times.start is None:
300
+ # If the start isn't specified, we assume it starts at the next timestep
301
+ # after lead time 0 (lead time 0 is the final input timestep):
302
+ target_lead_times = slice(
303
+ pd.Timedelta(1, "ns"), target_lead_times.stop, target_lead_times.step
304
+ )
305
+ target_duration = pd.Timedelta(target_lead_times.stop)
306
+ else:
307
+ if not isinstance(target_lead_times, (list, tuple, set)):
308
+ # A single lead time, which we wrap as a length-1 array to ensure there
309
+ # still remains a time dimension (here of length 1) for consistency.
310
+ target_lead_times = [target_lead_times]
311
+
312
+ # A list of multiple (not necessarily contiguous) lead times:
313
+ target_lead_times = [pd.Timedelta(x) for x in target_lead_times]
314
+ target_lead_times.sort()
315
+ target_duration = target_lead_times[-1]
316
+ return target_lead_times, target_duration
317
+
318
+
319
+ def extract_inputs_targets_forcings(
320
+ dataset: xarray.Dataset,
321
+ *,
322
+ input_variables: Tuple[str, ...],
323
+ target_variables: Tuple[str, ...],
324
+ forcing_variables: Tuple[str, ...],
325
+ pressure_levels: Tuple[int, ...],
326
+ input_duration: TimedeltaLike,
327
+ target_lead_times: TargetLeadTimes,
328
+ ) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]:
329
+ """Extracts inputs, targets and forcings according to requirements."""
330
+ dataset = dataset.sel(level=list(pressure_levels))
331
+
332
+ # "Forcings" include derived variables that do not exist in the original ERA5
333
+ # or HRES datasets, as well as other variables (e.g. tisr) that need to be
334
+ # computed manually for the target lead times. Compute the requested ones.
335
+ if set(forcing_variables) & _DERIVED_VARS:
336
+ add_derived_vars(dataset)
337
+ if set(forcing_variables) & {TISR}:
338
+ add_tisr_var(dataset)
339
+
340
+ # `datetime` is needed by add_derived_vars but breaks autoregressive rollouts.
341
+ dataset = dataset.drop_vars("datetime")
342
+
343
+ inputs, targets = extract_input_target_times(
344
+ dataset,
345
+ input_duration=input_duration,
346
+ target_lead_times=target_lead_times)
347
+
348
+ if set(forcing_variables) & set(target_variables):
349
+ raise ValueError(
350
+ f"Forcing variables {forcing_variables} should not "
351
+ f"overlap with target variables {target_variables}."
352
+ )
353
+
354
+ inputs = inputs[list(input_variables)]
355
+ # The forcing uses the same time coordinates as the target.
356
+ forcings = targets[list(forcing_variables)]
357
+ targets = targets[list(target_variables)]
358
+
359
+ return inputs, targets, forcings
graphcast/data_utils_test.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for `data_utils.py`."""
15
+
16
+ import datetime
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from graphcast import data_utils
20
+ import numpy as np
21
+ import xarray as xa
22
+
23
+
24
+ class DataUtilsTest(parameterized.TestCase):
25
+
26
+ def setUp(self):
27
+ super().setUp()
28
+ # Fix the seed for reproducibility.
29
+ np.random.seed(0)
30
+
31
+ def test_year_progress_is_zero_at_year_start_or_end(self):
32
+ year_progress = data_utils.get_year_progress(
33
+ np.array([
34
+ 0,
35
+ data_utils.AVG_SEC_PER_YEAR,
36
+ data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
37
+ ])
38
+ )
39
+ np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
40
+
41
+ def test_year_progress_is_almost_one_before_year_ends(self):
42
+ year_progress = data_utils.get_year_progress(
43
+ np.array([
44
+ data_utils.AVG_SEC_PER_YEAR - 1,
45
+ (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
46
+ ])
47
+ )
48
+ with self.subTest("Year progress values are close to 1"):
49
+ self.assertTrue(np.all(year_progress > 0.999))
50
+ with self.subTest("Year progress values != 1"):
51
+ self.assertTrue(np.all(year_progress < 1.0))
52
+
53
+ def test_day_progress_computes_for_all_times_and_longitudes(self):
54
+ times = np.random.randint(low=0, high=1e10, size=10)
55
+ longitudes = np.arange(0, 360.0, 1.0)
56
+ day_progress = data_utils.get_day_progress(times, longitudes)
57
+ with self.subTest("Day progress is computed for all times and longinutes"):
58
+ self.assertSequenceEqual(
59
+ day_progress.shape, (len(times), len(longitudes))
60
+ )
61
+
62
+ @parameterized.named_parameters(
63
+ dict(
64
+ testcase_name="random_date_1",
65
+ year=1988,
66
+ month=11,
67
+ day=7,
68
+ hour=2,
69
+ minute=45,
70
+ second=34,
71
+ ),
72
+ dict(
73
+ testcase_name="random_date_2",
74
+ year=2022,
75
+ month=3,
76
+ day=12,
77
+ hour=7,
78
+ minute=1,
79
+ second=0,
80
+ ),
81
+ )
82
+ def test_day_progress_is_in_between_zero_and_one(
83
+ self, year, month, day, hour, minute, second
84
+ ):
85
+ # Datetime from a timestamp.
86
+ dt = datetime.datetime(year, month, day, hour, minute, second)
87
+ # Epoch time.
88
+ epoch_time = datetime.datetime(1970, 1, 1)
89
+ # Seconds since epoch.
90
+ seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
91
+
92
+ # Longitudes with 1 degree resolution.
93
+ longitudes = np.arange(0, 360.0, 1.0)
94
+
95
+ day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
96
+ with self.subTest("Day progress >= 0"):
97
+ self.assertTrue(np.all(day_progress >= 0.0))
98
+ with self.subTest("Day progress < 1"):
99
+ self.assertTrue(np.all(day_progress < 1.0))
100
+
101
+ def test_day_progress_is_zero_at_day_start_or_end(self):
102
+ day_progress = data_utils.get_day_progress(
103
+ seconds_since_epoch=np.array([
104
+ 0,
105
+ data_utils.SEC_PER_DAY,
106
+ data_utils.SEC_PER_DAY * 42, # 42 days.
107
+ ]),
108
+ longitude=np.array([0.0]),
109
+ )
110
+ np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
111
+
112
+ def test_day_progress_specific_value(self):
113
+ day_progress = data_utils.get_day_progress(
114
+ seconds_since_epoch=np.array([123]),
115
+ longitude=np.array([0.0]),
116
+ )
117
+ np.testing.assert_array_almost_equal(
118
+ day_progress, np.array([[0.00142361]]), decimal=6
119
+ )
120
+
121
+ def test_featurize_progress_valid_values_and_dimensions(self):
122
+ day_progress = np.array([0.0, 0.45, 0.213])
123
+ feature_dimensions = ("time",)
124
+ progress_features = data_utils.featurize_progress(
125
+ name="day_progress", dims=feature_dimensions, progress=day_progress
126
+ )
127
+ for feature in progress_features.values():
128
+ with self.subTest(f"Valid dimensions for {feature}"):
129
+ self.assertSequenceEqual(feature.dims, feature_dimensions)
130
+
131
+ with self.subTest("Valid values for day_progress"):
132
+ np.testing.assert_array_equal(
133
+ day_progress, progress_features["day_progress"].values
134
+ )
135
+
136
+ with self.subTest("Valid values for day_progress_sin"):
137
+ np.testing.assert_array_almost_equal(
138
+ np.array([0.0, 0.30901699, 0.97309851]),
139
+ progress_features["day_progress_sin"].values,
140
+ decimal=6,
141
+ )
142
+
143
+ with self.subTest("Valid values for day_progress_cos"):
144
+ np.testing.assert_array_almost_equal(
145
+ np.array([1.0, -0.95105652, 0.23038943]),
146
+ progress_features["day_progress_cos"].values,
147
+ decimal=6,
148
+ )
149
+
150
+ def test_featurize_progress_invalid_dimensions(self):
151
+ year_progress = np.array([0.0, 0.45, 0.213])
152
+ feature_dimensions = ("time", "longitude")
153
+ with self.assertRaises(ValueError):
154
+ data_utils.featurize_progress(
155
+ name="year_progress", dims=feature_dimensions, progress=year_progress
156
+ )
157
+
158
+ def test_add_derived_vars_variables_added(self):
159
+ data = xa.Dataset(
160
+ data_vars={
161
+ "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
162
+ },
163
+ coords={
164
+ "lon": np.array([0.0, 0.5]),
165
+ "datetime": np.array([
166
+ datetime.datetime(2021, 1, 1),
167
+ datetime.datetime(2023, 1, 1),
168
+ datetime.datetime(2023, 1, 3),
169
+ ]),
170
+ },
171
+ )
172
+ data_utils.add_derived_vars(data)
173
+ all_variables = set(data.variables)
174
+
175
+ with self.subTest("Original value was not removed"):
176
+ self.assertIn("var1", all_variables)
177
+ with self.subTest("Year progress feature was added"):
178
+ self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
179
+ with self.subTest("Day progress feature was added"):
180
+ self.assertIn(data_utils.DAY_PROGRESS, all_variables)
181
+
182
+ def test_add_derived_vars_existing_vars_not_overridden(self):
183
+ dims = ["x", "lon", "datetime"]
184
+ data = xa.Dataset(
185
+ data_vars={
186
+ "var1": (dims, 8 * np.random.randn(2, 2, 3)),
187
+ data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)),
188
+ data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)),
189
+ },
190
+ coords={
191
+ "lon": np.array([0.0, 0.5]),
192
+ "datetime": np.array([
193
+ datetime.datetime(2021, 1, 1),
194
+ datetime.datetime(2023, 1, 1),
195
+ datetime.datetime(2023, 1, 3),
196
+ ]),
197
+ },
198
+ )
199
+
200
+ data_utils.add_derived_vars(data)
201
+
202
+ with self.subTest("Year progress feature was not overridden"):
203
+ np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111)
204
+ with self.subTest("Day progress feature was not overridden"):
205
+ np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222)
206
+
207
+ @parameterized.named_parameters(
208
+ dict(testcase_name="missing_datetime", coord_name="lon"),
209
+ dict(testcase_name="missing_lon", coord_name="datetime"),
210
+ )
211
+ def test_add_derived_vars_missing_coordinate_raises_value_error(
212
+ self, coord_name
213
+ ):
214
+ with self.subTest(f"Missing {coord_name} coordinate"):
215
+ data = xa.Dataset(
216
+ data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
217
+ coords={
218
+ coord_name: np.array([0.0, 0.5]),
219
+ },
220
+ )
221
+ with self.assertRaises(ValueError):
222
+ data_utils.add_derived_vars(data)
223
+
224
+ def test_add_tisr_var_variable_added(self):
225
+ data = xa.Dataset(
226
+ data_vars={
227
+ "var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0))
228
+ },
229
+ coords={
230
+ "lat": np.array([2.0, 1.0]),
231
+ "lon": np.array([0.0, 0.5]),
232
+ "time": np.array([100, 200], dtype="timedelta64[s]"),
233
+ "datetime": xa.Variable(
234
+ "time", np.array([10, 20], dtype="datetime64[D]")
235
+ ),
236
+ },
237
+ )
238
+
239
+ data_utils.add_tisr_var(data)
240
+
241
+ self.assertIn(data_utils.TISR, set(data.variables))
242
+
243
+ def test_add_tisr_var_existing_var_not_overridden(self):
244
+ dims = ["time", "lat", "lon"]
245
+ data = xa.Dataset(
246
+ data_vars={
247
+ "var1": (dims, np.full((2, 2, 2), 8.0)),
248
+ data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)),
249
+ },
250
+ coords={
251
+ "lat": np.array([2.0, 1.0]),
252
+ "lon": np.array([0.0, 0.5]),
253
+ "time": np.array([100, 200], dtype="timedelta64[s]"),
254
+ "datetime": xa.Variable(
255
+ "time", np.array([10, 20], dtype="datetime64[D]")
256
+ ),
257
+ },
258
+ )
259
+
260
+ data_utils.add_derived_vars(data)
261
+
262
+ np.testing.assert_allclose(data[data_utils.TISR], 1200.0)
263
+
264
+ def test_add_tisr_var_works_with_batch_dim_size_one(self):
265
+ data = xa.Dataset(
266
+ data_vars={
267
+ "var1": (
268
+ ["batch", "time", "lat", "lon"],
269
+ np.full((1, 2, 2, 2), 8.0),
270
+ )
271
+ },
272
+ coords={
273
+ "lat": np.array([2.0, 1.0]),
274
+ "lon": np.array([0.0, 0.5]),
275
+ "time": np.array([100, 200], dtype="timedelta64[s]"),
276
+ "datetime": xa.Variable(
277
+ ("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]")
278
+ ),
279
+ },
280
+ )
281
+
282
+ data_utils.add_tisr_var(data)
283
+
284
+ self.assertIn(data_utils.TISR, set(data.variables))
285
+
286
+ def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self):
287
+ data = xa.Dataset(
288
+ data_vars={
289
+ "var1": (
290
+ ["batch", "time", "lat", "lon"],
291
+ np.full((2, 2, 2, 2), 8.0),
292
+ )
293
+ },
294
+ coords={
295
+ "lat": np.array([2.0, 1.0]),
296
+ "lon": np.array([0.0, 0.5]),
297
+ "time": np.array([100, 200], dtype="timedelta64[s]"),
298
+ "datetime": xa.Variable(
299
+ ("batch", "time"),
300
+ np.array([[10, 20], [100, 200]], dtype="datetime64[D]"),
301
+ ),
302
+ },
303
+ )
304
+
305
+ with self.assertRaisesRegex(ValueError, r"cannot select a dimension"):
306
+ data_utils.add_tisr_var(data)
307
+
308
+
309
+ if __name__ == "__main__":
310
+ absltest.main()
graphcast/deep_typed_graph_net.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """JAX implementation of Graph Networks Simulator.
15
+
16
+ Generalization to TypedGraphs of the deep Graph Neural Network from:
17
+
18
+ @inproceedings{pfaff2021learning,
19
+ title={Learning Mesh-Based Simulation with Graph Networks},
20
+ author={Pfaff, Tobias and Fortunato, Meire and Sanchez-Gonzalez, Alvaro and
21
+ Battaglia, Peter},
22
+ booktitle={International Conference on Learning Representations},
23
+ year={2021}
24
+ }
25
+
26
+ @inproceedings{sanchez2020learning,
27
+ title={Learning to simulate complex physics with graph networks},
28
+ author={Sanchez-Gonzalez, Alvaro and Godwin, Jonathan and Pfaff, Tobias and
29
+ Ying, Rex and Leskovec, Jure and Battaglia, Peter},
30
+ booktitle={International conference on machine learning},
31
+ pages={8459--8468},
32
+ year={2020},
33
+ organization={PMLR}
34
+ }
35
+ """
36
+
37
+ from typing import Mapping, Optional
38
+
39
+ from graphcast import typed_graph
40
+ from graphcast import typed_graph_net
41
+ import haiku as hk
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import jraph
45
+
46
+
47
+ class DeepTypedGraphNet(hk.Module):
48
+ """Deep Graph Neural Network.
49
+
50
+ It works with TypedGraphs with typed nodes and edges. It runs message
51
+ passing on all of the node sets and all of the edge sets in the graph. For
52
+ each message passing step a `typed_graph_net.InteractionNetwork` is used to
53
+ update the full TypedGraph by using different MLPs for each of the node sets
54
+ and each of the edge sets.
55
+
56
+ If embed_{nodes,edges} is specified the node/edge features will be embedded
57
+ into a fixed dimensionality before running the first step of message passing.
58
+
59
+ If {node,edge}_output_size the final node/edge features will be embedded into
60
+ the specified output size.
61
+
62
+ This class may be used for shared or unshared message passing:
63
+ * num_message_passing_steps = N, num_processor_repetitions = 1, gives
64
+ N layers of message passing with fully unshared weights:
65
+ [W_1, W_2, ... , W_M] (default)
66
+ * num_message_passing_steps = 1, num_processor_repetitions = M, gives
67
+ N layers of message passing with fully shared weights:
68
+ [W_1] * M
69
+ * num_message_passing_steps = N, num_processor_repetitions = M, gives
70
+ M*N layers of message passing with both shared and unshared message passing
71
+ such that the weights used at each iteration are:
72
+ [W_1, W_2, ... , W_N] * M
73
+
74
+ """
75
+
76
+ def __init__(self,
77
+ *,
78
+ node_latent_size: Mapping[str, int],
79
+ edge_latent_size: Mapping[str, int],
80
+ mlp_hidden_size: int,
81
+ mlp_num_hidden_layers: int,
82
+ num_message_passing_steps: int,
83
+ num_processor_repetitions: int = 1,
84
+ embed_nodes: bool = True,
85
+ embed_edges: bool = True,
86
+ node_output_size: Optional[Mapping[str, int]] = None,
87
+ edge_output_size: Optional[Mapping[str, int]] = None,
88
+ include_sent_messages_in_node_update: bool = False,
89
+ use_layer_norm: bool = True,
90
+ activation: str = "relu",
91
+ f32_aggregation: bool = False,
92
+ aggregate_edges_for_nodes_fn: str = "segment_sum",
93
+ aggregate_normalization: Optional[float] = None,
94
+ name: str = "DeepTypedGraphNet"):
95
+ """Inits the model.
96
+
97
+ Args:
98
+ node_latent_size: Size of the node latent representations.
99
+ edge_latent_size: Size of the edge latent representations.
100
+ mlp_hidden_size: Hidden layer size for all MLPs.
101
+ mlp_num_hidden_layers: Number of hidden layers in all MLPs.
102
+ num_message_passing_steps: Number of unshared message passing steps
103
+ in the processor steps.
104
+ num_processor_repetitions: Number of times that the same processor is
105
+ applied sequencially.
106
+ embed_nodes: If False, the node embedder will be omitted.
107
+ embed_edges: If False, the edge embedder will be omitted.
108
+ node_output_size: Size of the output node representations for
109
+ each node type. For node types not specified here, the latent node
110
+ representation from the output of the processor will be returned.
111
+ edge_output_size: Size of the output edge representations for
112
+ each edge type. For edge types not specified here, the latent edge
113
+ representation from the output of the processor will be returned.
114
+ include_sent_messages_in_node_update: Whether to include pooled sent
115
+ messages from each node in the node update.
116
+ use_layer_norm: Whether it uses layer norm or not.
117
+ activation: name of activation function.
118
+ f32_aggregation: Use float32 in the edge aggregation.
119
+ aggregate_edges_for_nodes_fn: function used to aggregate messages to each
120
+ node.
121
+ aggregate_normalization: An optional constant that normalizes the output
122
+ of aggregate_edges_for_nodes_fn. For context, this can be used to
123
+ reduce the shock the model undergoes when switching resolution, which
124
+ increase the number of edges connected to a node. In particular, this is
125
+ useful when using segment_sum, but should not be combined with
126
+ segment_mean.
127
+ name: Name of the model.
128
+ """
129
+
130
+ super().__init__(name=name)
131
+
132
+ self._node_latent_size = node_latent_size
133
+ self._edge_latent_size = edge_latent_size
134
+ self._mlp_hidden_size = mlp_hidden_size
135
+ self._mlp_num_hidden_layers = mlp_num_hidden_layers
136
+ self._num_message_passing_steps = num_message_passing_steps
137
+ self._num_processor_repetitions = num_processor_repetitions
138
+ self._embed_nodes = embed_nodes
139
+ self._embed_edges = embed_edges
140
+ self._node_output_size = node_output_size
141
+ self._edge_output_size = edge_output_size
142
+ self._include_sent_messages_in_node_update = (
143
+ include_sent_messages_in_node_update)
144
+ self._use_layer_norm = use_layer_norm
145
+ self._activation = _get_activation_fn(activation)
146
+ self._initialized = False
147
+ self._f32_aggregation = f32_aggregation
148
+ self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn(
149
+ aggregate_edges_for_nodes_fn)
150
+ self._aggregate_normalization = aggregate_normalization
151
+
152
+ if aggregate_normalization:
153
+ # using aggregate_normalization only makes sense with segment_sum.
154
+ assert aggregate_edges_for_nodes_fn == "segment_sum"
155
+
156
+ def __call__(self,
157
+ input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
158
+ """Forward pass of the learnable dynamics model."""
159
+ self._networks_builder(input_graph)
160
+
161
+ # Embed input features (if applicable).
162
+ latent_graph_0 = self._embed(input_graph)
163
+
164
+ # Do `m` message passing steps in the latent graphs.
165
+ latent_graph_m = self._process(latent_graph_0)
166
+
167
+ # Compute outputs from the last latent graph (if applicable).
168
+ return self._output(latent_graph_m)
169
+
170
+ def _networks_builder(self, graph_template):
171
+ if self._initialized:
172
+ return
173
+ self._initialized = True
174
+
175
+ def build_mlp(name, output_size):
176
+ mlp = hk.nets.MLP(
177
+ output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [
178
+ output_size], name=name + "_mlp", activation=self._activation)
179
+ return jraph.concatenated_args(mlp)
180
+
181
+ def build_mlp_with_maybe_layer_norm(name, output_size):
182
+ network = build_mlp(name, output_size)
183
+ if self._use_layer_norm:
184
+ layer_norm = hk.LayerNorm(
185
+ axis=-1, create_scale=True, create_offset=True,
186
+ name=name + "_layer_norm")
187
+ network = hk.Sequential([network, layer_norm])
188
+ return jraph.concatenated_args(network)
189
+
190
+ # The embedder graph network independently embeds edge and node features.
191
+ if self._embed_edges:
192
+ embed_edge_fn = _build_update_fns_for_edge_types(
193
+ build_mlp_with_maybe_layer_norm,
194
+ graph_template,
195
+ "encoder_edges_",
196
+ output_sizes=self._edge_latent_size)
197
+ else:
198
+ embed_edge_fn = None
199
+ if self._embed_nodes:
200
+ embed_node_fn = _build_update_fns_for_node_types(
201
+ build_mlp_with_maybe_layer_norm,
202
+ graph_template,
203
+ "encoder_nodes_",
204
+ output_sizes=self._node_latent_size)
205
+ else:
206
+ embed_node_fn = None
207
+ embedder_kwargs = dict(
208
+ embed_edge_fn=embed_edge_fn,
209
+ embed_node_fn=embed_node_fn,
210
+ )
211
+ self._embedder_network = typed_graph_net.GraphMapFeatures(
212
+ **embedder_kwargs)
213
+
214
+ if self._f32_aggregation:
215
+ def aggregate_fn(data, *args, **kwargs):
216
+ dtype = data.dtype
217
+ data = data.astype(jnp.float32)
218
+ output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs)
219
+ if self._aggregate_normalization:
220
+ output = output / self._aggregate_normalization
221
+ output = output.astype(dtype)
222
+ return output
223
+
224
+ else:
225
+ def aggregate_fn(data, *args, **kwargs):
226
+ output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs)
227
+ if self._aggregate_normalization:
228
+ output = output / self._aggregate_normalization
229
+ return output
230
+
231
+ # Create `num_message_passing_steps` graph networks with unshared parameters
232
+ # that update the node and edge latent features.
233
+ # Note that we can use `modules.InteractionNetwork` because
234
+ # it also outputs the messages as updated edge latent features.
235
+ self._processor_networks = []
236
+ for step_i in range(self._num_message_passing_steps):
237
+ self._processor_networks.append(
238
+ typed_graph_net.InteractionNetwork(
239
+ update_edge_fn=_build_update_fns_for_edge_types(
240
+ build_mlp_with_maybe_layer_norm,
241
+ graph_template,
242
+ f"processor_edges_{step_i}_",
243
+ output_sizes=self._edge_latent_size),
244
+ update_node_fn=_build_update_fns_for_node_types(
245
+ build_mlp_with_maybe_layer_norm,
246
+ graph_template,
247
+ f"processor_nodes_{step_i}_",
248
+ output_sizes=self._node_latent_size),
249
+ aggregate_edges_for_nodes_fn=aggregate_fn,
250
+ include_sent_messages_in_node_update=(
251
+ self._include_sent_messages_in_node_update),
252
+ ))
253
+
254
+ # The output MLPs converts edge/node latent features into the output sizes.
255
+ output_kwargs = dict(
256
+ embed_edge_fn=_build_update_fns_for_edge_types(
257
+ build_mlp, graph_template, "decoder_edges_", self._edge_output_size)
258
+ if self._edge_output_size else None,
259
+ embed_node_fn=_build_update_fns_for_node_types(
260
+ build_mlp, graph_template, "decoder_nodes_", self._node_output_size)
261
+ if self._node_output_size else None,)
262
+ self._output_network = typed_graph_net.GraphMapFeatures(
263
+ **output_kwargs)
264
+
265
+ def _embed(
266
+ self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
267
+ """Embeds the input graph features into a latent graph."""
268
+
269
+ # Copy the context to all of the node types, if applicable.
270
+ context_features = input_graph.context.features
271
+ if jax.tree_util.tree_leaves(context_features):
272
+ # This code assumes a single input feature array for the context and for
273
+ # each node type.
274
+ assert len(jax.tree_util.tree_leaves(context_features)) == 1
275
+ new_nodes = {}
276
+ for node_set_name, node_set in input_graph.nodes.items():
277
+ node_features = node_set.features
278
+ broadcasted_context = jnp.repeat(
279
+ context_features, node_set.n_node, axis=0,
280
+ total_repeat_length=node_features.shape[0])
281
+ new_nodes[node_set_name] = node_set._replace(
282
+ features=jnp.concatenate(
283
+ [node_features, broadcasted_context], axis=-1))
284
+ input_graph = input_graph._replace(
285
+ nodes=new_nodes,
286
+ context=input_graph.context._replace(features=()))
287
+
288
+ # Embeds the node and edge features.
289
+ latent_graph_0 = self._embedder_network(input_graph)
290
+ return latent_graph_0
291
+
292
+ def _process(
293
+ self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
294
+ """Processes the latent graph with several steps of message passing."""
295
+
296
+ # Do `num_message_passing_steps` with each of the `self._processor_networks`
297
+ # with unshared weights, and repeat that `self._num_processor_repetitions`
298
+ # times.
299
+ latent_graph = latent_graph_0
300
+ for unused_repetition_i in range(self._num_processor_repetitions):
301
+ for processor_network in self._processor_networks:
302
+ latent_graph = self._process_step(processor_network, latent_graph)
303
+
304
+ return latent_graph
305
+
306
+ def _process_step(
307
+ self, processor_network_k,
308
+ latent_graph_prev_k: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
309
+ """Single step of message passing with node/edge residual connections."""
310
+
311
+ # One step of message passing.
312
+ latent_graph_k = processor_network_k(latent_graph_prev_k)
313
+
314
+ # Add residuals.
315
+ nodes_with_residuals = {}
316
+ for k, prev_set in latent_graph_prev_k.nodes.items():
317
+ nodes_with_residuals[k] = prev_set._replace(
318
+ features=prev_set.features + latent_graph_k.nodes[k].features)
319
+
320
+ edges_with_residuals = {}
321
+ for k, prev_set in latent_graph_prev_k.edges.items():
322
+ edges_with_residuals[k] = prev_set._replace(
323
+ features=prev_set.features + latent_graph_k.edges[k].features)
324
+
325
+ latent_graph_k = latent_graph_k._replace(
326
+ nodes=nodes_with_residuals, edges=edges_with_residuals)
327
+ return latent_graph_k
328
+
329
+ def _output(self,
330
+ latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
331
+ """Produces the output from the latent graph."""
332
+ return self._output_network(latent_graph)
333
+
334
+
335
+ def _build_update_fns_for_node_types(
336
+ builder_fn, graph_template, prefix, output_sizes=None):
337
+ """Builds an update function for all node types or a subset of them."""
338
+
339
+ output_fns = {}
340
+ for node_set_name in graph_template.nodes.keys():
341
+ if output_sizes is None:
342
+ # Use the default output size for all types.
343
+ output_size = None
344
+ else:
345
+ # Otherwise, ignore any type that does not have an explicit output size.
346
+ if node_set_name in output_sizes:
347
+ output_size = output_sizes[node_set_name]
348
+ else:
349
+ continue
350
+ output_fns[node_set_name] = builder_fn(
351
+ f"{prefix}{node_set_name}", output_size)
352
+ return output_fns
353
+
354
+
355
+ def _build_update_fns_for_edge_types(
356
+ builder_fn, graph_template, prefix, output_sizes=None):
357
+ """Builds an edge function for all node types or a subset of them."""
358
+ output_fns = {}
359
+ for edge_set_key in graph_template.edges.keys():
360
+ edge_set_name = edge_set_key.name
361
+ if output_sizes is None:
362
+ # Use the default output size for all types.
363
+ output_size = None
364
+ else:
365
+ # Otherwise, ignore any type that does not have an explicit output size.
366
+ if edge_set_name in output_sizes:
367
+ output_size = output_sizes[edge_set_name]
368
+ else:
369
+ continue
370
+ output_fns[edge_set_name] = builder_fn(
371
+ f"{prefix}{edge_set_name}", output_size)
372
+ return output_fns
373
+
374
+
375
+ def _get_activation_fn(name):
376
+ """Return activation function corresponding to function_name."""
377
+ if name == "identity":
378
+ return lambda x: x
379
+ if hasattr(jax.nn, name):
380
+ return getattr(jax.nn, name)
381
+ if hasattr(jnp, name):
382
+ return getattr(jnp, name)
383
+ raise ValueError(f"Unknown activation function {name} specified.")
384
+
385
+
386
+ def _get_aggregate_edges_for_nodes_fn(name):
387
+ """Return aggregate_edges_for_nodes_fn corresponding to function_name."""
388
+ if hasattr(jraph, name):
389
+ return getattr(jraph, name)
390
+ raise ValueError(
391
+ f"Unknown aggregate_edges_for_nodes_fn function {name} specified.")
graphcast/graphcast.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """A predictor that runs multiple graph neural networks on mesh data.
15
+
16
+ It learns to interpolate between the grid and the mesh nodes, with the loss
17
+ and the rollouts ultimately computed at the grid level.
18
+
19
+ It uses ideas similar to those in Keisler (2022):
20
+
21
+ Reference:
22
+ https://arxiv.org/pdf/2202.07575.pdf
23
+
24
+ It assumes data across time and level is stacked, and operates only operates in
25
+ a 2D mesh over latitudes and longitudes.
26
+ """
27
+
28
+ from typing import Any, Callable, Mapping, Optional
29
+
30
+ import chex
31
+ from graphcast import deep_typed_graph_net
32
+ from graphcast import grid_mesh_connectivity
33
+ from graphcast import icosahedral_mesh
34
+ from graphcast import losses
35
+ from graphcast import model_utils
36
+ from graphcast import predictor_base
37
+ from graphcast import typed_graph
38
+ from graphcast import xarray_jax
39
+ import jax.numpy as jnp
40
+ import jraph
41
+ import numpy as np
42
+ import xarray
43
+
44
+ Kwargs = Mapping[str, Any]
45
+
46
+ GNN = Callable[[jraph.GraphsTuple], jraph.GraphsTuple]
47
+
48
+
49
+ # https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5
50
+ PRESSURE_LEVELS_ERA5_37 = (
51
+ 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300,
52
+ 350, 400, 450, 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900,
53
+ 925, 950, 975, 1000)
54
+
55
+ # https://www.ecmwf.int/en/forecasts/datasets/set-i
56
+ PRESSURE_LEVELS_HRES_25 = (
57
+ 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 400, 500, 600,
58
+ 700, 800, 850, 900, 925, 950, 1000)
59
+
60
+ # https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2020MS002203
61
+ PRESSURE_LEVELS_WEATHERBENCH_13 = (
62
+ 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000)
63
+
64
+ PRESSURE_LEVELS = {
65
+ 13: PRESSURE_LEVELS_WEATHERBENCH_13,
66
+ 25: PRESSURE_LEVELS_HRES_25,
67
+ 37: PRESSURE_LEVELS_ERA5_37,
68
+ }
69
+
70
+ # The list of all possible atmospheric variables. Taken from:
71
+ # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Table9
72
+ ALL_ATMOSPHERIC_VARS = (
73
+ "potential_vorticity",
74
+ "specific_rain_water_content",
75
+ "specific_snow_water_content",
76
+ "geopotential",
77
+ "temperature",
78
+ "u_component_of_wind",
79
+ "v_component_of_wind",
80
+ "specific_humidity",
81
+ "vertical_velocity",
82
+ "vorticity",
83
+ "divergence",
84
+ "relative_humidity",
85
+ "ozone_mass_mixing_ratio",
86
+ "specific_cloud_liquid_water_content",
87
+ "specific_cloud_ice_water_content",
88
+ "fraction_of_cloud_cover",
89
+ )
90
+
91
+ TARGET_SURFACE_VARS = (
92
+ "2m_temperature",
93
+ "mean_sea_level_pressure",
94
+ "10m_v_component_of_wind",
95
+ "10m_u_component_of_wind",
96
+ "total_precipitation_6hr",
97
+ )
98
+ TARGET_SURFACE_NO_PRECIP_VARS = (
99
+ "2m_temperature",
100
+ "mean_sea_level_pressure",
101
+ "10m_v_component_of_wind",
102
+ "10m_u_component_of_wind",
103
+ )
104
+ TARGET_ATMOSPHERIC_VARS = (
105
+ "temperature",
106
+ "geopotential",
107
+ "u_component_of_wind",
108
+ "v_component_of_wind",
109
+ "vertical_velocity",
110
+ "specific_humidity",
111
+ )
112
+ TARGET_ATMOSPHERIC_NO_W_VARS = (
113
+ "temperature",
114
+ "geopotential",
115
+ "u_component_of_wind",
116
+ "v_component_of_wind",
117
+ "specific_humidity",
118
+ )
119
+ EXTERNAL_FORCING_VARS = (
120
+ "toa_incident_solar_radiation",
121
+ )
122
+ GENERATED_FORCING_VARS = (
123
+ "year_progress_sin",
124
+ "year_progress_cos",
125
+ "day_progress_sin",
126
+ "day_progress_cos",
127
+ )
128
+ FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS
129
+ STATIC_VARS = (
130
+ "geopotential_at_surface",
131
+ "land_sea_mask",
132
+ )
133
+
134
+
135
+ @chex.dataclass(frozen=True, eq=True)
136
+ class TaskConfig:
137
+ """Defines inputs and targets on which a model is trained and/or evaluated."""
138
+ input_variables: tuple[str, ...]
139
+ # Target variables which the model is expected to predict.
140
+ target_variables: tuple[str, ...]
141
+ forcing_variables: tuple[str, ...]
142
+ pressure_levels: tuple[int, ...]
143
+ input_duration: str
144
+
145
+ TASK = TaskConfig(
146
+ input_variables=(
147
+ TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
148
+ STATIC_VARS),
149
+ target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
150
+ forcing_variables=FORCING_VARS,
151
+ pressure_levels=PRESSURE_LEVELS_ERA5_37,
152
+ input_duration="12h",
153
+ )
154
+ TASK_13 = TaskConfig(
155
+ input_variables=(
156
+ TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
157
+ STATIC_VARS),
158
+ target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
159
+ forcing_variables=FORCING_VARS,
160
+ pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13,
161
+ input_duration="12h",
162
+ )
163
+ TASK_13_PRECIP_OUT = TaskConfig(
164
+ input_variables=(
165
+ TARGET_SURFACE_NO_PRECIP_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
166
+ STATIC_VARS),
167
+ target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
168
+ forcing_variables=FORCING_VARS,
169
+ pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13,
170
+ input_duration="12h",
171
+ )
172
+
173
+
174
+ @chex.dataclass(frozen=True, eq=True)
175
+ class ModelConfig:
176
+ """Defines the architecture of the GraphCast neural network architecture.
177
+
178
+ Properties:
179
+ resolution: The resolution of the data, in degrees (e.g. 0.25 or 1.0).
180
+ mesh_size: How many refinements to do on the multi-mesh.
181
+ gnn_msg_steps: How many Graph Network message passing steps to do.
182
+ latent_size: How many latent features to include in the various MLPs.
183
+ hidden_layers: How many hidden layers for each MLP.
184
+ radius_query_fraction_edge_length: Scalar that will be multiplied by the
185
+ length of the longest edge of the finest mesh to define the radius of
186
+ connectivity to use in the Grid2Mesh graph. Reasonable values are
187
+ between 0.6 and 1. 0.6 reduces the number of grid points feeding into
188
+ multiple mesh nodes and therefore reduces edge count and memory use, but
189
+ 1 gives better predictions.
190
+ mesh2grid_edge_normalization_factor: Allows explicitly controlling edge
191
+ normalization for mesh2grid edges. If None, defaults to max edge length.
192
+ This supports using pre-trained model weights with a different graph
193
+ structure to what it was trained on.
194
+ """
195
+ resolution: float
196
+ mesh_size: int
197
+ latent_size: int
198
+ gnn_msg_steps: int
199
+ hidden_layers: int
200
+ radius_query_fraction_edge_length: float
201
+ mesh2grid_edge_normalization_factor: Optional[float] = None
202
+
203
+
204
+ @chex.dataclass(frozen=True, eq=True)
205
+ class CheckPoint:
206
+ params: dict[str, Any]
207
+ model_config: ModelConfig
208
+ task_config: TaskConfig
209
+ description: str
210
+ license: str
211
+
212
+
213
+ class GraphCast(predictor_base.Predictor):
214
+ """GraphCast Predictor.
215
+
216
+ The model works on graphs that take into account:
217
+ * Mesh nodes: nodes for the vertices of the mesh.
218
+ * Grid nodes: nodes for the points of the grid.
219
+ * Nodes: When referring to just "nodes", this means the joint set of
220
+ both mesh nodes, concatenated with grid nodes.
221
+
222
+ The model works with 3 graphs:
223
+ * Grid2Mesh graph: Graph that contains all nodes. This graph is strictly
224
+ bipartite with edges going from grid nodes to mesh nodes using a
225
+ fixed radius query. The grid2mesh_gnn will operate in this graph. The output
226
+ of this stage will be a latent representation for the mesh nodes, and a
227
+ latent representation for the grid nodes.
228
+ * Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will
229
+ operate in this graph. It will update the latent state of the mesh nodes
230
+ only.
231
+ * Mesh2Grid graph: Graph that contains all nodes. This graph is strictly
232
+ bipartite with edges going from mesh nodes to grid nodes such that each grid
233
+ nodes is connected to 3 nodes of the mesh triangular face that contains
234
+ the grid points. The mesh2grid_gnn will operate in this graph. It will
235
+ process the updated latent state of the mesh nodes, and the latent state
236
+ of the grid nodes, to produce the final output for the grid nodes.
237
+
238
+ The model is built on top of `TypedGraph`s so the different types of nodes and
239
+ edges can be stored and treated separately.
240
+
241
+ """
242
+
243
+ def __init__(self, model_config: ModelConfig, task_config: TaskConfig):
244
+ """Initializes the predictor."""
245
+ self._spatial_features_kwargs = dict(
246
+ add_node_positions=False,
247
+ add_node_latitude=True,
248
+ add_node_longitude=True,
249
+ add_relative_positions=True,
250
+ relative_longitude_local_coordinates=True,
251
+ relative_latitude_local_coordinates=True,
252
+ )
253
+
254
+ # Specification of the multimesh.
255
+ self._meshes = (
256
+ icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
257
+ splits=model_config.mesh_size))
258
+
259
+ # Encoder, which moves data from the grid to the mesh with a single message
260
+ # passing step.
261
+ self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
262
+ embed_nodes=True, # Embed raw features of the grid and mesh nodes.
263
+ embed_edges=True, # Embed raw features of the grid2mesh edges.
264
+ edge_latent_size=dict(grid2mesh=model_config.latent_size),
265
+ node_latent_size=dict(
266
+ mesh_nodes=model_config.latent_size,
267
+ grid_nodes=model_config.latent_size),
268
+ mlp_hidden_size=model_config.latent_size,
269
+ mlp_num_hidden_layers=model_config.hidden_layers,
270
+ num_message_passing_steps=1,
271
+ use_layer_norm=True,
272
+ include_sent_messages_in_node_update=False,
273
+ activation="swish",
274
+ f32_aggregation=True,
275
+ aggregate_normalization=None,
276
+ name="grid2mesh_gnn",
277
+ )
278
+
279
+ # Processor, which performs message passing on the multi-mesh.
280
+ self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
281
+ embed_nodes=False, # Node features already embdded by previous layers.
282
+ embed_edges=True, # Embed raw features of the multi-mesh edges.
283
+ node_latent_size=dict(mesh_nodes=model_config.latent_size),
284
+ edge_latent_size=dict(mesh=model_config.latent_size),
285
+ mlp_hidden_size=model_config.latent_size,
286
+ mlp_num_hidden_layers=model_config.hidden_layers,
287
+ num_message_passing_steps=model_config.gnn_msg_steps,
288
+ use_layer_norm=True,
289
+ include_sent_messages_in_node_update=False,
290
+ activation="swish",
291
+ f32_aggregation=False,
292
+ name="mesh_gnn",
293
+ )
294
+
295
+ num_surface_vars = len(
296
+ set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS))
297
+ num_atmospheric_vars = len(
298
+ set(task_config.target_variables) & set(ALL_ATMOSPHERIC_VARS))
299
+ num_outputs = (num_surface_vars +
300
+ len(task_config.pressure_levels) * num_atmospheric_vars)
301
+
302
+ # Decoder, which moves data from the mesh back into the grid with a single
303
+ # message passing step.
304
+ self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet(
305
+ # Require a specific node dimensionaly for the grid node outputs.
306
+ node_output_size=dict(grid_nodes=num_outputs),
307
+ embed_nodes=False, # Node features already embdded by previous layers.
308
+ embed_edges=True, # Embed raw features of the mesh2grid edges.
309
+ edge_latent_size=dict(mesh2grid=model_config.latent_size),
310
+ node_latent_size=dict(
311
+ mesh_nodes=model_config.latent_size,
312
+ grid_nodes=model_config.latent_size),
313
+ mlp_hidden_size=model_config.latent_size,
314
+ mlp_num_hidden_layers=model_config.hidden_layers,
315
+ num_message_passing_steps=1,
316
+ use_layer_norm=True,
317
+ include_sent_messages_in_node_update=False,
318
+ activation="swish",
319
+ f32_aggregation=False,
320
+ name="mesh2grid_gnn",
321
+ )
322
+
323
+ # Obtain the query radius in absolute units for the unit-sphere for the
324
+ # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`.
325
+ self._query_radius = (_get_max_edge_distance(self._finest_mesh)
326
+ * model_config.radius_query_fraction_edge_length)
327
+ self._mesh2grid_edge_normalization_factor = (
328
+ model_config.mesh2grid_edge_normalization_factor
329
+ )
330
+
331
+ # Other initialization is delayed until the first call (`_maybe_init`)
332
+ # when we get some sample data so we know the lat/lon values.
333
+ self._initialized = False
334
+
335
+ # A "_init_mesh_properties":
336
+ # This one could be initialized at init but we delay it for consistency too.
337
+ self._num_mesh_nodes = None # num_mesh_nodes
338
+ self._mesh_nodes_lat = None # [num_mesh_nodes]
339
+ self._mesh_nodes_lon = None # [num_mesh_nodes]
340
+
341
+ # A "_init_grid_properties":
342
+ self._grid_lat = None # [num_lat_points]
343
+ self._grid_lon = None # [num_lon_points]
344
+ self._num_grid_nodes = None # num_lat_points * num_lon_points
345
+ self._grid_nodes_lat = None # [num_grid_nodes]
346
+ self._grid_nodes_lon = None # [num_grid_nodes]
347
+
348
+ # A "_init_{grid2mesh,processor,mesh2grid}_graph"
349
+ self._grid2mesh_graph_structure = None
350
+ self._mesh_graph_structure = None
351
+ self._mesh2grid_graph_structure = None
352
+
353
+ @property
354
+ def _finest_mesh(self):
355
+ return self._meshes[-1]
356
+
357
+ def __call__(self,
358
+ inputs: xarray.Dataset,
359
+ targets_template: xarray.Dataset,
360
+ forcings: xarray.Dataset,
361
+ is_training: bool = False,
362
+ ) -> xarray.Dataset:
363
+ self._maybe_init(inputs)
364
+
365
+ # Convert all input data into flat vectors for each of the grid nodes.
366
+ # xarray (batch, time, lat, lon, level, multiple vars, forcings)
367
+ # -> [num_grid_nodes, batch, num_channels]
368
+ grid_node_features = self._inputs_to_grid_node_features(inputs, forcings)
369
+
370
+ # Transfer data for the grid to the mesh,
371
+ # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size]
372
+ (latent_mesh_nodes, latent_grid_nodes
373
+ ) = self._run_grid2mesh_gnn(grid_node_features)
374
+
375
+ # Run message passing in the multimesh.
376
+ # [num_mesh_nodes, batch, latent_size]
377
+ updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes)
378
+
379
+ # Transfer data frome the mesh to the grid.
380
+ # [num_grid_nodes, batch, output_size]
381
+ output_grid_nodes = self._run_mesh2grid_gnn(
382
+ updated_latent_mesh_nodes, latent_grid_nodes)
383
+
384
+ # Conver output flat vectors for the grid nodes to the format of the output.
385
+ # [num_grid_nodes, batch, output_size] ->
386
+ # xarray (batch, one time step, lat, lon, level, multiple vars)
387
+ return self._grid_node_outputs_to_prediction(
388
+ output_grid_nodes, targets_template)
389
+
390
+ def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
391
+ self,
392
+ inputs: xarray.Dataset,
393
+ targets: xarray.Dataset,
394
+ forcings: xarray.Dataset,
395
+ ) -> tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]:
396
+ # Forward pass.
397
+ predictions = self(
398
+ inputs, targets_template=targets, forcings=forcings, is_training=True)
399
+ # Compute loss.
400
+ loss = losses.weighted_mse_per_level(
401
+ predictions, targets,
402
+ per_variable_weights={
403
+ # Any variables not specified here are weighted as 1.0.
404
+ # A single-level variable, but an important headline variable
405
+ # and also one which we have struggled to get good performance
406
+ # on at short lead times, so leaving it weighted at 1.0, equal
407
+ # to the multi-level variables:
408
+ "2m_temperature": 1.0,
409
+ # New single-level variables, which we don't weight too highly
410
+ # to avoid hurting performance on other variables.
411
+ "10m_u_component_of_wind": 0.1,
412
+ "10m_v_component_of_wind": 0.1,
413
+ "mean_sea_level_pressure": 0.1,
414
+ "total_precipitation_6hr": 0.1,
415
+ })
416
+ return loss, predictions # pytype: disable=bad-return-type # jax-ndarray
417
+
418
+ def loss( # pytype: disable=signature-mismatch # jax-ndarray
419
+ self,
420
+ inputs: xarray.Dataset,
421
+ targets: xarray.Dataset,
422
+ forcings: xarray.Dataset,
423
+ ) -> predictor_base.LossAndDiagnostics:
424
+ loss, _ = self.loss_and_predictions(inputs, targets, forcings)
425
+ return loss # pytype: disable=bad-return-type # jax-ndarray
426
+
427
+ def _maybe_init(self, sample_inputs: xarray.Dataset):
428
+ """Inits everything that has a dependency on the input coordinates."""
429
+ if not self._initialized:
430
+ self._init_mesh_properties()
431
+ self._init_grid_properties(
432
+ grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon)
433
+ self._grid2mesh_graph_structure = self._init_grid2mesh_graph()
434
+ self._mesh_graph_structure = self._init_mesh_graph()
435
+ self._mesh2grid_graph_structure = self._init_mesh2grid_graph()
436
+
437
+ self._initialized = True
438
+
439
+ def _init_mesh_properties(self):
440
+ """Inits static properties that have to do with mesh nodes."""
441
+ self._num_mesh_nodes = self._finest_mesh.vertices.shape[0]
442
+ mesh_phi, mesh_theta = model_utils.cartesian_to_spherical(
443
+ self._finest_mesh.vertices[:, 0],
444
+ self._finest_mesh.vertices[:, 1],
445
+ self._finest_mesh.vertices[:, 2])
446
+ (
447
+ mesh_nodes_lat,
448
+ mesh_nodes_lon,
449
+ ) = model_utils.spherical_to_lat_lon(
450
+ phi=mesh_phi, theta=mesh_theta)
451
+ # Convert to f32 to ensure the lat/lon features aren't in f64.
452
+ self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32)
453
+ self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32)
454
+
455
+ def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray):
456
+ """Inits static properties that have to do with grid nodes."""
457
+ self._grid_lat = grid_lat.astype(np.float32)
458
+ self._grid_lon = grid_lon.astype(np.float32)
459
+ # Initialized the counters.
460
+ self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0]
461
+
462
+ # Initialize lat and lon for the grid.
463
+ grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat)
464
+ self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32)
465
+ self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32)
466
+
467
+ def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph:
468
+ """Build Grid2Mesh graph."""
469
+
470
+ # Create some edges according to distance between mesh and grid nodes.
471
+ assert self._grid_lat is not None and self._grid_lon is not None
472
+ (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices(
473
+ grid_latitude=self._grid_lat,
474
+ grid_longitude=self._grid_lon,
475
+ mesh=self._finest_mesh,
476
+ radius=self._query_radius)
477
+
478
+ # Edges sending info from grid to mesh.
479
+ senders = grid_indices
480
+ receivers = mesh_indices
481
+
482
+ # Precompute structural node and edge features according to config options.
483
+ # Structural features are those that depend on the fixed values of the
484
+ # latitude and longitudes of the nodes.
485
+ (senders_node_features, receivers_node_features,
486
+ edge_features) = model_utils.get_bipartite_graph_spatial_features(
487
+ senders_node_lat=self._grid_nodes_lat,
488
+ senders_node_lon=self._grid_nodes_lon,
489
+ receivers_node_lat=self._mesh_nodes_lat,
490
+ receivers_node_lon=self._mesh_nodes_lon,
491
+ senders=senders,
492
+ receivers=receivers,
493
+ edge_normalization_factor=None,
494
+ **self._spatial_features_kwargs,
495
+ )
496
+
497
+ n_grid_node = np.array([self._num_grid_nodes])
498
+ n_mesh_node = np.array([self._num_mesh_nodes])
499
+ n_edge = np.array([mesh_indices.shape[0]])
500
+ grid_node_set = typed_graph.NodeSet(
501
+ n_node=n_grid_node, features=senders_node_features)
502
+ mesh_node_set = typed_graph.NodeSet(
503
+ n_node=n_mesh_node, features=receivers_node_features)
504
+ edge_set = typed_graph.EdgeSet(
505
+ n_edge=n_edge,
506
+ indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
507
+ features=edge_features)
508
+ nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
509
+ edges = {
510
+ typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")):
511
+ edge_set
512
+ }
513
+ grid2mesh_graph = typed_graph.TypedGraph(
514
+ context=typed_graph.Context(n_graph=np.array([1]), features=()),
515
+ nodes=nodes,
516
+ edges=edges)
517
+ return grid2mesh_graph
518
+
519
+ def _init_mesh_graph(self) -> typed_graph.TypedGraph:
520
+ """Build Mesh graph."""
521
+ merged_mesh = icosahedral_mesh.merge_meshes(self._meshes)
522
+
523
+ # Work simply on the mesh edges.
524
+ senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces)
525
+
526
+ # Precompute structural node and edge features according to config options.
527
+ # Structural features are those that depend on the fixed values of the
528
+ # latitude and longitudes of the nodes.
529
+ assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
530
+ node_features, edge_features = model_utils.get_graph_spatial_features(
531
+ node_lat=self._mesh_nodes_lat,
532
+ node_lon=self._mesh_nodes_lon,
533
+ senders=senders,
534
+ receivers=receivers,
535
+ **self._spatial_features_kwargs,
536
+ )
537
+
538
+ n_mesh_node = np.array([self._num_mesh_nodes])
539
+ n_edge = np.array([senders.shape[0]])
540
+ assert n_mesh_node == len(node_features)
541
+ mesh_node_set = typed_graph.NodeSet(
542
+ n_node=n_mesh_node, features=node_features)
543
+ edge_set = typed_graph.EdgeSet(
544
+ n_edge=n_edge,
545
+ indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
546
+ features=edge_features)
547
+ nodes = {"mesh_nodes": mesh_node_set}
548
+ edges = {
549
+ typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set
550
+ }
551
+ mesh_graph = typed_graph.TypedGraph(
552
+ context=typed_graph.Context(n_graph=np.array([1]), features=()),
553
+ nodes=nodes,
554
+ edges=edges)
555
+
556
+ return mesh_graph
557
+
558
+ def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph:
559
+ """Build Mesh2Grid graph."""
560
+
561
+ # Create some edges according to how the grid nodes are contained by
562
+ # mesh triangles.
563
+ (grid_indices,
564
+ mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices(
565
+ grid_latitude=self._grid_lat,
566
+ grid_longitude=self._grid_lon,
567
+ mesh=self._finest_mesh)
568
+
569
+ # Edges sending info from mesh to grid.
570
+ senders = mesh_indices
571
+ receivers = grid_indices
572
+
573
+ # Precompute structural node and edge features according to config options.
574
+ assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
575
+ (senders_node_features, receivers_node_features,
576
+ edge_features) = model_utils.get_bipartite_graph_spatial_features(
577
+ senders_node_lat=self._mesh_nodes_lat,
578
+ senders_node_lon=self._mesh_nodes_lon,
579
+ receivers_node_lat=self._grid_nodes_lat,
580
+ receivers_node_lon=self._grid_nodes_lon,
581
+ senders=senders,
582
+ receivers=receivers,
583
+ edge_normalization_factor=self._mesh2grid_edge_normalization_factor,
584
+ **self._spatial_features_kwargs,
585
+ )
586
+
587
+ n_grid_node = np.array([self._num_grid_nodes])
588
+ n_mesh_node = np.array([self._num_mesh_nodes])
589
+ n_edge = np.array([senders.shape[0]])
590
+ grid_node_set = typed_graph.NodeSet(
591
+ n_node=n_grid_node, features=receivers_node_features)
592
+ mesh_node_set = typed_graph.NodeSet(
593
+ n_node=n_mesh_node, features=senders_node_features)
594
+ edge_set = typed_graph.EdgeSet(
595
+ n_edge=n_edge,
596
+ indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
597
+ features=edge_features)
598
+ nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
599
+ edges = {
600
+ typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")):
601
+ edge_set
602
+ }
603
+ mesh2grid_graph = typed_graph.TypedGraph(
604
+ context=typed_graph.Context(n_graph=np.array([1]), features=()),
605
+ nodes=nodes,
606
+ edges=edges)
607
+ return mesh2grid_graph
608
+
609
+ def _run_grid2mesh_gnn(self, grid_node_features: chex.Array,
610
+ ) -> tuple[chex.Array, chex.Array]:
611
+ """Runs the grid2mesh_gnn, extracting latent mesh and grid nodes."""
612
+
613
+ # Concatenate node structural features with input features.
614
+ batch_size = grid_node_features.shape[1]
615
+
616
+ grid2mesh_graph = self._grid2mesh_graph_structure
617
+ assert grid2mesh_graph is not None
618
+ grid_nodes = grid2mesh_graph.nodes["grid_nodes"]
619
+ mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"]
620
+ new_grid_nodes = grid_nodes._replace(
621
+ features=jnp.concatenate([
622
+ grid_node_features,
623
+ _add_batch_second_axis(
624
+ grid_nodes.features.astype(grid_node_features.dtype),
625
+ batch_size)
626
+ ],
627
+ axis=-1))
628
+
629
+ # To make sure capacity of the embedded is identical for the grid nodes and
630
+ # the mesh nodes, we also append some dummy zero input features for the
631
+ # mesh nodes.
632
+ dummy_mesh_node_features = jnp.zeros(
633
+ (self._num_mesh_nodes,) + grid_node_features.shape[1:],
634
+ dtype=grid_node_features.dtype)
635
+ new_mesh_nodes = mesh_nodes._replace(
636
+ features=jnp.concatenate([
637
+ dummy_mesh_node_features,
638
+ _add_batch_second_axis(
639
+ mesh_nodes.features.astype(dummy_mesh_node_features.dtype),
640
+ batch_size)
641
+ ],
642
+ axis=-1))
643
+
644
+ # Broadcast edge structural features to the required batch size.
645
+ grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh")
646
+ edges = grid2mesh_graph.edges[grid2mesh_edges_key]
647
+
648
+ new_edges = edges._replace(
649
+ features=_add_batch_second_axis(
650
+ edges.features.astype(dummy_mesh_node_features.dtype), batch_size))
651
+
652
+ input_graph = self._grid2mesh_graph_structure._replace(
653
+ edges={grid2mesh_edges_key: new_edges},
654
+ nodes={
655
+ "grid_nodes": new_grid_nodes,
656
+ "mesh_nodes": new_mesh_nodes
657
+ })
658
+
659
+ # Run the GNN.
660
+ grid2mesh_out = self._grid2mesh_gnn(input_graph)
661
+ latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features
662
+ latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features
663
+ return latent_mesh_nodes, latent_grid_nodes
664
+
665
+ def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array:
666
+ """Runs the mesh_gnn, extracting updated latent mesh nodes."""
667
+
668
+ # Add the structural edge features of this graph. Note we don't need
669
+ # to add the structural node features, because these are already part of
670
+ # the latent state, via the original Grid2Mesh gnn, however, we need
671
+ # the edge ones, because it is the first time we are seeing this particular
672
+ # set of edges.
673
+ batch_size = latent_mesh_nodes.shape[1]
674
+
675
+ mesh_graph = self._mesh_graph_structure
676
+ assert mesh_graph is not None
677
+ mesh_edges_key = mesh_graph.edge_key_by_name("mesh")
678
+ edges = mesh_graph.edges[mesh_edges_key]
679
+
680
+ # We are assuming here that the mesh gnn uses a single set of edge keys
681
+ # named "mesh" for the edges and that it uses a single set of nodes named
682
+ # "mesh_nodes"
683
+ msg = ("The setup currently requires to only have one kind of edge in the"
684
+ " mesh GNN.")
685
+ assert len(mesh_graph.edges) == 1, msg
686
+
687
+ new_edges = edges._replace(
688
+ features=_add_batch_second_axis(
689
+ edges.features.astype(latent_mesh_nodes.dtype), batch_size))
690
+
691
+ nodes = mesh_graph.nodes["mesh_nodes"]
692
+ nodes = nodes._replace(features=latent_mesh_nodes)
693
+
694
+ input_graph = mesh_graph._replace(
695
+ edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes})
696
+
697
+ # Run the GNN.
698
+ return self._mesh_gnn(input_graph).nodes["mesh_nodes"].features
699
+
700
+ def _run_mesh2grid_gnn(self,
701
+ updated_latent_mesh_nodes: chex.Array,
702
+ latent_grid_nodes: chex.Array,
703
+ ) -> chex.Array:
704
+ """Runs the mesh2grid_gnn, extracting the output grid nodes."""
705
+
706
+ # Add the structural edge features of this graph. Note we don't need
707
+ # to add the structural node features, because these are already part of
708
+ # the latent state, via the original Grid2Mesh gnn, however, we need
709
+ # the edge ones, because it is the first time we are seeing this particular
710
+ # set of edges.
711
+ batch_size = updated_latent_mesh_nodes.shape[1]
712
+
713
+ mesh2grid_graph = self._mesh2grid_graph_structure
714
+ assert mesh2grid_graph is not None
715
+ mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"]
716
+ grid_nodes = mesh2grid_graph.nodes["grid_nodes"]
717
+ new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes)
718
+ new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes)
719
+ mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid")
720
+ edges = mesh2grid_graph.edges[mesh2grid_key]
721
+
722
+ new_edges = edges._replace(
723
+ features=_add_batch_second_axis(
724
+ edges.features.astype(latent_grid_nodes.dtype), batch_size))
725
+
726
+ input_graph = mesh2grid_graph._replace(
727
+ edges={mesh2grid_key: new_edges},
728
+ nodes={
729
+ "mesh_nodes": new_mesh_nodes,
730
+ "grid_nodes": new_grid_nodes
731
+ })
732
+
733
+ # Run the GNN.
734
+ output_graph = self._mesh2grid_gnn(input_graph)
735
+ output_grid_nodes = output_graph.nodes["grid_nodes"].features
736
+
737
+ return output_grid_nodes
738
+
739
+ def _inputs_to_grid_node_features(
740
+ self,
741
+ inputs: xarray.Dataset,
742
+ forcings: xarray.Dataset,
743
+ ) -> chex.Array:
744
+ """xarrays -> [num_grid_nodes, batch, num_channels]."""
745
+
746
+ # xarray `Dataset` (batch, time, lat, lon, level, multiple vars)
747
+ # to xarray `DataArray` (batch, lat, lon, channels)
748
+ stacked_inputs = model_utils.dataset_to_stacked(inputs)
749
+ stacked_forcings = model_utils.dataset_to_stacked(forcings)
750
+ stacked_inputs = xarray.concat(
751
+ [stacked_inputs, stacked_forcings], dim="channels")
752
+
753
+ # xarray `DataArray` (batch, lat, lon, channels)
754
+ # to single numpy array with shape [lat_lon_node, batch, channels]
755
+ grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes(
756
+ stacked_inputs)
757
+ return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape(
758
+ (-1,) + grid_xarray_lat_lon_leading.data.shape[2:])
759
+
760
+ def _grid_node_outputs_to_prediction(
761
+ self,
762
+ grid_node_outputs: chex.Array,
763
+ targets_template: xarray.Dataset,
764
+ ) -> xarray.Dataset:
765
+ """[num_grid_nodes, batch, num_outputs] -> xarray."""
766
+
767
+ # numpy array with shape [lat_lon_node, batch, channels]
768
+ # to xarray `DataArray` (batch, lat, lon, channels)
769
+ assert self._grid_lat is not None and self._grid_lon is not None
770
+ grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0])
771
+ grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
772
+ grid_shape + grid_node_outputs.shape[1:])
773
+ dims = ("lat", "lon", "batch", "channels")
774
+ grid_xarray_lat_lon_leading = xarray_jax.DataArray(
775
+ data=grid_outputs_lat_lon_leading,
776
+ dims=dims)
777
+ grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading)
778
+
779
+ # xarray `DataArray` (batch, lat, lon, channels)
780
+ # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars)
781
+ return model_utils.stacked_to_dataset(
782
+ grid_xarray.variable, targets_template)
783
+
784
+
785
+ def _add_batch_second_axis(data, batch_size):
786
+ # data [leading_dim, trailing_dim]
787
+ assert data.ndim == 2
788
+ ones = jnp.ones([batch_size, 1], dtype=data.dtype)
789
+ return data[:, None] * ones # [leading_dim, batch, trailing_dim]
790
+
791
+
792
+ def _get_max_edge_distance(mesh):
793
+ senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces)
794
+ edge_distances = np.linalg.norm(
795
+ mesh.vertices[senders] - mesh.vertices[receivers], axis=-1)
796
+ return edge_distances.max()
graphcast/grid_mesh_connectivity.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tools for converting from regular grids on a sphere, to triangular meshes."""
15
+
16
+ from graphcast import icosahedral_mesh
17
+ import numpy as np
18
+ import scipy
19
+ import trimesh
20
+
21
+
22
+ def _grid_lat_lon_to_coordinates(
23
+ grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
24
+ """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
25
+ # Convert to spherical coordinates phi and theta defined in the grid.
26
+ # Each [num_latitude_points, num_longitude_points]
27
+ phi_grid, theta_grid = np.meshgrid(
28
+ np.deg2rad(grid_longitude),
29
+ np.deg2rad(90 - grid_latitude))
30
+
31
+ # [num_latitude_points, num_longitude_points, 3]
32
+ # Note this assumes unit radius, since for now we model the earth as a
33
+ # sphere of unit radius, and keep any vertical dimension as a regular grid.
34
+ return np.stack(
35
+ [np.cos(phi_grid)*np.sin(theta_grid),
36
+ np.sin(phi_grid)*np.sin(theta_grid),
37
+ np.cos(theta_grid)], axis=-1)
38
+
39
+
40
+ def radius_query_indices(
41
+ *,
42
+ grid_latitude: np.ndarray,
43
+ grid_longitude: np.ndarray,
44
+ mesh: icosahedral_mesh.TriangularMesh,
45
+ radius: float) -> tuple[np.ndarray, np.ndarray]:
46
+ """Returns mesh-grid edge indices for radius query.
47
+
48
+ Args:
49
+ grid_latitude: Latitude values for the grid [num_lat_points]
50
+ grid_longitude: Longitude values for the grid [num_lon_points]
51
+ mesh: Mesh object.
52
+ radius: Radius of connectivity in R3. for a sphere of unit radius.
53
+
54
+ Returns:
55
+ tuple with `grid_indices` and `mesh_indices` indicating edges between the
56
+ grid and the mesh such that the distances in a straight line (not geodesic)
57
+ are smaller than or equal to `radius`.
58
+ * grid_indices: Indices of shape [num_edges], that index into a
59
+ [num_lat_points, num_lon_points] grid, after flattening the leading axes.
60
+ * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
61
+ """
62
+
63
+ # [num_grid_points=num_lat_points * num_lon_points, 3]
64
+ grid_positions = _grid_lat_lon_to_coordinates(
65
+ grid_latitude, grid_longitude).reshape([-1, 3])
66
+
67
+ # [num_mesh_points, 3]
68
+ mesh_positions = mesh.vertices
69
+ kd_tree = scipy.spatial.cKDTree(mesh_positions)
70
+
71
+ # [num_grid_points, num_mesh_points_per_grid_point]
72
+ # Note `num_mesh_points_per_grid_point` is not constant, so this is a list
73
+ # of arrays, rather than a 2d array.
74
+ query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
75
+
76
+ grid_edge_indices = []
77
+ mesh_edge_indices = []
78
+ for grid_index, mesh_neighbors in enumerate(query_indices):
79
+ grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
80
+ mesh_edge_indices.append(mesh_neighbors)
81
+
82
+ # [num_edges]
83
+ grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
84
+ mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
85
+
86
+ return grid_edge_indices, mesh_edge_indices
87
+
88
+
89
+ def in_mesh_triangle_indices(
90
+ *,
91
+ grid_latitude: np.ndarray,
92
+ grid_longitude: np.ndarray,
93
+ mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
94
+ """Returns mesh-grid edge indices for grid points contained in mesh triangles.
95
+
96
+ Args:
97
+ grid_latitude: Latitude values for the grid [num_lat_points]
98
+ grid_longitude: Longitude values for the grid [num_lon_points]
99
+ mesh: Mesh object.
100
+
101
+ Returns:
102
+ tuple with `grid_indices` and `mesh_indices` indicating edges between the
103
+ grid and the mesh vertices of the triangle that contain each grid point.
104
+ The number of edges is always num_lat_points * num_lon_points * 3
105
+ * grid_indices: Indices of shape [num_edges], that index into a
106
+ [num_lat_points, num_lon_points] grid, after flattening the leading axes.
107
+ * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
108
+ """
109
+
110
+ # [num_grid_points=num_lat_points * num_lon_points, 3]
111
+ grid_positions = _grid_lat_lon_to_coordinates(
112
+ grid_latitude, grid_longitude).reshape([-1, 3])
113
+
114
+ mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
115
+
116
+ # [num_grid_points] with mesh face indices for each grid point.
117
+ _, _, query_face_indices = trimesh.proximity.closest_point(
118
+ mesh_trimesh, grid_positions)
119
+
120
+ # [num_grid_points, 3] with mesh node indices for each grid point.
121
+ mesh_edge_indices = mesh.faces[query_face_indices]
122
+
123
+ # [num_grid_points, 3] with grid node indices, where every row simply contains
124
+ # the row (grid_point) index.
125
+ grid_indices = np.arange(grid_positions.shape[0])
126
+ grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
127
+
128
+ # Flatten to get a regular list.
129
+ # [num_edges=num_grid_points*3]
130
+ mesh_edge_indices = mesh_edge_indices.reshape([-1])
131
+ grid_edge_indices = grid_edge_indices.reshape([-1])
132
+
133
+ return grid_edge_indices, mesh_edge_indices
graphcast/grid_mesh_connectivity_test.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for graphcast.grid_mesh_connectivity."""
15
+
16
+ from absl.testing import absltest
17
+ from graphcast import grid_mesh_connectivity
18
+ from graphcast import icosahedral_mesh
19
+ import numpy as np
20
+
21
+
22
+ class GridMeshConnectivityTest(absltest.TestCase):
23
+
24
+ def test_grid_lat_lon_to_coordinates(self):
25
+
26
+ # Intervals of 30 degrees.
27
+ grid_latitude = np.array([-45., 0., 45])
28
+ grid_longitude = np.array([0., 90., 180., 270.])
29
+
30
+ inv_sqrt2 = 1 / np.sqrt(2)
31
+ expected_coordinates = np.array([
32
+ [[inv_sqrt2, 0., -inv_sqrt2],
33
+ [0., inv_sqrt2, -inv_sqrt2],
34
+ [-inv_sqrt2, 0., -inv_sqrt2],
35
+ [0., -inv_sqrt2, -inv_sqrt2]],
36
+ [[1., 0., 0.],
37
+ [0., 1., 0.],
38
+ [-1., 0., 0.],
39
+ [0., -1., 0.]],
40
+ [[inv_sqrt2, 0., inv_sqrt2],
41
+ [0., inv_sqrt2, inv_sqrt2],
42
+ [-inv_sqrt2, 0., inv_sqrt2],
43
+ [0., -inv_sqrt2, inv_sqrt2]],
44
+ ])
45
+
46
+ coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates(
47
+ grid_latitude, grid_longitude)
48
+ np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15)
49
+
50
+ def test_radius_query_indices_smoke(self):
51
+ # TODO(alvarosg): Add non-smoke test?
52
+ grid_latitude = np.linspace(-75, 75, 6)
53
+ grid_longitude = np.arange(12) * 30.
54
+ mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
55
+ splits=3)[-1]
56
+ grid_mesh_connectivity.radius_query_indices(
57
+ grid_latitude=grid_latitude,
58
+ grid_longitude=grid_longitude,
59
+ mesh=mesh, radius=0.2)
60
+
61
+ def test_in_mesh_triangle_indices_smoke(self):
62
+ # TODO(alvarosg): Add non-smoke test?
63
+ grid_latitude = np.linspace(-75, 75, 6)
64
+ grid_longitude = np.arange(12) * 30.
65
+ mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
66
+ splits=3)[-1]
67
+ grid_mesh_connectivity.in_mesh_triangle_indices(
68
+ grid_latitude=grid_latitude,
69
+ grid_longitude=grid_longitude,
70
+ mesh=mesh)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ absltest.main()
graphcast/icosahedral_mesh.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utils for creating icosahedral meshes."""
15
+
16
+ import itertools
17
+ from typing import List, NamedTuple, Sequence, Tuple
18
+
19
+ import numpy as np
20
+ from scipy.spatial import transform
21
+
22
+
23
+ class TriangularMesh(NamedTuple):
24
+ """Data structure for triangular meshes.
25
+
26
+ Attributes:
27
+ vertices: spatial positions of the vertices of the mesh of shape
28
+ [num_vertices, num_dims].
29
+ faces: triangular faces of the mesh of shape [num_faces, 3]. Contains
30
+ integer indices into `vertices`.
31
+
32
+ """
33
+ vertices: np.ndarray
34
+ faces: np.ndarray
35
+
36
+
37
+ def merge_meshes(
38
+ mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
39
+ """Merges all meshes into one. Assumes the last mesh is the finest.
40
+
41
+ Args:
42
+ mesh_list: Sequence of meshes, from coarse to fine refinement levels. The
43
+ vertices and faces may contain those from preceding, coarser levels.
44
+
45
+ Returns:
46
+ `TriangularMesh` for which the vertices correspond to the highest
47
+ resolution mesh in the hierarchy, and the faces are the join set of the
48
+ faces at all levels of the hierarchy.
49
+ """
50
+ for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
51
+ num_nodes_mesh_i = mesh_i.vertices.shape[0]
52
+ assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
53
+
54
+ return TriangularMesh(
55
+ vertices=mesh_list[-1].vertices,
56
+ faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0))
57
+
58
+
59
+ def get_hierarchy_of_triangular_meshes_for_sphere(
60
+ splits: int) -> List[TriangularMesh]:
61
+ """Returns a sequence of meshes, each with triangularization sphere.
62
+
63
+ Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with
64
+ circumscribed unit sphere. Then, each triangular face is iteratively
65
+ subdivided into 4 triangular faces `splits` times. The new vertices are then
66
+ projected back onto the unit sphere. All resulting meshes are returned in a
67
+ list, from lowest to highest resolution.
68
+
69
+ The vertices in each face are specified in counter-clockwise order as
70
+ observed from the outside the icosahedron.
71
+
72
+ Args:
73
+ splits: How many times to split each triangle.
74
+ Returns:
75
+ Sequence of `TriangularMesh`s of length `splits + 1` each with:
76
+
77
+ vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm.
78
+ faces: [num_faces, 3] with triangular faces joining sets of 3 vertices.
79
+ Each row contains three indices into the vertices array, indicating
80
+ the vertices adjacent to the face. Always with positive orientation
81
+ (counterclock-wise when looking from the outside).
82
+ """
83
+ current_mesh = get_icosahedron()
84
+ output_meshes = [current_mesh]
85
+ for _ in range(splits):
86
+ current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
87
+ output_meshes.append(current_mesh)
88
+ return output_meshes
89
+
90
+
91
+ def get_icosahedron() -> TriangularMesh:
92
+ """Returns a regular icosahedral mesh with circumscribed unit sphere.
93
+
94
+ See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates
95
+ for details on the construction of the regular icosahedron.
96
+
97
+ The vertices in each face are specified in counter-clockwise order as observed
98
+ from the outside of the icosahedron.
99
+
100
+ Returns:
101
+ TriangularMesh with:
102
+
103
+ vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm.
104
+ faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices.
105
+ Each row contains three indices into the vertices array, indicating
106
+ the vertices adjacent to the face. Always with positive orientation (
107
+ counterclock-wise when looking from the outside).
108
+
109
+ """
110
+ phi = (1 + np.sqrt(5)) / 2
111
+ vertices = []
112
+ for c1 in [1., -1.]:
113
+ for c2 in [phi, -phi]:
114
+ vertices.append((c1, c2, 0.))
115
+ vertices.append((0., c1, c2))
116
+ vertices.append((c2, 0., c1))
117
+
118
+ vertices = np.array(vertices, dtype=np.float32)
119
+ vertices /= np.linalg.norm([1., phi])
120
+
121
+ # I did this manually, checking the orientation one by one.
122
+ faces = [(0, 1, 2),
123
+ (0, 6, 1),
124
+ (8, 0, 2),
125
+ (8, 4, 0),
126
+ (3, 8, 2),
127
+ (3, 2, 7),
128
+ (7, 2, 1),
129
+ (0, 4, 6),
130
+ (4, 11, 6),
131
+ (6, 11, 5),
132
+ (1, 5, 7),
133
+ (4, 10, 11),
134
+ (4, 8, 10),
135
+ (10, 8, 3),
136
+ (10, 3, 9),
137
+ (11, 10, 9),
138
+ (11, 9, 5),
139
+ (5, 9, 7),
140
+ (9, 3, 7),
141
+ (1, 6, 5),
142
+ ]
143
+
144
+ # By default the top is an aris parallel to the Y axis.
145
+ # Need to rotate around the y axis by half the supplementary to the
146
+ # angle between faces divided by two to get the desired orientation.
147
+ # /O\ (top arist)
148
+ # / \ Z
149
+ # (adjacent face)/ \ (adjacent face) ^
150
+ # / angle_between_faces \ |
151
+ # / \ |
152
+ # / \ YO-----> X
153
+ # This results in:
154
+ # (adjacent faceis now top plane)
155
+ # ----------------------O\ (top arist)
156
+ # \
157
+ # \
158
+ # \ (adjacent face)
159
+ # \
160
+ # \
161
+ # \
162
+
163
+ angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
164
+ rotation_angle = (np.pi - angle_between_faces) / 2
165
+ rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
166
+ rotation_matrix = rotation.as_matrix()
167
+ vertices = np.dot(vertices, rotation_matrix)
168
+
169
+ return TriangularMesh(vertices=vertices.astype(np.float32),
170
+ faces=np.array(faces, dtype=np.int32))
171
+
172
+
173
+ def _two_split_unit_sphere_triangle_faces(
174
+ triangular_mesh: TriangularMesh) -> TriangularMesh:
175
+ """Splits each triangular face into 4 triangles keeping the orientation."""
176
+
177
+ # Every time we split a triangle into 4 we will be adding 3 extra vertices,
178
+ # located at the edge centres.
179
+ # This class handles the positioning of the new vertices, and avoids creating
180
+ # duplicates.
181
+ new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
182
+
183
+ new_faces = []
184
+ for ind1, ind2, ind3 in triangular_mesh.faces:
185
+ # Transform each triangular face into 4 triangles,
186
+ # preserving the orientation.
187
+ # ind3
188
+ # / \
189
+ # / \
190
+ # / #3 \
191
+ # / \
192
+ # ind31 -------------- ind23
193
+ # / \ / \
194
+ # / \ #4 / \
195
+ # / #1 \ / #2 \
196
+ # / \ / \
197
+ # ind1 ------------ ind12 ------------ ind2
198
+ ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
199
+ ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
200
+ ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
201
+ # Note how each of the 4 triangular new faces specifies the order of the
202
+ # vertices to preserve the orientation of the original face. As the input
203
+ # face should always be counter-clockwise as specified in the diagram,
204
+ # this means child faces should also be counter-clockwise.
205
+ new_faces.extend([[ind1, ind12, ind31], # 1
206
+ [ind12, ind2, ind23], # 2
207
+ [ind31, ind23, ind3], # 3
208
+ [ind12, ind23, ind31], # 4
209
+ ])
210
+ return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(),
211
+ faces=np.array(new_faces, dtype=np.int32))
212
+
213
+
214
+ class _ChildVerticesBuilder(object):
215
+ """Bookkeeping of new child vertices added to an existing set of vertices."""
216
+
217
+ def __init__(self, parent_vertices):
218
+
219
+ # Because the same new vertex will be required when splitting adjacent
220
+ # triangles (which share an edge) we keep them in a hash table indexed by
221
+ # sorted indices of the vertices adjacent to the edge, to avoid creating
222
+ # duplicated child vertices.
223
+ self._child_vertices_index_mapping = {}
224
+ self._parent_vertices = parent_vertices
225
+ # We start with all previous vertices.
226
+ self._all_vertices_list = list(parent_vertices)
227
+
228
+ def _get_child_vertex_key(self, parent_vertex_indices):
229
+ return tuple(sorted(parent_vertex_indices))
230
+
231
+ def _create_child_vertex(self, parent_vertex_indices):
232
+ """Creates a new vertex."""
233
+ # Position for new vertex is the middle point, between the parent points,
234
+ # projected to unit sphere.
235
+ child_vertex_position = self._parent_vertices[
236
+ list(parent_vertex_indices)].mean(0)
237
+ child_vertex_position /= np.linalg.norm(child_vertex_position)
238
+
239
+ # Add the vertex to the output list. The index for this new vertex will
240
+ # match the length of the list before adding it.
241
+ child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
242
+ self._child_vertices_index_mapping[child_vertex_key] = len(
243
+ self._all_vertices_list)
244
+ self._all_vertices_list.append(child_vertex_position)
245
+
246
+ def get_new_child_vertex_index(self, parent_vertex_indices):
247
+ """Returns index for a child vertex, creating it if necessary."""
248
+ # Get the key to see if we already have a new vertex in the middle.
249
+ child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
250
+ if child_vertex_key not in self._child_vertices_index_mapping:
251
+ self._create_child_vertex(parent_vertex_indices)
252
+ return self._child_vertices_index_mapping[child_vertex_key]
253
+
254
+ def get_all_vertices(self):
255
+ """Returns an array with old vertices."""
256
+ return np.array(self._all_vertices_list)
257
+
258
+
259
+ def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
260
+ """Transforms polygonal faces to sender and receiver indices.
261
+
262
+ It does so by transforming every face into N_i edges. Such if the triangular
263
+ face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0.
264
+
265
+ If all faces have consistent orientation, and the surface represented by the
266
+ faces is closed, then every edge in a polygon with a certain orientation
267
+ is also part of another polygon with the opposite orientation. In this
268
+ situation, the edges returned by the method are always bidirectional.
269
+
270
+ Args:
271
+ faces: Integer array of shape [num_faces, 3]. Contains node indices
272
+ adjacent to each face.
273
+ Returns:
274
+ Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3].
275
+
276
+ """
277
+ assert faces.ndim == 2
278
+ assert faces.shape[-1] == 3
279
+ senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
280
+ receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
281
+ return senders, receivers
graphcast/icosahedral_mesh_test.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for icosahedral_mesh."""
15
+
16
+ from absl.testing import absltest
17
+ from absl.testing import parameterized
18
+ import chex
19
+ from graphcast import icosahedral_mesh
20
+ import numpy as np
21
+
22
+
23
+ def _get_mesh_spec(splits: int):
24
+ """Returns size of the final icosahedral mesh resulting from the splitting."""
25
+ num_vertices = 12
26
+ num_faces = 20
27
+ for _ in range(splits):
28
+ # Each previous face adds three new vertices, but each vertex is shared
29
+ # by two faces.
30
+ num_vertices += num_faces * 3 // 2
31
+ num_faces *= 4
32
+ return num_vertices, num_faces
33
+
34
+
35
+ class IcosahedralMeshTest(parameterized.TestCase):
36
+
37
+ def test_icosahedron(self):
38
+ mesh = icosahedral_mesh.get_icosahedron()
39
+ _assert_valid_mesh(
40
+ mesh, num_expected_vertices=12, num_expected_faces=20)
41
+
42
+ @parameterized.parameters(list(range(5)))
43
+ def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
44
+ meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
45
+ splits=splits)
46
+ prev_vertices = None
47
+ for mesh_i, mesh in enumerate(meshes):
48
+ # Check that `mesh` is valid.
49
+ num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
50
+ _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
51
+
52
+ # Check that the first N vertices from this mesh match all of the
53
+ # vertices from the previous mesh.
54
+ if prev_vertices is not None:
55
+ leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
56
+ np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
57
+
58
+ # Increase the expected/previous values for the next iteration.
59
+ if mesh_i < len(meshes) - 1:
60
+ prev_vertices = mesh.vertices
61
+
62
+ @parameterized.parameters(list(range(4)))
63
+ def test_merge_meshes(self, splits):
64
+ mesh_hierarchy = (
65
+ icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
66
+ splits=splits))
67
+ mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
68
+
69
+ expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
70
+ np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
71
+ np.testing.assert_array_equal(mesh.faces, expected_faces)
72
+
73
+ def test_faces_to_edges(self):
74
+
75
+ faces = np.array([[0, 1, 2],
76
+ [3, 4, 5]])
77
+
78
+ # This also documents the order of the edges returned by the method.
79
+ expected_edges = np.array(
80
+ [[0, 1],
81
+ [3, 4],
82
+ [1, 2],
83
+ [4, 5],
84
+ [2, 0],
85
+ [5, 3]])
86
+ expected_senders = expected_edges[:, 0]
87
+ expected_receivers = expected_edges[:, 1]
88
+
89
+ senders, receivers = icosahedral_mesh.faces_to_edges(faces)
90
+
91
+ np.testing.assert_array_equal(senders, expected_senders)
92
+ np.testing.assert_array_equal(receivers, expected_receivers)
93
+
94
+
95
+ def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
96
+ vertices = mesh.vertices
97
+ faces = mesh.faces
98
+ chex.assert_shape(vertices, [num_expected_vertices, 3])
99
+ chex.assert_shape(faces, [num_expected_faces, 3])
100
+
101
+ # Vertices norm should be 1.
102
+ vertices_norm = np.linalg.norm(vertices, axis=-1)
103
+ np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
104
+
105
+ _assert_positive_face_orientation(vertices, faces)
106
+
107
+
108
+ def _assert_positive_face_orientation(vertices, faces):
109
+
110
+ # Obtain a unit vector that points, in the direction of the face.
111
+ face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
112
+ vertices[faces[:, 2]] - vertices[faces[:, 1]])
113
+ face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
114
+
115
+ # And a unit vector pointing from the origin to the center of the face.
116
+ face_centers = vertices[faces].mean(1)
117
+ face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
118
+
119
+ # Positive orientation means those two vectors should be parallel
120
+ # (dot product, 1), and not anti-parallel (dot product, -1).
121
+ dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
122
+
123
+ # Check that the face normal is parallel to the vector that joins the center
124
+ # of the face to the center of the sphere. Note we need a small tolerance
125
+ # because some discretizations are not exactly uniform, so it will not be
126
+ # exactly parallel.
127
+ np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ absltest.main()
graphcast/losses.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Loss functions (and terms for use in loss functions) used for weather."""
15
+
16
+ from typing import Mapping
17
+
18
+ from graphcast import xarray_tree
19
+ import numpy as np
20
+ from typing_extensions import Protocol
21
+ import xarray
22
+
23
+
24
+ LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
25
+
26
+
27
+ class LossFunction(Protocol):
28
+ """A loss function.
29
+
30
+ This is a protocol so it's fine to use a plain function which 'quacks like'
31
+ this. This is just to document the interface.
32
+ """
33
+
34
+ def __call__(self,
35
+ predictions: xarray.Dataset,
36
+ targets: xarray.Dataset,
37
+ **optional_kwargs) -> LossAndDiagnostics:
38
+ """Computes a loss function.
39
+
40
+ Args:
41
+ predictions: Dataset of predictions.
42
+ targets: Dataset of targets.
43
+ **optional_kwargs: Implementations may support extra optional kwargs.
44
+
45
+ Returns:
46
+ loss: A DataArray with dimensions ('batch',) containing losses for each
47
+ element of the batch. These will be averaged to give the final
48
+ loss, locally and across replicas.
49
+ diagnostics: Mapping of additional quantities to log by name alongside the
50
+ loss. These will will typically correspond to terms in the loss. They
51
+ should also have dimensions ('batch',) and will be averaged over the
52
+ batch before logging.
53
+ """
54
+
55
+
56
+ def weighted_mse_per_level(
57
+ predictions: xarray.Dataset,
58
+ targets: xarray.Dataset,
59
+ per_variable_weights: Mapping[str, float],
60
+ ) -> LossAndDiagnostics:
61
+ """Latitude- and pressure-level-weighted MSE loss."""
62
+ def loss(prediction, target):
63
+ loss = (prediction - target)**2
64
+ loss *= normalized_latitude_weights(target).astype(loss.dtype)
65
+ if 'level' in target.dims:
66
+ loss *= normalized_level_weights(target).astype(loss.dtype)
67
+ return _mean_preserving_batch(loss)
68
+
69
+ losses = xarray_tree.map_structure(loss, predictions, targets)
70
+ return sum_per_variable_losses(losses, per_variable_weights)
71
+
72
+
73
+ def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
74
+ return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
75
+
76
+
77
+ def sum_per_variable_losses(
78
+ per_variable_losses: Mapping[str, xarray.DataArray],
79
+ weights: Mapping[str, float],
80
+ ) -> LossAndDiagnostics:
81
+ """Weighted sum of per-variable losses."""
82
+ if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
83
+ raise ValueError(
84
+ 'Passing a weight that does not correspond to any variable '
85
+ f'{set(weights.keys())-set(per_variable_losses.keys())}')
86
+
87
+ weighted_per_variable_losses = {
88
+ name: loss * weights.get(name, 1)
89
+ for name, loss in per_variable_losses.items()
90
+ }
91
+ total = xarray.concat(
92
+ weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
93
+ 'variable', skipna=False)
94
+ return total, per_variable_losses # pytype: disable=bad-return-type
95
+
96
+
97
+ def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
98
+ """Weights proportional to pressure at each level."""
99
+ level = data.coords['level']
100
+ return level / level.mean(skipna=False)
101
+
102
+
103
+ def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
104
+ """Weights based on latitude, roughly proportional to grid cell area.
105
+
106
+ This method supports two use cases only (both for equispaced values):
107
+ * Latitude values such that the closest value to the pole is at latitude
108
+ (90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
109
+ For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
110
+ In this case each point with `lat` value represents a sphere slice between
111
+ `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
112
+ proportional to:
113
+ `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
114
+ we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
115
+ that cancels during normalization.
116
+ * Latitude values that fall exactly at the poles.
117
+ For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
118
+ In this case each point with `lat` value also represents
119
+ a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
120
+ except for the points at the poles, that represent a slice between
121
+ `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
122
+ The areas of the first type of point are still proportional to:
123
+ * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
124
+ but for the points at the poles now is:
125
+ * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
126
+ and we will be using these weights, depending on whether we are looking at
127
+ pole cells, or non-pole cells (omitting the common factor of 2 which will be
128
+ absorbed by the normalization).
129
+
130
+ It can be shown via a limit, or simple geometry, that in the small angles
131
+ regime, the proportion of area per pole-point is equal to 1/8th
132
+ the proportion of area covered by each of the nearest non-pole point, and we
133
+ test for this in the test.
134
+
135
+ Args:
136
+ data: `DataArray` with latitude coordinates.
137
+ Returns:
138
+ Unit mean latitude weights.
139
+ """
140
+ latitude = data.coords['lat']
141
+
142
+ if np.any(np.isclose(np.abs(latitude), 90.)):
143
+ weights = _weight_for_latitude_vector_with_poles(latitude)
144
+ else:
145
+ weights = _weight_for_latitude_vector_without_poles(latitude)
146
+
147
+ return weights / weights.mean(skipna=False)
148
+
149
+
150
+ def _weight_for_latitude_vector_without_poles(latitude):
151
+ """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
152
+ delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
153
+ if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
154
+ not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
155
+ raise ValueError(
156
+ f'Latitude vector {latitude} does not start/end at '
157
+ '+- (90 - delta_latitude/2) degrees.')
158
+ return np.cos(np.deg2rad(latitude))
159
+
160
+
161
+ def _weight_for_latitude_vector_with_poles(latitude):
162
+ """Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
163
+ delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
164
+ if (not np.isclose(np.max(latitude), 90.) or
165
+ not np.isclose(np.min(latitude), -90.)):
166
+ raise ValueError(
167
+ f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
168
+ weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
169
+ # The two checks above enough to guarantee that latitudes are sorted, so
170
+ # the extremes are the poles
171
+ weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
172
+ return weights
173
+
174
+
175
+ def _check_uniform_spacing_and_get_delta(vector):
176
+ diff = np.diff(vector)
177
+ if not np.all(np.isclose(diff[0], diff)):
178
+ raise ValueError(f'Vector {diff} is not uniformly spaced.')
179
+ return diff[0]
graphcast/model_utils.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities for building models."""
15
+
16
+ from typing import Mapping, Optional, Tuple
17
+
18
+ import numpy as np
19
+ from scipy.spatial import transform
20
+ import xarray
21
+
22
+
23
+ def get_graph_spatial_features(
24
+ *, node_lat: np.ndarray, node_lon: np.ndarray,
25
+ senders: np.ndarray, receivers: np.ndarray,
26
+ add_node_positions: bool,
27
+ add_node_latitude: bool,
28
+ add_node_longitude: bool,
29
+ add_relative_positions: bool,
30
+ relative_longitude_local_coordinates: bool,
31
+ relative_latitude_local_coordinates: bool,
32
+ sine_cosine_encoding: bool = False,
33
+ encoding_num_freqs: int = 10,
34
+ encoding_multiplicative_factor: float = 1.2,
35
+ ) -> Tuple[np.ndarray, np.ndarray]:
36
+ """Computes spatial features for the nodes.
37
+
38
+ Args:
39
+ node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes]
40
+ node_lon: Longitudes in the [0, 360] interval of shape [num_nodes]
41
+ senders: Sender indices of shape [num_edges]
42
+ receivers: Receiver indices of shape [num_edges]
43
+ add_node_positions: Add unit norm absolute positions.
44
+ add_node_latitude: Add a feature for latitude (cos(90 - lat))
45
+ Note even if this is set to False, the model may be able to infer the
46
+ longitude from relative features, unless
47
+ `relative_latitude_local_coordinates` is also True, or if there is any
48
+ bias on the relative edge sizes for different longitudes.
49
+ add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
50
+ Note even if this is set to False, the model may be able to infer the
51
+ longitude from relative features, unless
52
+ `relative_longitude_local_coordinates` is also True, or if there is any
53
+ bias on the relative edge sizes for different longitudes.
54
+ add_relative_positions: Whether to relative positions in R3 to the edges.
55
+ relative_longitude_local_coordinates: If True, relative positions are
56
+ computed in a local space where the receiver is at 0 longitude.
57
+ relative_latitude_local_coordinates: If True, relative positions are
58
+ computed in a local space where the receiver is at 0 latitude.
59
+ sine_cosine_encoding: If True, we will transform the node/edge features
60
+ with sine and cosine functions, similar to NERF.
61
+ encoding_num_freqs: frequency parameter
62
+ encoding_multiplicative_factor: used for calculating the frequency.
63
+
64
+ Returns:
65
+ Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
66
+ with node and edge features.
67
+
68
+ """
69
+
70
+ num_nodes = node_lat.shape[0]
71
+ num_edges = senders.shape[0]
72
+ dtype = node_lat.dtype
73
+ node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon)
74
+
75
+ # Computing some node features.
76
+ node_features = []
77
+ if add_node_positions:
78
+ # Already in [-1, 1.] range.
79
+ node_features.extend(spherical_to_cartesian(node_phi, node_theta))
80
+
81
+ if add_node_latitude:
82
+ # Using the cos of theta.
83
+ # From 1. (north pole) to -1 (south pole).
84
+ node_features.append(np.cos(node_theta))
85
+
86
+ if add_node_longitude:
87
+ # Using the cos and sin, which is already normalized.
88
+ node_features.append(np.cos(node_phi))
89
+ node_features.append(np.sin(node_phi))
90
+
91
+ if not node_features:
92
+ node_features = np.zeros([num_nodes, 0], dtype=dtype)
93
+ else:
94
+ node_features = np.stack(node_features, axis=-1)
95
+
96
+ # Computing some edge features.
97
+ edge_features = []
98
+
99
+ if add_relative_positions:
100
+
101
+ relative_position = get_relative_position_in_receiver_local_coordinates(
102
+ node_phi=node_phi,
103
+ node_theta=node_theta,
104
+ senders=senders,
105
+ receivers=receivers,
106
+ latitude_local_coordinates=relative_latitude_local_coordinates,
107
+ longitude_local_coordinates=relative_longitude_local_coordinates
108
+ )
109
+
110
+ # Note this is L2 distance in 3d space, rather than geodesic distance.
111
+ relative_edge_distances = np.linalg.norm(
112
+ relative_position, axis=-1, keepdims=True)
113
+
114
+ # Normalize to the maximum edge distance. Note that we expect to always
115
+ # have an edge that goes in the opposite direction of any given edge
116
+ # so the distribution of relative positions should be symmetric around
117
+ # zero. So by scaling by the maximum length, we expect all relative
118
+ # positions to fall in the [-1., 1.] interval, and all relative distances
119
+ # to fall in the [0., 1.] interval.
120
+ max_edge_distance = relative_edge_distances.max()
121
+ edge_features.append(relative_edge_distances / max_edge_distance)
122
+ edge_features.append(relative_position / max_edge_distance)
123
+
124
+ if not edge_features:
125
+ edge_features = np.zeros([num_edges, 0], dtype=dtype)
126
+ else:
127
+ edge_features = np.concatenate(edge_features, axis=-1)
128
+
129
+ if sine_cosine_encoding:
130
+ def sine_cosine_transform(x: np.ndarray) -> np.ndarray:
131
+ freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs)
132
+ phases = freqs * x[..., None]
133
+ x_sin = np.sin(phases)
134
+ x_cos = np.cos(phases)
135
+ x_cat = np.concatenate([x_sin, x_cos], axis=-1)
136
+ return x_cat.reshape([x.shape[0], -1])
137
+
138
+ node_features = sine_cosine_transform(node_features)
139
+ edge_features = sine_cosine_transform(edge_features)
140
+
141
+ return node_features, edge_features
142
+
143
+
144
+ def lat_lon_to_leading_axes(
145
+ grid_xarray: xarray.DataArray) -> xarray.DataArray:
146
+ """Reorders xarray so lat/lon axes come first."""
147
+ # leading + ["lat", "lon"] + trailing
148
+ # to
149
+ # ["lat", "lon"] + leading + trailing
150
+ return grid_xarray.transpose("lat", "lon", ...)
151
+
152
+
153
+ def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray:
154
+ """Reorders xarray so batch/time/level axes come first (if present)."""
155
+
156
+ # ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing
157
+ # to
158
+ # [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing
159
+
160
+ input_dims = list(grid_xarray.dims)
161
+ output_dims = list(input_dims)
162
+ for leading_key in ["level", "time", "batch"]: # reverse order for insert
163
+ if leading_key in input_dims:
164
+ output_dims.remove(leading_key)
165
+ output_dims.insert(0, leading_key)
166
+ return grid_xarray.transpose(*output_dims)
167
+
168
+
169
+ def lat_lon_deg_to_spherical(node_lat: np.ndarray,
170
+ node_lon: np.ndarray,
171
+ ) -> Tuple[np.ndarray, np.ndarray]:
172
+ phi = np.deg2rad(node_lon)
173
+ theta = np.deg2rad(90 - node_lat)
174
+ return phi, theta
175
+
176
+
177
+ def spherical_to_lat_lon(phi: np.ndarray,
178
+ theta: np.ndarray,
179
+ ) -> Tuple[np.ndarray, np.ndarray]:
180
+ lon = np.mod(np.rad2deg(phi), 360)
181
+ lat = 90 - np.rad2deg(theta)
182
+ return lat, lon
183
+
184
+
185
+ def cartesian_to_spherical(x: np.ndarray,
186
+ y: np.ndarray,
187
+ z: np.ndarray,
188
+ ) -> Tuple[np.ndarray, np.ndarray]:
189
+ phi = np.arctan2(y, x)
190
+ with np.errstate(invalid="ignore"): # circumventing b/253179568
191
+ theta = np.arccos(z) # Assuming unit radius.
192
+ return phi, theta
193
+
194
+
195
+ def spherical_to_cartesian(
196
+ phi: np.ndarray, theta: np.ndarray
197
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
198
+ # Assuming unit radius.
199
+ return (np.cos(phi)*np.sin(theta),
200
+ np.sin(phi)*np.sin(theta),
201
+ np.cos(theta))
202
+
203
+
204
+ def get_relative_position_in_receiver_local_coordinates(
205
+ node_phi: np.ndarray,
206
+ node_theta: np.ndarray,
207
+ senders: np.ndarray,
208
+ receivers: np.ndarray,
209
+ latitude_local_coordinates: bool,
210
+ longitude_local_coordinates: bool
211
+ ) -> np.ndarray:
212
+ """Returns relative position features for the edges.
213
+
214
+ The relative positions will be computed in a rotated space for a local
215
+ coordinate system as defined by the receiver. The relative positions are
216
+ simply obtained by subtracting sender position minues receiver position in
217
+ that local coordinate system after the rotation in R^3.
218
+
219
+ Args:
220
+ node_phi: [num_nodes] with polar angles.
221
+ node_theta: [num_nodes] with azimuthal angles.
222
+ senders: [num_edges] with indices.
223
+ receivers: [num_edges] with indices.
224
+ latitude_local_coordinates: Whether to rotate edges such that in the
225
+ positions are computed such that the receiver is always at latitude 0.
226
+ longitude_local_coordinates: Whether to rotate edges such that in the
227
+ positions are computed such that the receiver is always at longitude 0.
228
+
229
+ Returns:
230
+ Array of relative positions in R3 [num_edges, 3]
231
+ """
232
+
233
+ node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1)
234
+
235
+ # No rotation in this case.
236
+ if not (latitude_local_coordinates or longitude_local_coordinates):
237
+ return node_pos[senders] - node_pos[receivers]
238
+
239
+ # Get rotation matrices for the local space space for every node.
240
+ rotation_matrices = get_rotation_matrices_to_local_coordinates(
241
+ reference_phi=node_phi,
242
+ reference_theta=node_theta,
243
+ rotate_latitude=latitude_local_coordinates,
244
+ rotate_longitude=longitude_local_coordinates)
245
+
246
+ # Each edge will be rotated according to the rotation matrix of its receiver
247
+ # node.
248
+ edge_rotation_matrices = rotation_matrices[receivers]
249
+
250
+ # Rotate all nodes to the rotated space of the corresponding edge.
251
+ # Note for receivers we can also do the matmul first and the gather second:
252
+ # ```
253
+ # receiver_pos_in_rotated_space = rotate_with_matrices(
254
+ # rotation_matrices, node_pos)[receivers]
255
+ # ```
256
+ # which is more efficient, however, we do gather first to keep it more
257
+ # symmetric with the sender computation.
258
+ receiver_pos_in_rotated_space = rotate_with_matrices(
259
+ edge_rotation_matrices, node_pos[receivers])
260
+ sender_pos_in_in_rotated_space = rotate_with_matrices(
261
+ edge_rotation_matrices, node_pos[senders])
262
+ # Note, here, that because the rotated space is chosen according to the
263
+ # receiver, if:
264
+ # * latitude_local_coordinates = True: latitude for the receivers will be
265
+ # 0, that is the z coordinate will always be 0.
266
+ # * longitude_local_coordinates = True: longitude for the receivers will be
267
+ # 0, that is the y coordinate will be 0.
268
+
269
+ # Now we can just subtract.
270
+ # Note we are rotating to a local coordinate system, where the y-z axes are
271
+ # parallel to a tangent plane to the sphere, but still remain in a 3d space.
272
+ # Note that if both `latitude_local_coordinates` and
273
+ # `longitude_local_coordinates` are True, and edges are short,
274
+ # then the difference in x coordinate between sender and receiver
275
+ # should be small, so we could consider dropping the new x coordinate if
276
+ # we wanted to the tangent plane, however in doing so
277
+ # we would lose information about the curvature of the mesh, which may be
278
+ # important for very coarse meshes.
279
+ return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
280
+
281
+
282
+ def get_rotation_matrices_to_local_coordinates(
283
+ reference_phi: np.ndarray,
284
+ reference_theta: np.ndarray,
285
+ rotate_latitude: bool,
286
+ rotate_longitude: bool) -> np.ndarray:
287
+
288
+ """Returns a rotation matrix to rotate to a point based on a reference vector.
289
+
290
+ The rotation matrix is build such that, a vector in the
291
+ same coordinate system at the reference point that points towards the pole
292
+ before the rotation, continues to point towards the pole after the rotation.
293
+
294
+ Args:
295
+ reference_phi: [leading_axis] Polar angles of the reference.
296
+ reference_theta: [leading_axis] Azimuthal angles of the reference.
297
+ rotate_latitude: Whether to produce a rotation matrix that would rotate
298
+ R^3 vectors to zero latitude.
299
+ rotate_longitude: Whether to produce a rotation matrix that would rotate
300
+ R^3 vectors to zero longitude.
301
+
302
+ Returns:
303
+ Matrices of shape [leading_axis] such that when applied to the reference
304
+ position with `rotate_with_matrices(rotation_matrices, reference_pos)`
305
+
306
+ * phi goes to 0. if "rotate_longitude" is True.
307
+
308
+ * theta goes to np.pi / 2 if "rotate_latitude" is True.
309
+
310
+ The rotation consists of:
311
+ * rotate_latitude = False, rotate_longitude = True:
312
+ Latitude preserving rotation.
313
+ * rotate_latitude = True, rotate_longitude = True:
314
+ Latitude preserving rotation, followed by longitude preserving
315
+ rotation.
316
+ * rotate_latitude = True, rotate_longitude = False:
317
+ Latitude preserving rotation, followed by longitude preserving
318
+ rotation, and the inverse of the latitude preserving rotation. Note
319
+ this is computationally different from rotating the longitude only
320
+ and is. We do it like this, so the polar geodesic curve, continues
321
+ to be aligned with one of the axis after the rotation.
322
+
323
+ """
324
+
325
+ if rotate_longitude and rotate_latitude:
326
+
327
+ # We first rotate around the z axis "minus the azimuthal angle", to get the
328
+ # point with zero longitude
329
+ azimuthal_rotation = - reference_phi
330
+
331
+ # One then we will do a polar rotation (which can be done along the y
332
+ # axis now that we are at longitude 0.), "minus the polar angle plus 2pi"
333
+ # to get the point with zero latitude.
334
+ polar_rotation = - reference_theta + np.pi/2
335
+
336
+ return transform.Rotation.from_euler(
337
+ "zy", np.stack([azimuthal_rotation, polar_rotation],
338
+ axis=1)).as_matrix()
339
+ elif rotate_longitude:
340
+ # Just like the previous case, but applying only the azimuthal rotation.
341
+ azimuthal_rotation = - reference_phi
342
+ return transform.Rotation.from_euler("z", -reference_phi).as_matrix()
343
+ elif rotate_latitude:
344
+ # Just like the first case, but after doing the polar rotation, undoing
345
+ # the azimuthal rotation.
346
+ azimuthal_rotation = - reference_phi
347
+ polar_rotation = - reference_theta + np.pi/2
348
+
349
+ return transform.Rotation.from_euler(
350
+ "zyz", np.stack(
351
+ [azimuthal_rotation, polar_rotation, -azimuthal_rotation]
352
+ , axis=1)).as_matrix()
353
+ else:
354
+ raise ValueError(
355
+ "At least one of longitude and latitude should be rotated.")
356
+
357
+
358
+ def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray
359
+ ) -> np.ndarray:
360
+ return np.einsum("bji,bi->bj", rotation_matrices, positions)
361
+
362
+
363
+ def get_bipartite_graph_spatial_features(
364
+ *,
365
+ senders_node_lat: np.ndarray,
366
+ senders_node_lon: np.ndarray,
367
+ senders: np.ndarray,
368
+ receivers_node_lat: np.ndarray,
369
+ receivers_node_lon: np.ndarray,
370
+ receivers: np.ndarray,
371
+ add_node_positions: bool,
372
+ add_node_latitude: bool,
373
+ add_node_longitude: bool,
374
+ add_relative_positions: bool,
375
+ edge_normalization_factor: Optional[float] = None,
376
+ relative_longitude_local_coordinates: bool,
377
+ relative_latitude_local_coordinates: bool,
378
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
379
+ """Computes spatial features for the nodes.
380
+
381
+ This function is almost identical to `get_graph_spatial_features`. The only
382
+ difference is that sender nodes and receiver nodes can be in different arrays.
383
+ This is necessary to enable combination with typed Graph.
384
+
385
+ Args:
386
+ senders_node_lat: Latitudes in the [-90, 90] interval of shape
387
+ [num_sender_nodes]
388
+ senders_node_lon: Longitudes in the [0, 360] interval of shape
389
+ [num_sender_nodes]
390
+ senders: Sender indices of shape [num_edges], indices in [0,
391
+ num_sender_nodes)
392
+ receivers_node_lat: Latitudes in the [-90, 90] interval of shape
393
+ [num_receiver_nodes]
394
+ receivers_node_lon: Longitudes in the [0, 360] interval of shape
395
+ [num_receiver_nodes]
396
+ receivers: Receiver indices of shape [num_edges], indices in [0,
397
+ num_receiver_nodes)
398
+ add_node_positions: Add unit norm absolute positions.
399
+ add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if
400
+ this is set to False, the model may be able to infer the longitude from
401
+ relative features, unless `relative_latitude_local_coordinates` is also
402
+ True, or if there is any bias on the relative edge sizes for different
403
+ longitudes.
404
+ add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note
405
+ even if this is set to False, the model may be able to infer the longitude
406
+ from relative features, unless `relative_longitude_local_coordinates` is
407
+ also True, or if there is any bias on the relative edge sizes for
408
+ different longitudes.
409
+ add_relative_positions: Whether to relative positions in R3 to the edges.
410
+ edge_normalization_factor: Allows explicitly controlling edge normalization.
411
+ If None, defaults to max edge length. This supports using pre-trained
412
+ model weights with a different graph structure to what it was trained on.
413
+ relative_longitude_local_coordinates: If True, relative positions are
414
+ computed in a local space where the receiver is at 0 longitude.
415
+ relative_latitude_local_coordinates: If True, relative positions are
416
+ computed in a local space where the receiver is at 0 latitude.
417
+
418
+ Returns:
419
+ Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
420
+ with node and edge features.
421
+
422
+ """
423
+
424
+ num_senders = senders_node_lat.shape[0]
425
+ num_receivers = receivers_node_lat.shape[0]
426
+ num_edges = senders.shape[0]
427
+ dtype = senders_node_lat.dtype
428
+ assert receivers_node_lat.dtype == dtype
429
+ senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical(
430
+ senders_node_lat, senders_node_lon)
431
+ receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical(
432
+ receivers_node_lat, receivers_node_lon)
433
+
434
+ # Computing some node features.
435
+ senders_node_features = []
436
+ receivers_node_features = []
437
+ if add_node_positions:
438
+ # Already in [-1, 1.] range.
439
+ senders_node_features.extend(
440
+ spherical_to_cartesian(senders_node_phi, senders_node_theta))
441
+ receivers_node_features.extend(
442
+ spherical_to_cartesian(receivers_node_phi, receivers_node_theta))
443
+
444
+ if add_node_latitude:
445
+ # Using the cos of theta.
446
+ # From 1. (north pole) to -1 (south pole).
447
+ senders_node_features.append(np.cos(senders_node_theta))
448
+ receivers_node_features.append(np.cos(receivers_node_theta))
449
+
450
+ if add_node_longitude:
451
+ # Using the cos and sin, which is already normalized.
452
+ senders_node_features.append(np.cos(senders_node_phi))
453
+ senders_node_features.append(np.sin(senders_node_phi))
454
+
455
+ receivers_node_features.append(np.cos(receivers_node_phi))
456
+ receivers_node_features.append(np.sin(receivers_node_phi))
457
+
458
+ if not senders_node_features:
459
+ senders_node_features = np.zeros([num_senders, 0], dtype=dtype)
460
+ receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype)
461
+ else:
462
+ senders_node_features = np.stack(senders_node_features, axis=-1)
463
+ receivers_node_features = np.stack(receivers_node_features, axis=-1)
464
+
465
+ # Computing some edge features.
466
+ edge_features = []
467
+
468
+ if add_relative_positions:
469
+
470
+ relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long
471
+ senders_node_phi=senders_node_phi,
472
+ senders_node_theta=senders_node_theta,
473
+ receivers_node_phi=receivers_node_phi,
474
+ receivers_node_theta=receivers_node_theta,
475
+ senders=senders,
476
+ receivers=receivers,
477
+ latitude_local_coordinates=relative_latitude_local_coordinates,
478
+ longitude_local_coordinates=relative_longitude_local_coordinates)
479
+
480
+ # Note this is L2 distance in 3d space, rather than geodesic distance.
481
+ relative_edge_distances = np.linalg.norm(
482
+ relative_position, axis=-1, keepdims=True)
483
+
484
+ if edge_normalization_factor is None:
485
+ # Normalize to the maximum edge distance. Note that we expect to always
486
+ # have an edge that goes in the opposite direction of any given edge
487
+ # so the distribution of relative positions should be symmetric around
488
+ # zero. So by scaling by the maximum length, we expect all relative
489
+ # positions to fall in the [-1., 1.] interval, and all relative distances
490
+ # to fall in the [0., 1.] interval.
491
+ edge_normalization_factor = relative_edge_distances.max()
492
+
493
+ edge_features.append(relative_edge_distances / edge_normalization_factor)
494
+ edge_features.append(relative_position / edge_normalization_factor)
495
+
496
+ if not edge_features:
497
+ edge_features = np.zeros([num_edges, 0], dtype=dtype)
498
+ else:
499
+ edge_features = np.concatenate(edge_features, axis=-1)
500
+
501
+ return senders_node_features, receivers_node_features, edge_features
502
+
503
+
504
+ def get_bipartite_relative_position_in_receiver_local_coordinates(
505
+ senders_node_phi: np.ndarray,
506
+ senders_node_theta: np.ndarray,
507
+ senders: np.ndarray,
508
+ receivers_node_phi: np.ndarray,
509
+ receivers_node_theta: np.ndarray,
510
+ receivers: np.ndarray,
511
+ latitude_local_coordinates: bool,
512
+ longitude_local_coordinates: bool) -> np.ndarray:
513
+ """Returns relative position features for the edges.
514
+
515
+ This function is equivalent to
516
+ `get_relative_position_in_receiver_local_coordinates`, but adapted to work
517
+ with bipartite typed graphs.
518
+
519
+ The relative positions will be computed in a rotated space for a local
520
+ coordinate system as defined by the receiver. The relative positions are
521
+ simply obtained by subtracting sender position minues receiver position in
522
+ that local coordinate system after the rotation in R^3.
523
+
524
+ Args:
525
+ senders_node_phi: [num_sender_nodes] with polar angles.
526
+ senders_node_theta: [num_sender_nodes] with azimuthal angles.
527
+ senders: [num_edges] with indices into sender nodes.
528
+ receivers_node_phi: [num_sender_nodes] with polar angles.
529
+ receivers_node_theta: [num_sender_nodes] with azimuthal angles.
530
+ receivers: [num_edges] with indices into receiver nodes.
531
+ latitude_local_coordinates: Whether to rotate edges such that in the
532
+ positions are computed such that the receiver is always at latitude 0.
533
+ longitude_local_coordinates: Whether to rotate edges such that in the
534
+ positions are computed such that the receiver is always at longitude 0.
535
+
536
+ Returns:
537
+ Array of relative positions in R3 [num_edges, 3]
538
+ """
539
+
540
+ senders_node_pos = np.stack(
541
+ spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1)
542
+
543
+ receivers_node_pos = np.stack(
544
+ spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1)
545
+
546
+ # No rotation in this case.
547
+ if not (latitude_local_coordinates or longitude_local_coordinates):
548
+ return senders_node_pos[senders] - receivers_node_pos[receivers]
549
+
550
+ # Get rotation matrices for the local space space for every receiver node.
551
+ receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates(
552
+ reference_phi=receivers_node_phi,
553
+ reference_theta=receivers_node_theta,
554
+ rotate_latitude=latitude_local_coordinates,
555
+ rotate_longitude=longitude_local_coordinates)
556
+
557
+ # Each edge will be rotated according to the rotation matrix of its receiver
558
+ # node.
559
+ edge_rotation_matrices = receiver_rotation_matrices[receivers]
560
+
561
+ # Rotate all nodes to the rotated space of the corresponding edge.
562
+ # Note for receivers we can also do the matmul first and the gather second:
563
+ # ```
564
+ # receiver_pos_in_rotated_space = rotate_with_matrices(
565
+ # rotation_matrices, node_pos)[receivers]
566
+ # ```
567
+ # which is more efficient, however, we do gather first to keep it more
568
+ # symmetric with the sender computation.
569
+ receiver_pos_in_rotated_space = rotate_with_matrices(
570
+ edge_rotation_matrices, receivers_node_pos[receivers])
571
+ sender_pos_in_in_rotated_space = rotate_with_matrices(
572
+ edge_rotation_matrices, senders_node_pos[senders])
573
+ # Note, here, that because the rotated space is chosen according to the
574
+ # receiver, if:
575
+ # * latitude_local_coordinates = True: latitude for the receivers will be
576
+ # 0, that is the z coordinate will always be 0.
577
+ # * longitude_local_coordinates = True: longitude for the receivers will be
578
+ # 0, that is the y coordinate will be 0.
579
+
580
+ # Now we can just subtract.
581
+ # Note we are rotating to a local coordinate system, where the y-z axes are
582
+ # parallel to a tangent plane to the sphere, but still remain in a 3d space.
583
+ # Note that if both `latitude_local_coordinates` and
584
+ # `longitude_local_coordinates` are True, and edges are short,
585
+ # then the difference in x coordinate between sender and receiver
586
+ # should be small, so we could consider dropping the new x coordinate if
587
+ # we wanted to the tangent plane, however in doing so
588
+ # we would lose information about the curvature of the mesh, which may be
589
+ # important for very coarse meshes.
590
+ return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
591
+
592
+
593
+ def variable_to_stacked(
594
+ variable: xarray.Variable,
595
+ sizes: Mapping[str, int],
596
+ preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
597
+ ) -> xarray.Variable:
598
+ """Converts an xarray.Variable to preserved_dims + ("channels",).
599
+
600
+ Any dimensions other than those included in preserved_dims get stacked into a
601
+ final "channels" dimension. If any of the preserved_dims are missing then they
602
+ are added, with the data broadcast/tiled to match the sizes specified in
603
+ `sizes`.
604
+
605
+ Args:
606
+ variable: An xarray.Variable.
607
+ sizes: Mapping including sizes for any dimensions which are not present in
608
+ `variable` but are needed for the output. This may be needed for example
609
+ for a static variable with only ("lat", "lon") dims, or if you want to
610
+ encode just the latitude coordinates (a variable with dims ("lat",)).
611
+ preserved_dims: dimensions of variable to not be folded in channels.
612
+
613
+ Returns:
614
+ An xarray.Variable with dimensions preserved_dims + ("channels",).
615
+ """
616
+ stack_to_channels_dims = [
617
+ d for d in variable.dims if d not in preserved_dims]
618
+ if stack_to_channels_dims:
619
+ variable = variable.stack(channels=stack_to_channels_dims)
620
+ dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims}
621
+ dims["channels"] = variable.sizes.get("channels", 1)
622
+ return variable.set_dims(dims)
623
+
624
+
625
+ def dataset_to_stacked(
626
+ dataset: xarray.Dataset,
627
+ sizes: Optional[Mapping[str, int]] = None,
628
+ preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
629
+ ) -> xarray.DataArray:
630
+ """Converts an xarray.Dataset to a single stacked array.
631
+
632
+ This takes each consistuent data_var, converts it into BHWC layout
633
+ using `variable_to_stacked`, then concats them all along the channels axis.
634
+
635
+ Args:
636
+ dataset: An xarray.Dataset.
637
+ sizes: Mapping including sizes for any dimensions which are not present in
638
+ the `dataset` but are needed for the output. See variable_to_stacked.
639
+ preserved_dims: dimensions from the dataset that should not be folded in
640
+ the predictions channels.
641
+
642
+ Returns:
643
+ An xarray.DataArray with dimensions preserved_dims + ("channels",).
644
+ Existing coordinates for preserved_dims axes will be preserved, however
645
+ there will be no coordinates for "channels".
646
+ """
647
+ data_vars = [
648
+ variable_to_stacked(dataset.variables[name], sizes or dataset.sizes,
649
+ preserved_dims)
650
+ for name in sorted(dataset.data_vars.keys())
651
+ ]
652
+ coords = {
653
+ dim: coord
654
+ for dim, coord in dataset.coords.items()
655
+ if dim in preserved_dims
656
+ }
657
+ return xarray.DataArray(
658
+ data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords)
659
+
660
+
661
+ def stacked_to_dataset(
662
+ stacked_array: xarray.Variable,
663
+ template_dataset: xarray.Dataset,
664
+ preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
665
+ ) -> xarray.Dataset:
666
+ """The inverse of dataset_to_stacked.
667
+
668
+ Requires a template dataset to demonstrate the variables/shapes/coordinates
669
+ required.
670
+ All variables must have preserved_dims dimensions.
671
+
672
+ Args:
673
+ stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked
674
+ would if it was asked to encode `template_dataset`.
675
+ template_dataset: A template Dataset (or other mapping of DataArrays)
676
+ demonstrating the shape of output required (variables, shapes,
677
+ coordinates etc).
678
+ preserved_dims: dimensions from the target_template that were not folded in
679
+ the predictions channels. The preserved_dims need to be a subset of the
680
+ dims of all the variables of template_dataset.
681
+
682
+ Returns:
683
+ An xarray.Dataset (or other mapping of DataArrays) with the same shape and
684
+ type as template_dataset.
685
+ """
686
+ unstack_from_channels_sizes = {}
687
+ var_names = sorted(template_dataset.keys())
688
+ for name in var_names:
689
+ template_var = template_dataset[name]
690
+ if not all(dim in template_var.dims for dim in preserved_dims):
691
+ raise ValueError(
692
+ f"stacked_to_dataset requires all Variables to have {preserved_dims} "
693
+ f"dimensions, but found only {template_var.dims}.")
694
+ unstack_from_channels_sizes[name] = {
695
+ dim: size for dim, size in template_var.sizes.items()
696
+ if dim not in preserved_dims}
697
+
698
+ channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
699
+ for name, unstack_sizes in unstack_from_channels_sizes.items()}
700
+ total_expected_channels = sum(channels.values())
701
+ found_channels = stacked_array.sizes["channels"]
702
+ if total_expected_channels != found_channels:
703
+ raise ValueError(
704
+ f"Expected {total_expected_channels} channels but found "
705
+ f"{found_channels}, when trying to convert a stacked array of shape "
706
+ f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")
707
+
708
+ data_vars = {}
709
+ index = 0
710
+ for name in var_names:
711
+ template_var = template_dataset[name]
712
+ var = stacked_array.isel({"channels": slice(index, index + channels[name])})
713
+ index += channels[name]
714
+ var = var.unstack({"channels": unstack_from_channels_sizes[name]})
715
+ var = var.transpose(*template_var.dims)
716
+ data_vars[name] = xarray.DataArray(
717
+ data=var,
718
+ coords=template_var.coords,
719
+ # This might not always be the same as the name it's keyed under; it
720
+ # will refer to the original variable name, whereas the key might be
721
+ # some alias e.g. temperature_850 under which it should be logged:
722
+ name=template_var.name,
723
+ )
724
+ return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count
graphcast/normalization.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Wrappers for Predictors which allow them to work with normalized data.
15
+
16
+ The Predictor which is wrapped sees normalized inputs and targets, and makes
17
+ normalized predictions. The wrapper handles translating the predictions back
18
+ to the original domain.
19
+ """
20
+
21
+ import logging
22
+ from typing import Optional, Tuple
23
+
24
+ from graphcast import predictor_base
25
+ from graphcast import xarray_tree
26
+ import xarray
27
+
28
+
29
+ def normalize(values: xarray.Dataset,
30
+ scales: xarray.Dataset,
31
+ locations: Optional[xarray.Dataset],
32
+ ) -> xarray.Dataset:
33
+ """Normalize variables using the given scales and (optionally) locations."""
34
+ def normalize_array(array):
35
+ if array.name is None:
36
+ raise ValueError(
37
+ "Can't look up normalization constants because array has no name.")
38
+ if locations is not None:
39
+ if array.name in locations:
40
+ array = array - locations[array.name].astype(array.dtype)
41
+ else:
42
+ logging.warning('No normalization location found for %s', array.name)
43
+ if array.name in scales:
44
+ array = array / scales[array.name].astype(array.dtype)
45
+ else:
46
+ logging.warning('No normalization scale found for %s', array.name)
47
+ return array
48
+ return xarray_tree.map_structure(normalize_array, values)
49
+
50
+
51
+ def unnormalize(values: xarray.Dataset,
52
+ scales: xarray.Dataset,
53
+ locations: Optional[xarray.Dataset],
54
+ ) -> xarray.Dataset:
55
+ """Unnormalize variables using the given scales and (optionally) locations."""
56
+ def unnormalize_array(array):
57
+ if array.name is None:
58
+ raise ValueError(
59
+ "Can't look up normalization constants because array has no name.")
60
+ if array.name in scales:
61
+ array = array * scales[array.name].astype(array.dtype)
62
+ else:
63
+ logging.warning('No normalization scale found for %s', array.name)
64
+ if locations is not None:
65
+ if array.name in locations:
66
+ array = array + locations[array.name].astype(array.dtype)
67
+ else:
68
+ logging.warning('No normalization location found for %s', array.name)
69
+ return array
70
+ return xarray_tree.map_structure(unnormalize_array, values)
71
+
72
+
73
+ class InputsAndResiduals(predictor_base.Predictor):
74
+ """Wraps with a residual connection, normalizing inputs and target residuals.
75
+
76
+ The inner predictor is given inputs that are normalized using `locations`
77
+ and `scales` to roughly zero-mean unit variance.
78
+
79
+ For target variables that are present in the inputs, the inner predictor is
80
+ trained to predict residuals (target - last_frame_of_input) that have been
81
+ normalized using `residual_scales` (and optionally `residual_locations`) to
82
+ roughly unit variance / zero mean.
83
+
84
+ This replaces `residual.Predictor` in the case where you want normalization
85
+ that's based on the scales of the residuals.
86
+
87
+ Since we return the underlying predictor's loss on the normalized residuals,
88
+ if the underlying predictor is a sum of per-variable losses, the normalization
89
+ will affect the relative weighting of the per-variable loss terms (hopefully
90
+ in a good way).
91
+
92
+ For target variables *not* present in the inputs, the inner predictor is
93
+ trained to predict targets directly, that have been normalized in the same
94
+ way as the inputs.
95
+
96
+ The transforms applied to the targets (the residual connection and the
97
+ normalization) are applied in reverse to the predictions before returning
98
+ them.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ predictor: predictor_base.Predictor,
104
+ stddev_by_level: xarray.Dataset,
105
+ mean_by_level: xarray.Dataset,
106
+ diffs_stddev_by_level: xarray.Dataset):
107
+ self._predictor = predictor
108
+ self._scales = stddev_by_level
109
+ self._locations = mean_by_level
110
+ self._residual_scales = diffs_stddev_by_level
111
+ self._residual_locations = None
112
+
113
+ def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
114
+ if norm_prediction.sizes.get('time') != 1:
115
+ raise ValueError(
116
+ 'normalization.InputsAndResiduals only supports predicting a '
117
+ 'single timestep.')
118
+ if norm_prediction.name in inputs:
119
+ # Residuals are assumed to be predicted as normalized (unit variance),
120
+ # but the scale and location they need mapping to is that of the residuals
121
+ # not of the values themselves.
122
+ prediction = unnormalize(
123
+ norm_prediction, self._residual_scales, self._residual_locations)
124
+ # A prediction for which we have a corresponding input -- we are
125
+ # predicting the residual:
126
+ last_input = inputs[norm_prediction.name].isel(time=-1)
127
+ prediction = prediction + last_input
128
+ return prediction
129
+ else:
130
+ # A predicted variable which is not an input variable. We are predicting
131
+ # it directly, so unnormalize it directly to the target scale/location:
132
+ return unnormalize(norm_prediction, self._scales, self._locations)
133
+
134
+ def _subtract_input_and_normalize_target(self, inputs, target):
135
+ if target.sizes.get('time') != 1:
136
+ raise ValueError(
137
+ 'normalization.InputsAndResiduals only supports wrapping predictors'
138
+ 'that predict a single timestep.')
139
+ if target.name in inputs:
140
+ target_residual = target
141
+ last_input = inputs[target.name].isel(time=-1)
142
+ target_residual = target_residual - last_input
143
+ return normalize(
144
+ target_residual, self._residual_scales, self._residual_locations)
145
+ else:
146
+ return normalize(target, self._scales, self._locations)
147
+
148
+ def __call__(self,
149
+ inputs: xarray.Dataset,
150
+ targets_template: xarray.Dataset,
151
+ forcings: xarray.Dataset,
152
+ **kwargs
153
+ ) -> xarray.Dataset:
154
+ norm_inputs = normalize(inputs, self._scales, self._locations)
155
+ norm_forcings = normalize(forcings, self._scales, self._locations)
156
+ norm_predictions = self._predictor(
157
+ norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
158
+ return xarray_tree.map_structure(
159
+ lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
160
+ norm_predictions)
161
+
162
+ def loss(self,
163
+ inputs: xarray.Dataset,
164
+ targets: xarray.Dataset,
165
+ forcings: xarray.Dataset,
166
+ **kwargs,
167
+ ) -> predictor_base.LossAndDiagnostics:
168
+ """Returns the loss computed on normalized inputs and targets."""
169
+ norm_inputs = normalize(inputs, self._scales, self._locations)
170
+ norm_forcings = normalize(forcings, self._scales, self._locations)
171
+ norm_target_residuals = xarray_tree.map_structure(
172
+ lambda t: self._subtract_input_and_normalize_target(inputs, t),
173
+ targets)
174
+ return self._predictor.loss(
175
+ norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
176
+
177
+ def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
178
+ self,
179
+ inputs: xarray.Dataset,
180
+ targets: xarray.Dataset,
181
+ forcings: xarray.Dataset,
182
+ **kwargs,
183
+ ) -> Tuple[predictor_base.LossAndDiagnostics,
184
+ xarray.Dataset]:
185
+ """The loss computed on normalized data, with unnormalized predictions."""
186
+ norm_inputs = normalize(inputs, self._scales, self._locations)
187
+ norm_forcings = normalize(forcings, self._scales, self._locations)
188
+ norm_target_residuals = xarray_tree.map_structure(
189
+ lambda t: self._subtract_input_and_normalize_target(inputs, t),
190
+ targets)
191
+ (loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
192
+ norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
193
+ predictions = xarray_tree.map_structure(
194
+ lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
195
+ norm_predictions)
196
+ return (loss, scalars), predictions
graphcast/predictor_base.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Abstract base classes for an xarray-based Predictor API."""
15
+
16
+ import abc
17
+
18
+ from typing import Tuple
19
+
20
+ from graphcast import losses
21
+ from graphcast import xarray_jax
22
+ import jax.numpy as jnp
23
+ import xarray
24
+
25
+ LossAndDiagnostics = losses.LossAndDiagnostics
26
+
27
+
28
+ class Predictor(abc.ABC):
29
+ """A possibly-trainable predictor of weather, exposing an xarray-based API.
30
+
31
+ Typically wraps an underlying JAX model and handles translating the xarray
32
+ Dataset values to and from plain JAX arrays that are convenient for input to
33
+ (and output from) the underlying model.
34
+
35
+ Different subclasses may exist to wrap different kinds of underlying model,
36
+ e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
37
+ inputs/outputs, autoregressive models.
38
+
39
+ You can also implement a specific model directly as a Predictor if you want,
40
+ for example if it has quite specific/unique requirements for its input/output
41
+ or loss function, or if it's convenient to implement directly using xarray.
42
+ """
43
+
44
+ @abc.abstractmethod
45
+ def __call__(self,
46
+ inputs: xarray.Dataset,
47
+ targets_template: xarray.Dataset,
48
+ forcings: xarray.Dataset,
49
+ **optional_kwargs
50
+ ) -> xarray.Dataset:
51
+ """Makes predictions.
52
+
53
+ This is only used by the Experiment for inference / evaluation, with
54
+ training going via the .loss method. So it should default to making
55
+ predictions for evaluation, although you can also support making predictions
56
+ for use in the loss via an is_training argument -- see
57
+ LossFunctionPredictor which helps with that.
58
+
59
+ Args:
60
+ inputs: An xarray.Dataset of inputs.
61
+ targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
62
+ with the same shape as the targets, to demonstrate what kind of
63
+ predictions are required. You can use this to determine which variables,
64
+ levels and lead times must be predicted.
65
+ You are free to raise an error if you don't support predicting what is
66
+ requested.
67
+ forcings: An xarray.Dataset of forcings terms. Forcings are variables
68
+ that can be fed to the model, but do not need to be predicted. This is
69
+ often because this variable can be computed analytically (e.g. the toa
70
+ radiation of the sun is mostly a function of geometry) or are considered
71
+ to be controlled for the experiment (e.g., impose a scenario of C02
72
+ emission into the atmosphere). Unlike `inputs`, the `forcings` can
73
+ include information "from the future", that is, information at target
74
+ times specified in the `targets_template`.
75
+ **optional_kwargs: Implementations may support extra optional kwargs,
76
+ provided they set appropriate defaults for them.
77
+
78
+ Returns:
79
+ Predictions, as an xarray.Dataset or other mapping of DataArrays which
80
+ is capable of being evaluated against targets with shape given by
81
+ targets_template.
82
+ For probabilistic predictors which can return multiple samples from a
83
+ predictive distribution, these should (by convention) be returned along
84
+ an additional 'sample' dimension.
85
+ """
86
+
87
+ def loss(self,
88
+ inputs: xarray.Dataset,
89
+ targets: xarray.Dataset,
90
+ forcings: xarray.Dataset,
91
+ **optional_kwargs,
92
+ ) -> LossAndDiagnostics:
93
+ """Computes a training loss, for predictors that are trainable.
94
+
95
+ Why make this the Predictor's responsibility, rather than letting callers
96
+ compute their own loss function using predictions obtained from
97
+ Predictor.__call__?
98
+
99
+ Doing it this way gives Predictors more control over their training setup.
100
+ For example, some predictors may wish to train using different targets to
101
+ the ones they predict at evaluation time -- perhaps different lead times and
102
+ variables, perhaps training to predict transformed versions of targets
103
+ where the transform needs to be inverted at evaluation time, etc.
104
+
105
+ It's also necessary for generative models (VAEs, GANs, ...) where the
106
+ training loss is more complex and isn't expressible as a parameter-free
107
+ function of predictions and targets.
108
+
109
+ Args:
110
+ inputs: An xarray.Dataset.
111
+ targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
112
+ docs on __call__ for an explanation about the targets.
113
+ forcings: xarray.Dataset of forcing terms.
114
+ **optional_kwargs: Implementations may support extra optional kwargs,
115
+ provided they set appropriate defaults for them.
116
+
117
+ Returns:
118
+ loss: A DataArray with dimensions ('batch',) containing losses for each
119
+ element of the batch. These will be averaged to give the final
120
+ loss, locally and across replicas.
121
+ diagnostics: Mapping of additional quantities to log by name alongside the
122
+ loss. These will will typically correspond to terms in the loss. They
123
+ should also have dimensions ('batch',) and will be averaged over the
124
+ batch before logging.
125
+ You need not include the loss itself in this dict; it will be added for
126
+ you.
127
+ """
128
+ del targets, forcings, optional_kwargs
129
+ batch_size = inputs.sizes['batch']
130
+ dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
131
+ return dummy_loss, {} # pytype: disable=bad-return-type
132
+
133
+ def loss_and_predictions(
134
+ self,
135
+ inputs: xarray.Dataset,
136
+ targets: xarray.Dataset,
137
+ forcings: xarray.Dataset,
138
+ **optional_kwargs,
139
+ ) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
140
+ """Like .loss but also returns corresponding predictions.
141
+
142
+ Implementing this is optional as it's not used directly by the Experiment,
143
+ but it is required by autoregressive.Predictor when applying an inner
144
+ Predictor autoregressively at training time; we need a loss at each step but
145
+ also predictions to feed back in for the next step.
146
+
147
+ Note the loss itself may not be directly regressing the predictions towards
148
+ targets, the loss may be computed in terms of transformed predictions and
149
+ targets (or in some other way). For this reason we can't always cleanly
150
+ separate this into step 1: get predictions, step 2: compute loss from them,
151
+ hence the need for this combined method.
152
+
153
+ Args:
154
+ inputs:
155
+ targets:
156
+ forcings:
157
+ **optional_kwargs:
158
+ As for self.loss.
159
+
160
+ Returns:
161
+ (loss, diagnostics)
162
+ As for self.loss
163
+ predictions:
164
+ The predictions which the loss relates to. These should be of the same
165
+ shape as what you would get from
166
+ `self.__call__(inputs, targets_template=targets)`, and should be in the
167
+ same 'domain' as the inputs (i.e. they shouldn't be transformed
168
+ differently to how the predictor expects its inputs).
169
+ """
170
+ raise NotImplementedError
graphcast/rollout.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utils for rolling out models."""
15
+
16
+ from typing import Iterator
17
+
18
+ from absl import logging
19
+ import chex
20
+ import dask.array
21
+ from graphcast import xarray_tree
22
+ import jax
23
+ import numpy as np
24
+ import typing_extensions
25
+ import xarray
26
+
27
+
28
+ class PredictorFn(typing_extensions.Protocol):
29
+ """Functional version of base.Predictor.__call__ with explicit rng."""
30
+
31
+ def __call__(
32
+ self, rng: chex.PRNGKey, inputs: xarray.Dataset,
33
+ targets_template: xarray.Dataset,
34
+ forcings: xarray.Dataset,
35
+ **optional_kwargs,
36
+ ) -> xarray.Dataset:
37
+ ...
38
+
39
+
40
+ def chunked_prediction(
41
+ predictor_fn: PredictorFn,
42
+ rng: chex.PRNGKey,
43
+ inputs: xarray.Dataset,
44
+ targets_template: xarray.Dataset,
45
+ forcings: xarray.Dataset,
46
+ num_steps_per_chunk: int = 1,
47
+ verbose: bool = False,
48
+ ) -> xarray.Dataset:
49
+ """Outputs a long trajectory by iteratively concatenating chunked predictions.
50
+
51
+ Args:
52
+ predictor_fn: Function to use to make predictions for each chunk.
53
+ rng: Random key.
54
+ inputs: Inputs for the model.
55
+ targets_template: Template for the target prediction, requires targets
56
+ equispaced in time.
57
+ forcings: Optional forcing for the model.
58
+ num_steps_per_chunk: How many of the steps in `targets_template` to predict
59
+ at each call of `predictor_fn`. It must evenly divide the number of
60
+ steps in `targets_template`.
61
+ verbose: Whether to log the current chunk being predicted.
62
+
63
+ Returns:
64
+ Predictions for the targets template.
65
+
66
+ """
67
+ chunks_list = []
68
+ for prediction_chunk in chunked_prediction_generator(
69
+ predictor_fn=predictor_fn,
70
+ rng=rng,
71
+ inputs=inputs,
72
+ targets_template=targets_template,
73
+ forcings=forcings,
74
+ num_steps_per_chunk=num_steps_per_chunk,
75
+ verbose=verbose):
76
+ chunks_list.append(jax.device_get(prediction_chunk))
77
+ return xarray.concat(chunks_list, dim="time")
78
+
79
+
80
+ def chunked_prediction_generator(
81
+ predictor_fn: PredictorFn,
82
+ rng: chex.PRNGKey,
83
+ inputs: xarray.Dataset,
84
+ targets_template: xarray.Dataset,
85
+ forcings: xarray.Dataset,
86
+ num_steps_per_chunk: int = 1,
87
+ verbose: bool = False,
88
+ ) -> Iterator[xarray.Dataset]:
89
+ """Outputs a long trajectory by yielding chunked predictions.
90
+
91
+ Args:
92
+ predictor_fn: Function to use to make predictions for each chunk.
93
+ rng: Random key.
94
+ inputs: Inputs for the model.
95
+ targets_template: Template for the target prediction, requires targets
96
+ equispaced in time.
97
+ forcings: Optional forcing for the model.
98
+ num_steps_per_chunk: How many of the steps in `targets_template` to predict
99
+ at each call of `predictor_fn`. It must evenly divide the number of
100
+ steps in `targets_template`.
101
+ verbose: Whether to log the current chunk being predicted.
102
+
103
+ Yields:
104
+ The predictions for each chunked step of the chunked rollout, such as
105
+ if all predictions are concatenated in time this would match the targets
106
+ template in structure.
107
+
108
+ """
109
+
110
+ # Create copies to avoid mutating inputs.
111
+ inputs = xarray.Dataset(inputs)
112
+ targets_template = xarray.Dataset(targets_template)
113
+ forcings = xarray.Dataset(forcings)
114
+
115
+ if "datetime" in inputs.coords:
116
+ del inputs.coords["datetime"]
117
+
118
+ if "datetime" in targets_template.coords:
119
+ output_datetime = targets_template.coords["datetime"]
120
+ del targets_template.coords["datetime"]
121
+ else:
122
+ output_datetime = None
123
+
124
+ if "datetime" in forcings.coords:
125
+ del forcings.coords["datetime"]
126
+
127
+ num_target_steps = targets_template.dims["time"]
128
+ num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
129
+ if remainder != 0:
130
+ raise ValueError(
131
+ f"The number of steps per chunk {num_steps_per_chunk} must "
132
+ f"evenly divide the number of target steps {num_target_steps} ")
133
+
134
+ if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1:
135
+ raise ValueError("The targets time coordinates must be evenly spaced")
136
+
137
+ # Our template targets will always have a time axis corresponding for the
138
+ # timedeltas for the first chunk.
139
+ targets_chunk_time = targets_template.time.isel(
140
+ time=slice(0, num_steps_per_chunk))
141
+
142
+ current_inputs = inputs
143
+ for chunk_index in range(num_chunks):
144
+ if verbose:
145
+ logging.info("Chunk %d/%d", chunk_index, num_chunks)
146
+ logging.flush()
147
+
148
+ # Select targets for the time period that we are predicting for this chunk.
149
+ target_offset = num_steps_per_chunk * chunk_index
150
+ target_slice = slice(target_offset, target_offset + num_steps_per_chunk)
151
+ current_targets_template = targets_template.isel(time=target_slice)
152
+
153
+ # Replace the timedelta, by the one corresponding to the first chunk, so we
154
+ # don't recompile at every iteration, keeping the
155
+ actual_target_time = current_targets_template.coords["time"]
156
+ current_targets_template = current_targets_template.assign_coords(
157
+ time=targets_chunk_time).compute()
158
+
159
+ current_forcings = forcings.isel(time=target_slice)
160
+ current_forcings = current_forcings.assign_coords(time=targets_chunk_time)
161
+ current_forcings = current_forcings.compute()
162
+ # Make predictions for the chunk.
163
+ rng, this_rng = jax.random.split(rng)
164
+ predictions = predictor_fn(
165
+ rng=this_rng,
166
+ inputs=current_inputs,
167
+ targets_template=current_targets_template,
168
+ forcings=current_forcings)
169
+
170
+ next_frame = xarray.merge([predictions, current_forcings])
171
+
172
+ next_inputs = _get_next_inputs(current_inputs, next_frame)
173
+
174
+ # Shift timedelta coordinates, so we don't recompile at every iteration.
175
+ next_inputs = next_inputs.assign_coords(time=current_inputs.coords["time"])
176
+ current_inputs = next_inputs
177
+
178
+ # At this point we can assign the actual targets time coordinates.
179
+ predictions = predictions.assign_coords(time=actual_target_time)
180
+ if output_datetime is not None:
181
+ predictions.coords["datetime"] = output_datetime.isel(
182
+ time=target_slice)
183
+ yield predictions
184
+ del predictions
185
+
186
+
187
+ def _get_next_inputs(
188
+ prev_inputs: xarray.Dataset, next_frame: xarray.Dataset,
189
+ ) -> xarray.Dataset:
190
+ """Computes next inputs, from previous inputs and predictions."""
191
+
192
+ # Make sure are are predicting all inputs with a time axis.
193
+ non_predicted_or_forced_inputs = list(
194
+ set(prev_inputs.keys()) - set(next_frame.keys()))
195
+ if "time" in prev_inputs[non_predicted_or_forced_inputs].dims:
196
+ raise ValueError(
197
+ "Found an input with a time index that is not predicted or forced.")
198
+
199
+ # Keys we need to copy from predictions to inputs.
200
+ next_inputs_keys = list(
201
+ set(next_frame.keys()).intersection(set(prev_inputs.keys())))
202
+ next_inputs = next_frame[next_inputs_keys]
203
+
204
+ # Apply concatenate next frame with inputs, crop what we don't need.
205
+ num_inputs = prev_inputs.dims["time"]
206
+ return (
207
+ xarray.concat(
208
+ [prev_inputs, next_inputs], dim="time", data_vars="different")
209
+ .tail(time=num_inputs))
210
+
211
+
212
+ def extend_targets_template(
213
+ targets_template: xarray.Dataset,
214
+ required_num_steps: int) -> xarray.Dataset:
215
+ """Extends `targets_template` to `required_num_steps` with lazy arrays.
216
+
217
+ It uses lazy dask arrays of zeros, so it does not require instantiating the
218
+ array in memory.
219
+
220
+ Args:
221
+ targets_template: Input template to extend.
222
+ required_num_steps: Number of steps required in the returned template.
223
+
224
+ Returns:
225
+ `xarray.Dataset` identical in variables and timestep to `targets_template`
226
+ full of `dask.array.zeros` such that the time axis has `required_num_steps`.
227
+
228
+ """
229
+
230
+ # Extend the "time" and "datetime" coordinates
231
+ time = targets_template.coords["time"]
232
+
233
+ # Assert the first target time corresponds to the timestep.
234
+ timestep = time[0].data
235
+ if time.shape[0] > 1:
236
+ assert np.all(timestep == time[1:] - time[:-1])
237
+
238
+ extended_time = (np.arange(required_num_steps) + 1) * timestep
239
+
240
+ if "datetime" in targets_template.coords:
241
+ datetime = targets_template.coords["datetime"]
242
+ extended_datetime = (datetime[0].data - timestep) + extended_time
243
+ else:
244
+ extended_datetime = None
245
+
246
+ # Replace the values with empty dask arrays extending the time coordinates.
247
+ datetime = targets_template.coords["time"]
248
+
249
+ def extend_time(data_array: xarray.DataArray) -> xarray.DataArray:
250
+ dims = data_array.dims
251
+ shape = list(data_array.shape)
252
+ shape[dims.index("time")] = required_num_steps
253
+ dask_data = dask.array.zeros(
254
+ shape=tuple(shape),
255
+ chunks=-1, # Will give chunk info directly to `ChunksToZarr``.
256
+ dtype=data_array.dtype)
257
+
258
+ coords = dict(data_array.coords)
259
+ coords["time"] = extended_time
260
+
261
+ if extended_datetime is not None:
262
+ coords["datetime"] = ("time", extended_datetime)
263
+
264
+ return xarray.DataArray(
265
+ dims=dims,
266
+ data=dask_data,
267
+ coords=coords)
268
+
269
+ return xarray_tree.map_structure(extend_time, targets_template)
graphcast/solar_radiation.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Computes TOA incident solar radiation compatible with ERA5.
15
+
16
+ The Top-Of-the-Atmosphere (TOA) incident solar radiation is available in the
17
+ ERA5 dataset as the parameter `toa_incident_solar_radiation` (or `tisr`). This
18
+ represents the TOA solar radiation flux integrated over a period of one hour
19
+ ending at the timestamp given by the `datetime` coordinate. See
20
+ https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
21
+ https://codes.ecmwf.int/grib/param-db/?id=212.
22
+ """
23
+
24
+ from collections.abc import Callable, Sequence
25
+ import dataclasses
26
+ import functools
27
+
28
+ import chex
29
+ import jax
30
+ import jax.numpy as jnp
31
+ import numpy as np
32
+ import pandas as pd
33
+ import xarray as xa
34
+
35
+
36
+ # Default value of the `integration_period` argument to be compatible with ERA5.
37
+ _DEFAULT_INTEGRATION_PERIOD = pd.Timedelta(hours=1)
38
+
39
+ # Default value for the `num_integration_bins` argument. This provides a good
40
+ # approximation of the solar radiation in ERA5.
41
+ _DEFAULT_NUM_INTEGRATION_BINS = 360
42
+
43
+ # The length of a Julian year in days.
44
+ # https://en.wikipedia.org/wiki/Julian_year_(astronomy)
45
+ _JULIAN_YEAR_LENGTH_IN_DAYS = 365.25
46
+
47
+ # Julian Date for the J2000 epoch, a standard reference used in astronomy.
48
+ # https://en.wikipedia.org/wiki/Epoch_(astronomy)#Julian_years_and_J2000
49
+ _J2000_EPOCH = 2451545.0
50
+
51
+ # Number of seconds in a day.
52
+ _SECONDS_PER_DAY = 60 * 60 * 24
53
+
54
+
55
+ _TimestampLike = str | pd.Timestamp | np.datetime64
56
+ _TimedeltaLike = str | pd.Timedelta | np.timedelta64
57
+
58
+
59
+ # Interface for loading Total Solar Irradiance (TSI) data.
60
+ # Returns a xa.DataArray containing yearly average TSI values with a `time`
61
+ # coordinate in units of years since 0000-1-1. E.g. 2023.5 corresponds to
62
+ # the middle of the year 2023.
63
+ TsiDataLoader = Callable[[], xa.DataArray]
64
+
65
+
66
+ # Total Solar Irradiance (TSI): Energy input to the top of the Earth's
67
+ # atmosphere in W⋅m⁻². TSI varies with time. This is the reference TSI value
68
+ # that can be used when more accurate data is not available.
69
+ # https://www.ncei.noaa.gov/products/climate-data-records/total-solar-irradiance
70
+ # https://github.com/ecmwf-ifs/ecrad/blob/6db82f929fb75028cc20606a04da87c0abe9b642/radiation/radiation_ecckd.F90#L296
71
+ _REFERENCE_TSI = 1361.0
72
+
73
+
74
+ def reference_tsi_data() -> xa.DataArray:
75
+ """A TsiDataProvider that returns a single reference TSI value."""
76
+ return xa.DataArray(
77
+ np.array([_REFERENCE_TSI]),
78
+ dims=["time"],
79
+ coords={"time": np.array([0.0])},
80
+ )
81
+
82
+
83
+ def era5_tsi_data() -> xa.DataArray:
84
+ """A TsiDataProvider that returns ERA5 compatible TSI data."""
85
+ # ECMWF provided the data used for ERA5, which was hardcoded in the IFS (cycle
86
+ # 41r2). The values were scaled down to agree better with more recent
87
+ # observations of the sun.
88
+ time = np.arange(1951.5, 2035.5, 1.0)
89
+ tsi = 0.9965 * np.array([
90
+ # fmt: off
91
+ # 1951-1995 (non-repeating sequence)
92
+ 1365.7765, 1365.7676, 1365.6284, 1365.6564, 1365.7773,
93
+ 1366.3109, 1366.6681, 1366.6328, 1366.3828, 1366.2767,
94
+ 1365.9199, 1365.7484, 1365.6963, 1365.6976, 1365.7341,
95
+ 1365.9178, 1366.1143, 1366.1644, 1366.2476, 1366.2426,
96
+ 1365.9580, 1366.0525, 1365.7991, 1365.7271, 1365.5345,
97
+ 1365.6453, 1365.8331, 1366.2747, 1366.6348, 1366.6482,
98
+ 1366.6951, 1366.2859, 1366.1992, 1365.8103, 1365.6416,
99
+ 1365.6379, 1365.7899, 1366.0826, 1366.6479, 1366.5533,
100
+ 1366.4457, 1366.3021, 1366.0286, 1365.7971, 1365.6996,
101
+ # 1996-2008 (13 year cycle, repeated below)
102
+ 1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
103
+ 1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
104
+ 1365.8107, 1365.7240, 1365.6918,
105
+ # 2009-2021
106
+ 1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
107
+ 1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
108
+ 1365.8107, 1365.7240, 1365.6918,
109
+ # 2022-2034
110
+ 1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
111
+ 1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
112
+ 1365.8107, 1365.7240, 1365.6918,
113
+ # fmt: on
114
+ ])
115
+ return xa.DataArray(tsi, dims=["time"], coords={"time": time})
116
+
117
+
118
+ # HRES compatible TSI data is from IFS cycle 47r1. The dataset can be obtained
119
+ # from the ECRAD package: https://confluence.ecmwf.int/display/ECRAD.
120
+ # The example code below can load this dataset from a local file.
121
+
122
+ # def hres_tsi_data() -> xa.DataArray:
123
+ # with open("total_solar_irradiance_CMIP6_47r1.nc", "rb") as f:
124
+ # with xa.load_dataset(f, decode_times=False) as ds:
125
+ # return ds["tsi"]
126
+
127
+
128
+ _DEFAULT_TSI_DATA_LOADER: TsiDataLoader = era5_tsi_data
129
+
130
+
131
+ def get_tsi(
132
+ timestamps: Sequence[_TimestampLike], tsi_data: xa.DataArray
133
+ ) -> chex.Array:
134
+ """Returns TSI values for the given timestamps.
135
+
136
+ TSI values are interpolated from the provided yearly TSI data.
137
+
138
+ Args:
139
+ timestamps: Timestamps for which to compute TSI values.
140
+ tsi_data: A DataArray with a single dimension `time` that has coordinates in
141
+ units of years since 0000-1-1. E.g. 2023.5 corresponds to the middle of
142
+ the year 2023.
143
+
144
+ Returns:
145
+ An Array containing interpolated TSI data.
146
+ """
147
+ timestamps = pd.DatetimeIndex(timestamps)
148
+ timestamps_date = pd.DatetimeIndex(timestamps.date)
149
+ day_fraction = (timestamps - timestamps_date) / pd.Timedelta(days=1)
150
+ year_length = 365 + timestamps.is_leap_year
151
+ year_fraction = (timestamps.dayofyear - 1 + day_fraction) / year_length
152
+ fractional_year = timestamps.year + year_fraction
153
+ return np.interp(fractional_year, tsi_data.coords["time"].data, tsi_data.data)
154
+
155
+
156
+ @dataclasses.dataclass(frozen=True)
157
+ class _OrbitalParameters:
158
+ """Parameters characterising Earth's position relative to the Sun.
159
+
160
+ The parameters characterize the position of the Earth in its orbit around the
161
+ Sun for specific points in time. Each attribute is an N-dimensional array
162
+ to represent orbital parameters for multiple points in time.
163
+
164
+ Attributes:
165
+ theta: The number of Julian years since the Julian epoch J2000.0.
166
+ rotational_phase: The phase of the Earth's rotation along its axis as a
167
+ ratio with 0 representing the phase at Julian epoch J2000.0 at exactly
168
+ 12:00 Terrestrial Time (TT). Multiplying this value by `2*pi` yields the
169
+ phase in radians.
170
+ sin_declination: Sine of the declination of the Sun as seen from the Earth.
171
+ cos_declination: Cosine of the declination of the Sun as seen from the
172
+ Earth.
173
+ eq_of_time_seconds: The value of the equation of time, in seconds.
174
+ solar_distance_au: Earth-Sun distance in astronomical units.
175
+ """
176
+
177
+ theta: chex.Array
178
+ rotational_phase: chex.Array
179
+ sin_declination: chex.Array
180
+ cos_declination: chex.Array
181
+ eq_of_time_seconds: chex.Array
182
+ solar_distance_au: chex.Array
183
+
184
+
185
+ def _get_j2000_days(timestamp: pd.Timestamp) -> float:
186
+ """Returns the number of days since the J2000 epoch.
187
+
188
+ Args:
189
+ timestamp: A timestamp for which to compute the J2000 days.
190
+
191
+ Returns:
192
+ The J2000 days corresponding to the input timestamp.
193
+ """
194
+ return timestamp.to_julian_date() - _J2000_EPOCH
195
+
196
+
197
+ def _get_orbital_parameters(j2000_days: chex.Array) -> _OrbitalParameters:
198
+ """Computes the orbital parameters for the given J2000 days.
199
+
200
+ Args:
201
+ j2000_days: Timestamps represented as the number of days since the J2000
202
+ epoch.
203
+
204
+ Returns:
205
+ Orbital parameters for the given timestamps. Each attribute of the return
206
+ value is an array containing the same dimensions as the input.
207
+ """
208
+ # Orbital parameters are computed based on the formulas in this code, which
209
+ # were determined empirically to produce radiation values similar to ERA5:
210
+ # https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/sucst.F90
211
+ # https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/fctast.cdk
212
+ # https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/fcttim.cdk
213
+ # There are many variations to these formulas, but since the goal is to match
214
+ # the values in ERA5, the formulas were implemented as is. Comments reference
215
+ # the notation used in those sources. Here are some additional references
216
+ # related to the quantities being computed here:
217
+ # https://aa.usno.navy.mil/faq/sun_approx
218
+ # https://en.wikipedia.org/wiki/Position_of_the_Sun
219
+ # https://en.wikipedia.org/wiki/Equation_of_time
220
+
221
+ # Number of Julian years since the J2000 epoch (including fractional years).
222
+ theta = j2000_days / _JULIAN_YEAR_LENGTH_IN_DAYS
223
+ # The phase of the Earth's rotation along its axis as a ratio. 0 represents
224
+ # Julian epoch J2000.0 at exactly 12:00 Terrestrial Time (TT).
225
+ rotational_phase = j2000_days % 1.0
226
+
227
+ # REL(PTETA).
228
+ rel = 1.7535 + 6.283076 * theta
229
+ # REM(PTETA).
230
+ rem = 6.240041 + 6.283020 * theta
231
+ # RLLS(PTETA).
232
+ rlls = 4.8951 + 6.283076 * theta
233
+
234
+ # Variables used in the three polynomials below.
235
+ one = jnp.ones_like(theta)
236
+ sin_rel = jnp.sin(rel)
237
+ cos_rel = jnp.cos(rel)
238
+ sin_two_rel = jnp.sin(2.0 * rel)
239
+ cos_two_rel = jnp.cos(2.0 * rel)
240
+ sin_two_rlls = jnp.sin(2.0 * rlls)
241
+ cos_two_rlls = jnp.cos(2.0 * rlls)
242
+ sin_four_rlls = jnp.sin(4.0 * rlls)
243
+ sin_rem = jnp.sin(rem)
244
+ sin_two_rem = jnp.sin(2.0 * rem)
245
+
246
+ # Ecliptic longitude of the Sun - RLLLS(PTETA).
247
+ rllls = jnp.dot(
248
+ jnp.stack(
249
+ [one, theta, sin_rel, cos_rel, sin_two_rel, cos_two_rel], axis=-1
250
+ ),
251
+ jnp.array([4.8952, 6.283320, -0.0075, -0.0326, -0.0003, 0.0002]),
252
+ )
253
+
254
+ # Angle in radians between the Earth's rotational axis and its orbital axis.
255
+ # Equivalent to 23.4393°.
256
+ repsm = 0.409093
257
+
258
+ # Declination of the Sun - RDS(teta).
259
+ sin_declination = jnp.sin(repsm) * jnp.sin(rllls)
260
+ cos_declination = jnp.sqrt(1.0 - sin_declination**2)
261
+
262
+ # Equation of time in seconds - RET(PTETA).
263
+ eq_of_time_seconds = jnp.dot(
264
+ jnp.stack(
265
+ [
266
+ sin_two_rlls,
267
+ sin_rem,
268
+ sin_rem * cos_two_rlls,
269
+ sin_four_rlls,
270
+ sin_two_rem,
271
+ ],
272
+ axis=-1,
273
+ ),
274
+ jnp.array([591.8, -459.4, 39.5, -12.7, -4.8]),
275
+ )
276
+
277
+ # Earth-Sun distance in astronomical units - RRS(PTETA).
278
+ solar_distance_au = jnp.dot(
279
+ jnp.stack([one, sin_rel, cos_rel], axis=-1),
280
+ jnp.array([1.0001, -0.0163, 0.0037]),
281
+ )
282
+
283
+ return _OrbitalParameters(
284
+ theta=theta,
285
+ rotational_phase=rotational_phase,
286
+ sin_declination=sin_declination,
287
+ cos_declination=cos_declination,
288
+ eq_of_time_seconds=eq_of_time_seconds,
289
+ solar_distance_au=solar_distance_au,
290
+ )
291
+
292
+
293
+ def _get_solar_sin_altitude(
294
+ op: _OrbitalParameters,
295
+ sin_latitude: chex.Array,
296
+ cos_latitude: chex.Array,
297
+ longitude: chex.Array,
298
+ ) -> chex.Array:
299
+ """Returns the sine of the solar altitude angle.
300
+
301
+ All computations are vectorized. Dimensions of all the inputs should be
302
+ broadcastable using standard NumPy rules. For example, if `op` has shape
303
+ `(T, 1, 1)`, `latitude` has shape `(1, H, 1)`, and `longitude` has shape
304
+ `(1, H, W)`, the return value will have shape `(T, H, W)`.
305
+
306
+ Args:
307
+ op: Orbital parameters characterising Earth's position relative to the Sun.
308
+ sin_latitude: Sine of latitude coordinates.
309
+ cos_latitude: Cosine of latitude coordinates.
310
+ longitude: Longitude coordinates in radians.
311
+
312
+ Returns:
313
+ Sine of the solar altitude angle for each set of orbital parameters and each
314
+ geographical coordinates. The returned array has the shape resulting from
315
+ broadcasting all the inputs together.
316
+ """
317
+ solar_time = op.rotational_phase + op.eq_of_time_seconds / _SECONDS_PER_DAY
318
+ # https://en.wikipedia.org/wiki/Hour_angle#Solar_hour_angle
319
+ hour_angle = 2.0 * jnp.pi * solar_time + longitude
320
+ # https://en.wikipedia.org/wiki/Solar_zenith_angle
321
+ sin_altitude = (
322
+ cos_latitude * op.cos_declination * jnp.cos(hour_angle)
323
+ + sin_latitude * op.sin_declination
324
+ )
325
+ return sin_altitude
326
+
327
+
328
+ def _get_radiation_flux(
329
+ j2000_days: chex.Array,
330
+ sin_latitude: chex.Array,
331
+ cos_latitude: chex.Array,
332
+ longitude: chex.Array,
333
+ tsi: chex.Array,
334
+ ) -> chex.Array:
335
+ """Computes the instantaneous TOA incident solar radiation flux.
336
+
337
+ Computes the instantanous Top-Of-the-Atmosphere (TOA) incident radiation flux
338
+ in W⋅m⁻² for the given timestamps and locations on the surface of the Earth.
339
+ See https://en.wikipedia.org/wiki/Solar_irradiance.
340
+
341
+ All inputs are assumed to be broadcastable together using standard NumPy
342
+ rules.
343
+
344
+ Args:
345
+ j2000_days: Timestamps represented as the number of days since the J2000
346
+ epoch.
347
+ sin_latitude: Sine of latitude coordinates.
348
+ cos_latitude: Cosine of latitude coordinates.
349
+ longitude: Longitude coordinates in radians.
350
+ tsi: Total Solar Irradiance (TSI) in W⋅m⁻². This can be a scalar (default)
351
+ to use the same TSI value for all the inputs, or an array to allow TSI to
352
+ depend on the timestamps.
353
+
354
+ Returns:
355
+ The instataneous TOA incident solar radiation flux in W⋅m⁻² for the given
356
+ timestamps and geographical coordinates. The returned array has the shape
357
+ resulting from broadcasting all the inputs together.
358
+ """
359
+ op = _get_orbital_parameters(j2000_days)
360
+ # Attenuation of the solar radiation based on the solar distance.
361
+ solar_factor = (1.0 / op.solar_distance_au) ** 2
362
+ sin_altitude = _get_solar_sin_altitude(
363
+ op, sin_latitude, cos_latitude, longitude
364
+ )
365
+ return tsi * solar_factor * jnp.maximum(sin_altitude, 0.0)
366
+
367
+
368
+ def _get_integrated_radiation(
369
+ j2000_days: chex.Array,
370
+ sin_latitude: chex.Array,
371
+ cos_latitude: chex.Array,
372
+ longitude: chex.Array,
373
+ tsi: chex.Array,
374
+ integration_period: pd.Timedelta,
375
+ num_integration_bins: int,
376
+ ) -> chex.Array:
377
+ """Returns the TOA solar radiation flux integrated over a time period.
378
+
379
+ Integrates the instantaneous TOA solar radiation flux over a time period.
380
+ The input timestamps represent the end times of each integration period.
381
+ When the integration period is one hour this approximates the
382
+ `toa_incident_solar_radiation` (or `tisr`) parameter from the ERA5 dataset.
383
+ See https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
384
+ https://codes.ecmwf.int/grib/param-db/?id=212.
385
+
386
+ All inputs are assumed to be broadcastable together using standard NumPy
387
+ rules. To approximate the integral, the instantaneous radiation is computed
388
+ at `num_integration_bins+1` time steps using `_get_radiation_flux` and
389
+ integrated using the trapezoidal rule. A dimension is appended at the end
390
+ of all inputs to compute the instantaneous radiation, which is then integrated
391
+ over to compute the final result.
392
+
393
+ Args:
394
+ j2000_days: Timestamps represented as the number of days since the J2000
395
+ epoch. These correspond to the end times of each integration period.
396
+ sin_latitude: Sine of latitude coordinates.
397
+ cos_latitude: Cosine of latitude coordinates.
398
+ longitude: Longitude in radians.
399
+ tsi: Total Solar Irradiance (TSI) in W⋅m⁻².
400
+ integration_period: Integration period.
401
+ num_integration_bins: Number of bins to divide the `integration_period` to
402
+ approximate the integral using the trapezoidal rule.
403
+
404
+ Returns:
405
+ The TOA solar radiation flux integrated over the requested time period for
406
+ the given timestamps and geographical coordinates. Unit is J⋅m⁻² .
407
+ """
408
+ # Offsets for the integration time steps.
409
+ offsets = (
410
+ pd.timedelta_range(
411
+ start=-integration_period,
412
+ end=pd.Timedelta(0),
413
+ periods=num_integration_bins + 1,
414
+ )
415
+ / pd.Timedelta(days=1)
416
+ ).to_numpy()
417
+
418
+ # Integration happens over the time dimension. Compute the instantaneous
419
+ # radiation flux for all the integration time steps by appending a dimension
420
+ # to all the inputs and adding `offsets` to `j2000_days` (will be broadcast
421
+ # over all the other dimensions).
422
+ fluxes = _get_radiation_flux(
423
+ j2000_days=jnp.expand_dims(j2000_days, axis=-1) + offsets,
424
+ sin_latitude=jnp.expand_dims(sin_latitude, axis=-1),
425
+ cos_latitude=jnp.expand_dims(cos_latitude, axis=-1),
426
+ longitude=jnp.expand_dims(longitude, axis=-1),
427
+ tsi=jnp.expand_dims(tsi, axis=-1),
428
+ )
429
+
430
+ # Size of each bin in seconds. The instantaneous solar radiation flux is
431
+ # returned in units of W⋅m⁻². Integrating over time expressed in seconds
432
+ # yields a result in units of J⋅m⁻².
433
+ dx = (integration_period / num_integration_bins) / pd.Timedelta(seconds=1)
434
+ return jax.scipy.integrate.trapezoid(fluxes, dx=dx)
435
+
436
+
437
+ _get_integrated_radiation_jitted = jax.jit(
438
+ _get_integrated_radiation,
439
+ static_argnames=["integration_period", "num_integration_bins"],
440
+ )
441
+
442
+
443
+ def get_toa_incident_solar_radiation(
444
+ timestamps: Sequence[_TimestampLike],
445
+ latitude: chex.Array,
446
+ longitude: chex.Array,
447
+ tsi_data: xa.DataArray | None = None,
448
+ integration_period: _TimedeltaLike = _DEFAULT_INTEGRATION_PERIOD,
449
+ num_integration_bins: int = _DEFAULT_NUM_INTEGRATION_BINS,
450
+ use_jit: bool = False,
451
+ ) -> chex.Array:
452
+ """Computes the solar radiation incident at the top of the atmosphere.
453
+
454
+ The solar radiation is computed for each element in `timestamps` for all the
455
+ locations on the grid determined by the `latitude` and `longitude` parameters.
456
+
457
+ To approximate the `toa_incident_solar_radiation` (or `tisr`) parameter from
458
+ the ERA5 dataset, set `integration_period` to one hour (default). See
459
+ https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
460
+ https://codes.ecmwf.int/grib/param-db/?id=212.
461
+
462
+ Args:
463
+ timestamps: Timestamps for which to compute the solar radiation.
464
+ latitude: The latitude coordinates in degrees of the grid for which to
465
+ compute the solar radiation.
466
+ longitude: The longitude coordinates in degrees of the grid for which to
467
+ compute the solar radiation.
468
+ tsi_data: A DataArray containing yearly TSI data as returned by a
469
+ `TsiDataLoader`. The default is to use ERA5 compatible TSI data.
470
+ integration_period: Timedelta to use to integrate the radiation, e.g. if
471
+ producing radiation for 1989-11-08 21:00:00, and `integration_period` is
472
+ "1h", radiation will be integrated from 1989-11-08 20:00:00 to 1989-11-08
473
+ 21:00:00. The default value ("1h") matches ERA5.
474
+ num_integration_bins: Number of equally spaced bins to divide the
475
+ `integration_period` in when approximating the integral using the
476
+ trapezoidal rule. Performance and peak memory usage are affected by this
477
+ value. The default (360) provides a good approximation, but lower values
478
+ may work to improve performance and reduce memory usage.
479
+ use_jit: Set to True to use the jitted implementation, or False (default) to
480
+ use the non-jitted one.
481
+
482
+ Returns:
483
+ An 3D array with dimensions (time, lat, lon) containing the total
484
+ top of atmosphere solar radiation integrated for the `integration_period`
485
+ up to each timestamp.
486
+ """
487
+ # Add a trailing dimension to latitude to get dimensions (lat, lon).
488
+ lat = jnp.radians(latitude).reshape((-1, 1))
489
+ lon = jnp.radians(longitude)
490
+ sin_lat = jnp.sin(lat)
491
+ cos_lat = jnp.cos(lat)
492
+ integration_period = pd.Timedelta(integration_period)
493
+ if tsi_data is None:
494
+ tsi_data = _DEFAULT_TSI_DATA_LOADER()
495
+ tsi = get_tsi(timestamps, tsi_data)
496
+ fn = (
497
+ _get_integrated_radiation_jitted if use_jit else _get_integrated_radiation
498
+ )
499
+
500
+ # Compute integral for each timestamp individually. Although this could be
501
+ # done in one step, peak memory usage would be proportional to
502
+ # `len(timestamps) * num_integration_bins`. Computing each timestamp
503
+ # individually reduces this to `max(len(timestamps), num_integration_bins)`.
504
+ # E.g. memory usage for a single timestamp, with a full 0.25° grid and 360
505
+ # integration bins is about 1.5 GB (1440 * 721 * 361 * 4 bytes); computing
506
+ # forcings for 40 prediction steps would require 60 GB.
507
+ results = []
508
+ for idx, timestamp in enumerate(timestamps):
509
+ results.append(
510
+ fn(
511
+ j2000_days=jnp.array(_get_j2000_days(pd.Timestamp(timestamp))),
512
+ sin_latitude=sin_lat,
513
+ cos_latitude=cos_lat,
514
+ longitude=lon,
515
+ tsi=tsi[idx],
516
+ integration_period=integration_period,
517
+ num_integration_bins=num_integration_bins,
518
+ )
519
+ )
520
+ return jnp.stack(results, axis=0)
521
+
522
+
523
+ def get_toa_incident_solar_radiation_for_xarray(
524
+ data_array_like: xa.DataArray | xa.Dataset,
525
+ tsi_data: xa.DataArray | None = None,
526
+ integration_period: _TimedeltaLike = _DEFAULT_INTEGRATION_PERIOD,
527
+ num_integration_bins: int = _DEFAULT_NUM_INTEGRATION_BINS,
528
+ use_jit: bool = False,
529
+ ) -> xa.DataArray:
530
+ """Computes the solar radiation incident at the top of the atmosphere.
531
+
532
+ This method is a wrapper for `get_toa_incident_solar_radiation` using
533
+ coordinates from an Xarray and returning an Xarray.
534
+
535
+ Args:
536
+ data_array_like: A xa.Dataset or xa.DataArray from which to take the time
537
+ and spatial coordinates for which to compute the solar radiation. It must
538
+ contain `lat` and `lon` spatial dimensions with corresponding coordinates.
539
+ If a `time` dimension is present, the `datetime` coordinate should be a
540
+ vector associated with that dimension containing timestamps for which to
541
+ compute the solar radiation. Otherwise, the `datetime` coordinate should
542
+ be a scalar representing the timestamp for which to compute the solar
543
+ radiation.
544
+ tsi_data: A DataArray containing yearly TSI data as returned by a
545
+ `TsiDataLoader`. The default is to use ERA5 compatible TSI data.
546
+ integration_period: Timedelta to use to integrate the radiation, e.g. if
547
+ producing radiation for 1989-11-08 21:00:00, and `integration_period` is
548
+ "1h", radiation will be integrated from 1989-11-08 20:00:00 to 1989-11-08
549
+ 21:00:00. The default value ("1h") matches ERA5.
550
+ num_integration_bins: Number of equally spaced bins to divide the
551
+ `integration_period` in when approximating the integral using the
552
+ trapezoidal rule. Performance and peak memory usage are affected by this
553
+ value. The default (360) provides a good approximation, but lower values
554
+ may work to improve performance and reduce memory usage.
555
+ use_jit: Set to True to use the jitted implementation, or False to use the
556
+ non-jitted one.
557
+
558
+ Returns:
559
+ xa.DataArray with dimensions `(time, lat, lon)` if `data_array_like` had
560
+ a `time` dimension; or dimensions `(lat, lon)` otherwise. The `datetime`
561
+ coordinates and those for the dimensions are copied to the returned array.
562
+ The array contains the total top of atmosphere solar radiation integrated
563
+ for `integration_period` up to the corresponding `datetime`.
564
+
565
+ Raises:
566
+ ValueError: If there are missing coordinates or dimensions.
567
+ """
568
+ missing_dims = set(["lat", "lon"]) - set(data_array_like.dims)
569
+ if missing_dims:
570
+ raise ValueError(
571
+ f"'{missing_dims}' dimensions are missing in `data_array_like`."
572
+ )
573
+
574
+ missing_coords = set(["datetime", "lat", "lon"]) - set(data_array_like.coords)
575
+ if missing_coords:
576
+ raise ValueError(
577
+ f"'{missing_coords}' coordinates are missing in `data_array_like`."
578
+ )
579
+
580
+ if "time" in data_array_like.dims:
581
+ timestamps = data_array_like.coords["datetime"].data
582
+ else:
583
+ timestamps = [data_array_like.coords["datetime"].data.item()]
584
+
585
+ radiation = get_toa_incident_solar_radiation(
586
+ timestamps=timestamps,
587
+ latitude=data_array_like.coords["lat"].data,
588
+ longitude=data_array_like.coords["lon"].data,
589
+ tsi_data=tsi_data,
590
+ integration_period=integration_period,
591
+ num_integration_bins=num_integration_bins,
592
+ use_jit=use_jit,
593
+ )
594
+
595
+ if "time" in data_array_like.dims:
596
+ output = xa.DataArray(radiation, dims=("time", "lat", "lon"))
597
+ else:
598
+ output = xa.DataArray(radiation[0], dims=("lat", "lon"))
599
+
600
+ # Preserve as many of the original coordinates as possible, so long as the
601
+ # dimension or the coordinate still exist in the output array.
602
+ for k, coord in data_array_like.coords.items():
603
+ if set(coord.dims).issubset(set(output.dims)):
604
+ output.coords[k] = coord
605
+ return output
graphcast/solar_radiation_test.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import timeit
15
+ from typing import Sequence
16
+
17
+ from absl import logging
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+ from graphcast import solar_radiation
21
+ import numpy as np
22
+ import pandas as pd
23
+ import xarray as xa
24
+
25
+
26
+ def _get_grid_lat_lon_coords(
27
+ num_lat: int, num_lon: int
28
+ ) -> tuple[np.ndarray, np.ndarray]:
29
+ """Generates a linear latitude-longitude grid of the given size.
30
+
31
+ Args:
32
+ num_lat: Size of the latitude dimension of the grid.
33
+ num_lon: Size of the longitude dimension of the grid.
34
+
35
+ Returns:
36
+ A tuple `(lat, lon)` containing 1D arrays with the latitude and longitude
37
+ coordinates in degrees of the generated grid.
38
+ """
39
+ lat = np.linspace(-90.0, 90.0, num=num_lat, endpoint=True)
40
+ lon = np.linspace(0.0, 360.0, num=num_lon, endpoint=False)
41
+ return lat, lon
42
+
43
+
44
+ class SolarRadiationTest(parameterized.TestCase):
45
+
46
+ def setUp(self):
47
+ super().setUp()
48
+ np.random.seed(0)
49
+
50
+ def test_missing_dim_raises_value_error(self):
51
+ data = xa.DataArray(
52
+ np.random.randn(2, 2),
53
+ coords=[np.array([0.1, 0.2]), np.array([0.0, 0.5])],
54
+ dims=["lon", "x"],
55
+ )
56
+ with self.assertRaisesRegex(
57
+ ValueError, r".* dimensions are missing in `data_array_like`."
58
+ ):
59
+ solar_radiation.get_toa_incident_solar_radiation_for_xarray(
60
+ data, integration_period="1h", num_integration_bins=360
61
+ )
62
+
63
+ def test_missing_coordinate_raises_value_error(self):
64
+ data = xa.Dataset(
65
+ data_vars={"var1": (["x", "lat", "lon"], np.random.randn(2, 3, 2))},
66
+ coords={
67
+ "lat": np.array([0.0, 0.1, 0.2]),
68
+ "lon": np.array([0.0, 0.5]),
69
+ },
70
+ )
71
+ with self.assertRaisesRegex(
72
+ ValueError, r".* coordinates are missing in `data_array_like`."
73
+ ):
74
+ solar_radiation.get_toa_incident_solar_radiation_for_xarray(
75
+ data, integration_period="1h", num_integration_bins=360
76
+ )
77
+
78
+ def test_shape_multiple_timestamps(self):
79
+ data = xa.Dataset(
80
+ data_vars={"var1": (["time", "lat", "lon"], np.random.randn(2, 4, 2))},
81
+ coords={
82
+ "lat": np.array([0.0, 0.1, 0.2, 0.3]),
83
+ "lon": np.array([0.0, 0.5]),
84
+ "time": np.array([100, 200], dtype="timedelta64[s]"),
85
+ "datetime": xa.Variable(
86
+ "time", np.array([10, 20], dtype="datetime64[D]")
87
+ ),
88
+ },
89
+ )
90
+
91
+ actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
92
+ data, integration_period="1h", num_integration_bins=2
93
+ )
94
+
95
+ self.assertEqual(("time", "lat", "lon"), actual.dims)
96
+ self.assertEqual((2, 4, 2), actual.shape)
97
+
98
+ def test_shape_single_timestamp(self):
99
+ data = xa.Dataset(
100
+ data_vars={"var1": (["lat", "lon"], np.random.randn(4, 2))},
101
+ coords={
102
+ "lat": np.array([0.0, 0.1, 0.2, 0.3]),
103
+ "lon": np.array([0.0, 0.5]),
104
+ "datetime": np.datetime64(10, "D"),
105
+ },
106
+ )
107
+
108
+ actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
109
+ data, integration_period="1h", num_integration_bins=2
110
+ )
111
+
112
+ self.assertEqual(("lat", "lon"), actual.dims)
113
+ self.assertEqual((4, 2), actual.shape)
114
+
115
+ @parameterized.named_parameters(
116
+ dict(
117
+ testcase_name="one_timestamp_jitted",
118
+ periods=1,
119
+ repeats=3,
120
+ use_jit=True,
121
+ ),
122
+ dict(
123
+ testcase_name="one_timestamp_non_jitted",
124
+ periods=1,
125
+ repeats=3,
126
+ use_jit=False,
127
+ ),
128
+ dict(
129
+ testcase_name="ten_timestamps_non_jitted",
130
+ periods=10,
131
+ repeats=1,
132
+ use_jit=False,
133
+ ),
134
+ )
135
+ def test_full_spatial_resolution(
136
+ self, periods: int, repeats: int, use_jit: bool
137
+ ):
138
+ timestamps = pd.date_range(start="2023-09-25", periods=periods, freq="6h")
139
+ # Generate a linear grid with 0.25 degrees resolution similar to ERA5.
140
+ lat, lon = _get_grid_lat_lon_coords(num_lat=721, num_lon=1440)
141
+
142
+ def benchmark() -> None:
143
+ solar_radiation.get_toa_incident_solar_radiation(
144
+ timestamps,
145
+ lat,
146
+ lon,
147
+ integration_period="1h",
148
+ num_integration_bins=360,
149
+ use_jit=use_jit,
150
+ ).block_until_ready()
151
+
152
+ results = timeit.repeat(benchmark, repeat=repeats, number=1)
153
+
154
+ logging.info(
155
+ "Times to compute `tisr` for input of shape `%d, %d, %d` (seconds): %s",
156
+ len(timestamps),
157
+ len(lat),
158
+ len(lon),
159
+ np.array2string(np.array(results), precision=1),
160
+ )
161
+
162
+
163
+ class GetTsiTest(parameterized.TestCase):
164
+
165
+ @parameterized.named_parameters(
166
+ dict(
167
+ testcase_name="reference_tsi_data",
168
+ loader=solar_radiation.reference_tsi_data,
169
+ expected_tsi=np.array([1361.0]),
170
+ ),
171
+ dict(
172
+ testcase_name="era5_tsi_data",
173
+ loader=solar_radiation.era5_tsi_data,
174
+ expected_tsi=np.array([1360.9440]), # 0.9965 * 1365.7240
175
+ ),
176
+ )
177
+ def test_mid_2020_lookup(
178
+ self, loader: solar_radiation.TsiDataLoader, expected_tsi: np.ndarray
179
+ ):
180
+ tsi_data = loader()
181
+
182
+ tsi = solar_radiation.get_tsi(
183
+ [np.datetime64("2020-07-02T00:00:00")], tsi_data
184
+ )
185
+
186
+ np.testing.assert_allclose(expected_tsi, tsi)
187
+
188
+ @parameterized.named_parameters(
189
+ dict(
190
+ testcase_name="beginning_2020_left_boundary",
191
+ timestamps=[np.datetime64("2020-01-01T00:00:00")],
192
+ expected_tsi=np.array([1000.0]),
193
+ ),
194
+ dict(
195
+ testcase_name="mid_2020_exact",
196
+ timestamps=[np.datetime64("2020-07-02T00:00:00")],
197
+ expected_tsi=np.array([1000.0]),
198
+ ),
199
+ dict(
200
+ testcase_name="beginning_2021_interpolated",
201
+ timestamps=[np.datetime64("2021-01-01T00:00:00")],
202
+ expected_tsi=np.array([1150.0]),
203
+ ),
204
+ dict(
205
+ testcase_name="mid_2021_lookup",
206
+ timestamps=[np.datetime64("2021-07-02T12:00:00")],
207
+ expected_tsi=np.array([1300.0]),
208
+ ),
209
+ dict(
210
+ testcase_name="beginning_2022_interpolated",
211
+ timestamps=[np.datetime64("2022-01-01T00:00:00")],
212
+ expected_tsi=np.array([1250.0]),
213
+ ),
214
+ dict(
215
+ testcase_name="mid_2022_lookup",
216
+ timestamps=[np.datetime64("2022-07-02T12:00:00")],
217
+ expected_tsi=np.array([1200.0]),
218
+ ),
219
+ dict(
220
+ testcase_name="beginning_2023_right_boundary",
221
+ timestamps=[np.datetime64("2023-01-01T00:00:00")],
222
+ expected_tsi=np.array([1200.0]),
223
+ ),
224
+ )
225
+ def test_interpolation(
226
+ self, timestamps: Sequence[np.datetime64], expected_tsi: np.ndarray
227
+ ):
228
+ tsi_data = xa.DataArray(
229
+ np.array([1000.0, 1300.0, 1200.0]),
230
+ dims=["time"],
231
+ coords={"time": np.array([2020.5, 2021.5, 2022.5])},
232
+ )
233
+
234
+ tsi = solar_radiation.get_tsi(timestamps, tsi_data)
235
+
236
+ np.testing.assert_allclose(expected_tsi, tsi)
237
+
238
+
239
+ if __name__ == "__main__":
240
+ absltest.main()
graphcast/typed_graph.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Data-structure for storing graphs with typed edges and nodes."""
15
+
16
+ from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
17
+
18
+ ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
19
+ ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
20
+
21
+ _T = TypeVar('_T')
22
+
23
+
24
+ # All tensors have a "flat_batch_axis", which is similar to the leading
25
+ # axes of graph_tuples:
26
+ # * In the case of nodes this is simply a shared node and flat batch axis, with
27
+ # size corresponding to the total number of nodes in the flattened batch.
28
+ # * In the case of edges this is simply a shared edge and flat batch axis, with
29
+ # size corresponding to the total number of edges in the flattened batch.
30
+ # * In the case of globals this is simply the number of graphs in the flattened
31
+ # batch.
32
+
33
+ # All shapes may also have any additional leading shape "batch_shape".
34
+ # Options for building batches are:
35
+ # * Use a provided "flatten" method that takes a leading `batch_shape` and
36
+ # it into the flat_batch_axis (this will be useful when using `tf.Dataset`
37
+ # which supports batching into RaggedTensors, with leading batch shape even
38
+ # if graphs have different numbers of nodes and edges), so the RaggedBatches
39
+ # can then be converted into something without ragged dimensions that jax can
40
+ # use.
41
+ # * Directly build a "flat batch" using a provided function for batching a list
42
+ # of graphs (how it is done in `jraph`).
43
+
44
+
45
+ class NodeSet(NamedTuple):
46
+ """Represents a set of nodes."""
47
+ n_node: ArrayLike # [num_flat_graphs]
48
+ features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
49
+
50
+
51
+ class EdgesIndices(NamedTuple):
52
+ """Represents indices to nodes adjacent to the edges."""
53
+ senders: ArrayLike # [num_flat_edges]
54
+ receivers: ArrayLike # [num_flat_edges]
55
+
56
+
57
+ class EdgeSet(NamedTuple):
58
+ """Represents a set of edges."""
59
+ n_edge: ArrayLike # [num_flat_graphs]
60
+ indices: EdgesIndices
61
+ features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
62
+
63
+
64
+ class Context(NamedTuple):
65
+ # `n_graph` always contains ones but it is useful to query the leading shape
66
+ # in case of graphs without any nodes or edges sets.
67
+ n_graph: ArrayLike # [num_flat_graphs]
68
+ features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
69
+
70
+
71
+ class EdgeSetKey(NamedTuple):
72
+ name: str # Name of the EdgeSet.
73
+
74
+ # Sender node set name and receiver node set name connected by the edge set.
75
+ node_sets: Tuple[str, str]
76
+
77
+
78
+ class TypedGraph(NamedTuple):
79
+ """A graph with typed nodes and edges.
80
+
81
+ A typed graph is made of a context, multiple sets of nodes and multiple
82
+ sets of edges connecting those nodes (as indicated by the EdgeSetKey).
83
+ """
84
+
85
+ context: Context
86
+ nodes: Mapping[str, NodeSet]
87
+ edges: Mapping[EdgeSetKey, EdgeSet]
88
+
89
+ def edge_key_by_name(self, name: str) -> EdgeSetKey:
90
+ found_key = [k for k in self.edges.keys() if k.name == name]
91
+ if len(found_key) != 1:
92
+ raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
93
+ name, ', '.join(x.name for x in self.edges.keys())))
94
+ return found_key[0]
95
+
96
+ def edge_by_name(self, name: str) -> EdgeSet:
97
+ return self.edges[self.edge_key_by_name(name)]
graphcast/typed_graph_net.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """A library of typed Graph Neural Networks."""
15
+
16
+ from typing import Callable, Mapping, Optional, Union
17
+
18
+ from graphcast import typed_graph
19
+ import jax.numpy as jnp
20
+ import jax.tree_util as tree
21
+ import jraph
22
+
23
+
24
+ # All features will be an ArrayTree.
25
+ NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = (
26
+ jraph.ArrayTree)
27
+
28
+ # Signature:
29
+ # (node features, outgoing edge features, incoming edge features,
30
+ # globals) -> updated node features
31
+ GNUpdateNodeFn = Callable[
32
+ [NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures],
33
+ Globals],
34
+ NodeFeatures]
35
+
36
+ GNUpdateGlobalFn = Callable[
37
+ [Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals],
38
+ Globals]
39
+
40
+
41
+ def GraphNetwork( # pylint: disable=invalid-name
42
+ update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn],
43
+ update_node_fn: Mapping[str, GNUpdateNodeFn],
44
+ update_global_fn: Optional[GNUpdateGlobalFn] = None,
45
+ aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
46
+ .segment_sum,
47
+ aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph
48
+ .segment_sum,
49
+ aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph
50
+ .segment_sum,
51
+ ):
52
+ """Returns a method that applies a configured GraphNetwork.
53
+
54
+ This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
55
+ extended to Typed Graphs with multiple edge sets and node sets and extended to
56
+ allow aggregating not only edges received by the nodes, but also edges sent by
57
+ the nodes.
58
+
59
+ Example usage::
60
+
61
+ gn = GraphNetwork(update_edge_function,
62
+ update_node_function, **kwargs)
63
+ # Conduct multiple rounds of message passing with the same parameters:
64
+ for _ in range(num_message_passing_steps):
65
+ graph = gn(graph)
66
+
67
+ Args:
68
+ update_edge_fn: mapping of functions used to update a subset of the edge
69
+ types, indexed by edge type name.
70
+ update_node_fn: mapping of functions used to update a subset of the node
71
+ types, indexed by node type name.
72
+ update_global_fn: function used to update the globals or None to deactivate
73
+ globals updates.
74
+ aggregate_edges_for_nodes_fn: function used to aggregate messages to each
75
+ node.
76
+ aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
77
+ globals.
78
+ aggregate_edges_for_globals_fn: function used to aggregate the edges for the
79
+ globals.
80
+
81
+ Returns:
82
+ A method that applies the configured GraphNetwork.
83
+ """
84
+
85
+ def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
86
+ """Applies a configured GraphNetwork to a graph.
87
+
88
+ This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
89
+ extended to Typed Graphs with multiple edge sets and node sets and extended
90
+ to allow aggregating not only edges received by the nodes, but also edges
91
+ sent by the nodes.
92
+
93
+ Args:
94
+ graph: a `TypedGraph` containing the graph.
95
+
96
+ Returns:
97
+ Updated `TypedGraph`.
98
+ """
99
+
100
+ updated_graph = graph
101
+
102
+ # Edge update.
103
+ updated_edges = dict(updated_graph.edges)
104
+ for edge_set_name, edge_fn in update_edge_fn.items():
105
+ edge_set_key = graph.edge_key_by_name(edge_set_name)
106
+ updated_edges[edge_set_key] = _edge_update(
107
+ updated_graph, edge_fn, edge_set_key)
108
+ updated_graph = updated_graph._replace(edges=updated_edges)
109
+
110
+ # Node update.
111
+ updated_nodes = dict(updated_graph.nodes)
112
+ for node_set_key, node_fn in update_node_fn.items():
113
+ updated_nodes[node_set_key] = _node_update(
114
+ updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
115
+ updated_graph = updated_graph._replace(nodes=updated_nodes)
116
+
117
+ # Global update.
118
+ if update_global_fn:
119
+ updated_context = _global_update(
120
+ updated_graph, update_global_fn,
121
+ aggregate_edges_for_globals_fn,
122
+ aggregate_nodes_for_globals_fn)
123
+ updated_graph = updated_graph._replace(context=updated_context)
124
+
125
+ return updated_graph
126
+
127
+ return _apply_graph_net
128
+
129
+
130
+ def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name
131
+ """Updates an edge set of a given key."""
132
+
133
+ sender_nodes = graph.nodes[edge_set_key.node_sets[0]]
134
+ receiver_nodes = graph.nodes[edge_set_key.node_sets[1]]
135
+ edge_set = graph.edges[edge_set_key]
136
+ senders = edge_set.indices.senders # pytype: disable=attribute-error
137
+ receivers = edge_set.indices.receivers # pytype: disable=attribute-error
138
+
139
+ sent_attributes = tree.tree_map(
140
+ lambda n: n[senders], sender_nodes.features)
141
+ received_attributes = tree.tree_map(
142
+ lambda n: n[receivers], receiver_nodes.features)
143
+
144
+ n_edge = edge_set.n_edge
145
+ sum_n_edge = senders.shape[0]
146
+ global_features = tree.tree_map(
147
+ lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge),
148
+ graph.context.features)
149
+ new_features = edge_fn(
150
+ edge_set.features, sent_attributes, received_attributes,
151
+ global_features)
152
+ return edge_set._replace(features=new_features)
153
+
154
+
155
+ def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name
156
+ """Updates an edge set of a given key."""
157
+ node_set = graph.nodes[node_set_key]
158
+ sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
159
+
160
+ sent_features = {}
161
+ for edge_set_key, edge_set in graph.edges.items():
162
+ sender_node_set_key = edge_set_key.node_sets[0]
163
+ if sender_node_set_key == node_set_key:
164
+ assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
165
+ senders = edge_set.indices.senders
166
+ sent_features[edge_set_key.name] = tree.tree_map(
167
+ lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
168
+
169
+ received_features = {}
170
+ for edge_set_key, edge_set in graph.edges.items():
171
+ receiver_node_set_key = edge_set_key.node_sets[1]
172
+ if receiver_node_set_key == node_set_key:
173
+ assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
174
+ receivers = edge_set.indices.receivers
175
+ received_features[edge_set_key.name] = tree.tree_map(
176
+ lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
177
+
178
+ n_node = node_set.n_node
179
+ global_features = tree.tree_map(
180
+ lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node),
181
+ graph.context.features)
182
+ new_features = node_fn(
183
+ node_set.features, sent_features, received_features, global_features)
184
+ return node_set._replace(features=new_features)
185
+
186
+
187
+ def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name
188
+ """Updates an edge set of a given key."""
189
+ n_graph = graph.context.n_graph.shape[0]
190
+ graph_idx = jnp.arange(n_graph)
191
+
192
+ edge_features = {}
193
+ for edge_set_key, edge_set in graph.edges.items():
194
+ assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
195
+ sum_n_edge = edge_set.indices.senders.shape[0]
196
+ edge_gr_idx = jnp.repeat(
197
+ graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge)
198
+ edge_features[edge_set_key.name] = tree.tree_map(
199
+ lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
200
+ edge_set.features)
201
+
202
+ node_features = {}
203
+ for node_set_key, node_set in graph.nodes.items():
204
+ sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
205
+ node_gr_idx = jnp.repeat(
206
+ graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node)
207
+ node_features[node_set_key] = tree.tree_map(
208
+ lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
209
+ node_set.features)
210
+
211
+ new_features = global_fn(node_features, edge_features, graph.context.features)
212
+ return graph.context._replace(features=new_features)
213
+
214
+
215
+ InteractionUpdateNodeFn = Callable[
216
+ [jraph.NodeFeatures,
217
+ Mapping[str, SenderFeatures],
218
+ Mapping[str, ReceiverFeatures]],
219
+ jraph.NodeFeatures]
220
+
221
+
222
+ InteractionUpdateNodeFnNoSentEdges = Callable[
223
+ [jraph.NodeFeatures,
224
+ Mapping[str, ReceiverFeatures]],
225
+ jraph.NodeFeatures]
226
+
227
+
228
+ def InteractionNetwork( # pylint: disable=invalid-name
229
+ update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn],
230
+ update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn,
231
+ InteractionUpdateNodeFnNoSentEdges]],
232
+ aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
233
+ .segment_sum,
234
+ include_sent_messages_in_node_update: bool = False):
235
+ """Returns a method that applies a configured InteractionNetwork.
236
+
237
+ An interaction network computes interactions on the edges based on the
238
+ previous edges features, and on the features of the nodes sending into those
239
+ edges. It then updates the nodes based on the incoming updated edges.
240
+ See https://arxiv.org/abs/1612.00222 for more details.
241
+
242
+ This implementation extends the behavior to `TypedGraphs` adding an option
243
+ to include edge features for which a node is a sender in the arguments to
244
+ the node update function.
245
+
246
+ Args:
247
+ update_edge_fn: mapping of functions used to update a subset of the edge
248
+ types, indexed by edge type name.
249
+ update_node_fn: mapping of functions used to update a subset of the node
250
+ types, indexed by node type name.
251
+ aggregate_edges_for_nodes_fn: function used to aggregate messages to each
252
+ node.
253
+ include_sent_messages_in_node_update: pass edge features for which a node is
254
+ a sender to the node update function.
255
+ """
256
+ # An InteractionNetwork is a GraphNetwork without globals features,
257
+ # so we implement the InteractionNetwork as a configured GraphNetwork.
258
+
259
+ # An InteractionNetwork edge function does not have global feature inputs,
260
+ # so we filter the passed global argument in the GraphNetwork.
261
+ wrapped_update_edge_fn = tree.tree_map(
262
+ lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn)
263
+
264
+ # Similarly, we wrap the update_node_fn to ensure only the expected
265
+ # arguments are passed to the Interaction net.
266
+ if include_sent_messages_in_node_update:
267
+ wrapped_update_node_fn = tree.tree_map(
268
+ lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn)
269
+ else:
270
+ wrapped_update_node_fn = tree.tree_map(
271
+ lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn)
272
+ return GraphNetwork(
273
+ update_edge_fn=wrapped_update_edge_fn,
274
+ update_node_fn=wrapped_update_node_fn,
275
+ aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn)
276
+
277
+
278
+ def GraphMapFeatures( # pylint: disable=invalid-name
279
+ embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None,
280
+ embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None,
281
+ embed_global_fn: Optional[jraph.EmbedGlobalFn] = None):
282
+ """Returns function which embeds the components of a graph independently.
283
+
284
+ Args:
285
+ embed_edge_fn: mapping of functions used to embed each edge type,
286
+ indexed by edge type name.
287
+ embed_node_fn: mapping of functions used to embed each node type,
288
+ indexed by node type name.
289
+ embed_global_fn: function used to embed the globals.
290
+ """
291
+
292
+ def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
293
+
294
+ updated_edges = dict(graph.edges)
295
+ if embed_edge_fn:
296
+ for edge_set_name, embed_fn in embed_edge_fn.items():
297
+ edge_set_key = graph.edge_key_by_name(edge_set_name)
298
+ edge_set = graph.edges[edge_set_key]
299
+ updated_edges[edge_set_key] = edge_set._replace(
300
+ features=embed_fn(edge_set.features))
301
+
302
+ updated_nodes = dict(graph.nodes)
303
+ if embed_node_fn:
304
+ for node_set_key, embed_fn in embed_node_fn.items():
305
+ node_set = graph.nodes[node_set_key]
306
+ updated_nodes[node_set_key] = node_set._replace(
307
+ features=embed_fn(node_set.features))
308
+
309
+ updated_context = graph.context
310
+ if embed_global_fn:
311
+ updated_context = updated_context._replace(
312
+ features=embed_global_fn(updated_context.features))
313
+
314
+ return graph._replace(edges=updated_edges, nodes=updated_nodes,
315
+ context=updated_context)
316
+
317
+ return _embed
graphcast/xarray_jax.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Helpers to use xarray.{Variable,DataArray,Dataset} with JAX.
15
+
16
+ Allows them to be based on JAX arrays without converting to numpy arrays under
17
+ the hood, so you can start with a JAX array, do some computation with it in
18
+ xarray-land, get a JAX array out the other end and (for example) jax.jit
19
+ through the whole thing. You can even jax.jit a function which accepts and
20
+ returns xarray.Dataset, DataArray and Variable.
21
+
22
+ ## Creating xarray datatypes from jax arrays, and vice-versa.
23
+
24
+ You can use the xarray_jax.{Variable,DataArray,Dataset} constructors, which have
25
+ the same API as the standard xarray constructors but will accept JAX arrays
26
+ without converting them to numpy.
27
+
28
+ It does this by wrapping the JAX array in a wrapper before passing it to
29
+ xarray; you can also do this manually by calling xarray_jax.wrap on your JAX
30
+ arrays before passing them to the standard xarray constructors.
31
+
32
+ To get non-wrapped JAX arrays out the other end, you can use e.g.:
33
+
34
+ xarray_jax.jax_vars(dataset)
35
+ xarray_jax.jax_data(dataset.some_var)
36
+
37
+ which will complain if the data isn't actually a JAX array. Use this if you need
38
+ to make sure the computation has gone via JAX, e.g. if it's the output of code
39
+ that you want to JIT or compute gradients through. If this is not the case and
40
+ you want to support passing plain numpy arrays through as well as potentially
41
+ JAX arrays, you can use:
42
+
43
+ xarray_jax.unwrap_vars(dataset)
44
+ xarray_jax.unwrap_data(dataset.some_var)
45
+
46
+ which will unwrap the data if it is a wrapped JAX array, but otherwise pass
47
+ it through to you without complaint.
48
+
49
+ The wrapped JAX arrays aim to support all the core operations from the numpy
50
+ array API that xarray expects, however there may still be some gaps; if you run
51
+ into any problems around this, you may need to add a few more proxy methods onto
52
+ the wrapper class below.
53
+
54
+ In future once JAX and xarray support the new Python array API standard
55
+ (https://data-apis.org/array-api/latest/index.html), we hope to avoid the need
56
+ for wrapping the JAX arrays like this.
57
+
58
+ ## jax.jit and pmap of functions taking and returning xarray datatypes
59
+
60
+ We register xarray datatypes with jax.tree_util, which allows them to be treated
61
+ as generic containers of jax arrays by various parts of jax including jax.jit.
62
+
63
+ This allows for, e.g.:
64
+
65
+ @jax.jit
66
+ def foo(input: xarray.Dataset) -> xarray.Dataset:
67
+ ...
68
+
69
+ It will not work out-of-the-box with shape-modifying transformations like
70
+ jax.pmap, or e.g. a jax.tree_util.tree_map with some transform that alters array
71
+ shapes or dimension order. That's because we won't know what dimension names
72
+ and/or coordinates to use when unflattening, if the results have a different
73
+ shape to the data that was originally flattened.
74
+
75
+ You can work around this using xarray_jax.dims_change_on_unflatten, however,
76
+ and in the case of jax.pmap we provide a wrapper xarray_jax.pmap which allows
77
+ it to be used with functions taking and returning xarrays.
78
+
79
+ ## Treatment of coordinates
80
+
81
+ We don't support passing jax arrays as coordinates when constructing a
82
+ DataArray/Dataset. This is because xarray's advanced indexing and slicing is
83
+ unlikely to work with jax arrays (at least when a Tracer is used during
84
+ jax.jit), and also because some important datatypes used for coordinates, like
85
+ timedelta64 and datetime64, are not supported by jax.
86
+
87
+ For the purposes of tree_util and jax.jit, coordinates are not treated as leaves
88
+ of the tree (array data 'contained' by a Dataset/DataArray), they are just a
89
+ static part of the structure. That means that if a jit'ed function is called
90
+ twice with Dataset inputs that use different coordinates, it will compile a
91
+ separate version of the function for each. The coordinates are treated like
92
+ static_argnums by jax.jit.
93
+
94
+ If you want to use dynamic data for coordinates, we recommend making it a
95
+ data_var instead of a coord. You won't be able to do indexing and slicing using
96
+ the coordinate, but that wasn't going to work with a jax array anyway.
97
+ """
98
+
99
+ import collections
100
+ import contextlib
101
+ import contextvars
102
+ from typing import Any, Callable, Hashable, Iterator, Mapping, Optional, Union, Tuple, TypeVar, cast
103
+
104
+ import jax
105
+ import jax.numpy as jnp
106
+ import numpy as np
107
+ import tree
108
+ import xarray
109
+
110
+
111
+ def Variable(dims, data, **kwargs) -> xarray.Variable: # pylint:disable=invalid-name
112
+ """Like xarray.Variable, but can wrap JAX arrays."""
113
+ return xarray.Variable(dims, wrap(data), **kwargs)
114
+
115
+
116
+ _JAX_COORD_ATTR_NAME = '_jax_coord'
117
+
118
+
119
+ def DataArray( # pylint:disable=invalid-name
120
+ data,
121
+ coords=None,
122
+ dims=None,
123
+ name=None,
124
+ attrs=None,
125
+ jax_coords=None,
126
+ ) -> xarray.DataArray:
127
+ """Like xarray.DataArray, but supports using JAX arrays.
128
+
129
+ Args:
130
+ data: As for xarray.DataArray, except jax arrays are also supported.
131
+ coords: Coordinates for the array, see xarray.DataArray. These coordinates
132
+ must be based on plain numpy arrays or something convertible to plain
133
+ numpy arrays. Their values will form a static part of the data structure
134
+ from the point of view of jax.tree_util. In particular this means these
135
+ coordinates will be passed as plain numpy arrays even inside a JIT'd
136
+ function, and the JIT'd function will be recompiled under the hood if the
137
+ coordinates of DataArrays passed into it change.
138
+ If this is not convenient for you, see also jax_coords below.
139
+ dims: See xarray.DataArray.
140
+ name: See xarray.DataArray.
141
+ attrs: See xarray.DataArray.
142
+ jax_coords: Additional coordinates, which *can* use JAX arrays. These
143
+ coordinates will be treated as JAX data from the point of view of
144
+ jax.tree_util, that means when JIT'ing they will be passed as tracers and
145
+ computation involving them will be JIT'd.
146
+ Unfortunately a side-effect of this is that they can't be used as index
147
+ coordinates (because xarray's indexing logic is not JIT-able). If you
148
+ specify a coordinate with the same name as a dimension here, it will not
149
+ be set as an index coordinate; this behaviour is different to the default
150
+ for `coords`, and it means that things like `.sel` based on the jax
151
+ coordinate will not work.
152
+ Note we require `jax_coords` to be explicitly specified via a different
153
+ constructor argument to `coords`, rather than just looking for jax arrays
154
+ within the `coords` and treating them differently. This is because it
155
+ affects the way jax.tree_util treats them, which is somewhat orthogonal to
156
+ whether the value is passed in as numpy or not, and generally needs to be
157
+ handled consistently so is something we encourage explicit control over.
158
+
159
+ Returns:
160
+ An instance of xarray.DataArray. Where JAX arrays are used as data or
161
+ coords, they will be wrapped with JaxArrayWrapper and can be unwrapped via
162
+ `unwrap` and `unwrap_data`.
163
+ """
164
+ result = xarray.DataArray(
165
+ wrap(data), dims=dims, name=name, attrs=attrs or {})
166
+ return assign_coords(result, coords=coords, jax_coords=jax_coords)
167
+
168
+
169
+ def Dataset( # pylint:disable=invalid-name
170
+ data_vars,
171
+ coords=None,
172
+ attrs=None,
173
+ jax_coords=None,
174
+ ) -> xarray.Dataset:
175
+ """Like xarray.Dataset, but can wrap JAX arrays.
176
+
177
+ Args:
178
+ data_vars: As for xarray.Dataset, except jax arrays are also supported.
179
+ coords: Coordinates for the dataset, see xarray.Dataset. These coordinates
180
+ must be based on plain numpy arrays or something convertible to plain
181
+ numpy arrays. Their values will form a static part of the data structure
182
+ from the point of view of jax.tree_util. In particular this means these
183
+ coordinates will be passed as plain numpy arrays even inside a JIT'd
184
+ function, and the JIT'd function will be recompiled under the hood if the
185
+ coordinates of DataArrays passed into it change.
186
+ If this is not convenient for you, see also jax_coords below.
187
+ attrs: See xarray.Dataset.
188
+ jax_coords: Additional coordinates, which *can* use JAX arrays. These
189
+ coordinates will be treated as JAX data from the point of view of
190
+ jax.tree_util, that means when JIT'ing they will be passed as tracers and
191
+ computation involving them will be JIT'd.
192
+ Unfortunately a side-effect of this is that they can't be used as index
193
+ coordinates (because xarray's indexing logic is not JIT-able). If you
194
+ specify a coordinate with the same name as a dimension here, it will not
195
+ be set as an index coordinate; this behaviour is different to the default
196
+ for `coords`, and it means that things like `.sel` based on the jax
197
+ coordinate will not work.
198
+ Note we require `jax_coords` to be explicitly specified via a different
199
+ constructor argument to `coords`, rather than just looking for jax arrays
200
+ within the `coords` and treating them differently. This is because it
201
+ affects the way jax.tree_util treats them, which is somewhat orthogonal to
202
+ whether the value is passed in as numpy or not, and generally needs to be
203
+ handled consistently so is something we encourage explicit control over.
204
+
205
+ Returns:
206
+ An instance of xarray.Dataset. Where JAX arrays are used as data, they
207
+ will be wrapped with JaxArrayWrapper.
208
+ """
209
+ wrapped_data_vars = {}
210
+ for name, var_like in data_vars.items():
211
+ # xarray.Dataset accepts a few different formats for data_vars:
212
+ if isinstance(var_like, jax.Array):
213
+ wrapped_data_vars[name] = wrap(var_like)
214
+ elif isinstance(var_like, tuple):
215
+ # Layout is (dims, data, ...). We wrap data.
216
+ wrapped_data_vars[name] = (var_like[0], wrap(var_like[1])) + var_like[2:]
217
+ else:
218
+ # Could be a plain numpy array or scalar (we don't wrap), or an
219
+ # xarray.Variable, DataArray etc, which we must assume is already wrapped
220
+ # if necessary (e.g. if creating using xarray_jax.{Variable,DataArray}).
221
+ wrapped_data_vars[name] = var_like
222
+
223
+ result = xarray.Dataset(
224
+ data_vars=wrapped_data_vars,
225
+ attrs=attrs)
226
+
227
+ return assign_coords(result, coords=coords, jax_coords=jax_coords)
228
+
229
+
230
+ DatasetOrDataArray = TypeVar(
231
+ 'DatasetOrDataArray', xarray.Dataset, xarray.DataArray)
232
+
233
+
234
+ def assign_coords(
235
+ x: DatasetOrDataArray,
236
+ *,
237
+ coords: Optional[Mapping[Hashable, Any]] = None,
238
+ jax_coords: Optional[Mapping[Hashable, Any]] = None,
239
+ ) -> DatasetOrDataArray:
240
+ """Replacement for assign_coords which works in presence of jax_coords.
241
+
242
+ `jax_coords` allow certain specified coordinates to have their data passed as
243
+ JAX arrays (including through jax.jit boundaries). The compromise in return is
244
+ that they are not created as index coordinates and cannot be used for .sel
245
+ and other coordinate-based indexing operations. See docs for `jax_coords` on
246
+ xarray_jax.Dataset and xarray_jax.DataArray for more information.
247
+
248
+ This function can be used to set jax_coords on an existing DataArray or
249
+ Dataset, and also to set a mix of jax and non-jax coordinates. It implements
250
+ some workarounds to prevent xarray trying and failing to create IndexVariables
251
+ from jax arrays under the hood.
252
+
253
+ If you have any jax_coords with the same name as a dimension, you'll need to
254
+ use this function instead of data_array.assign_coords or dataset.assign_coords
255
+ in general, to avoid an xarray bug where it tries (and in our case fails) to
256
+ create indexes for existing jax coords. See
257
+ https://github.com/pydata/xarray/issues/7885.
258
+
259
+ Args:
260
+ x: An xarray Dataset or DataArray.
261
+ coords: Dict of (non-JAX) coords, or None if not assigning any.
262
+ jax_coords: Dict of JAX coords, or None if not assigning any. See docs for
263
+ xarray_jax.Dataset / DataArray for more information on jax_coords.
264
+
265
+ Returns:
266
+ The Dataset or DataArray with coordinates assigned, similarly to
267
+ Dataset.assign_coords / DataArray.assign_coords.
268
+ """
269
+ coords = {} if coords is None else dict(coords) # Copy before mutating.
270
+ jax_coords = {} if jax_coords is None else dict(jax_coords)
271
+
272
+ # Any existing JAX coords must be dropped and re-added via the workaround
273
+ # below, since otherwise .assign_coords will trigger an xarray bug where
274
+ # it tries to recreate the indexes again for the existing coordinates.
275
+ # Can remove if/when https://github.com/pydata/xarray/issues/7885 fixed.
276
+ existing_jax_coords = get_jax_coords(x)
277
+ jax_coords = existing_jax_coords | jax_coords
278
+ x = x.drop_vars(existing_jax_coords.keys())
279
+
280
+ # We need to ensure that xarray doesn't try to create an index for
281
+ # coordinates with the same name as a dimension, since this will fail if
282
+ # given a wrapped JAX tracer.
283
+ # It appears the only way to avoid this is to name them differently to any
284
+ # dimension name, then rename them back afterwards.
285
+ renamed_jax_coords = {}
286
+ for name, coord in jax_coords.items():
287
+ if isinstance(coord, xarray.DataArray):
288
+ coord = coord.variable
289
+ if isinstance(coord, xarray.Variable):
290
+ coord = coord.copy(deep=False) # Copy before mutating attrs.
291
+ else:
292
+ # Must wrap as Variable with the correct dims first if this has not
293
+ # already been done, otherwise xarray.Dataset will assume the dimension
294
+ # name is also __NONINDEX_{n}.
295
+ coord = Variable((name,), coord)
296
+
297
+ # We set an attr on each jax_coord identifying it as such. These attrs on
298
+ # the coord Variable gets reflected on the coord DataArray exposed too, and
299
+ # when set on coordinates they generally get preserved under the default
300
+ # keep_attrs setting.
301
+ # These attrs are used by jax.tree_util registered flatten/unflatten to
302
+ # determine which coords need to be treated as leaves of the flattened
303
+ # structure vs static data.
304
+ coord.attrs[_JAX_COORD_ATTR_NAME] = True
305
+ renamed_jax_coords[f'__NONINDEX_{name}'] = coord
306
+
307
+ x = x.assign_coords(coords=coords | renamed_jax_coords)
308
+
309
+ rename_back_mapping = {f'__NONINDEX_{name}': name for name in jax_coords}
310
+ if isinstance(x, xarray.Dataset):
311
+ # Using 'rename' doesn't work if renaming to the same name as a dimension.
312
+ return x.rename_vars(rename_back_mapping)
313
+ else: # DataArray
314
+ return x.rename(rename_back_mapping)
315
+
316
+
317
+ def get_jax_coords(x: DatasetOrDataArray) -> Mapping[Hashable, Any]:
318
+ return {
319
+ name: coord_var
320
+ for name, coord_var in x.coords.variables.items()
321
+ if coord_var.attrs.get(_JAX_COORD_ATTR_NAME, False)}
322
+
323
+
324
+ def assign_jax_coords(
325
+ x: DatasetOrDataArray,
326
+ jax_coords: Optional[Mapping[Hashable, Any]] = None,
327
+ **jax_coords_kwargs
328
+ ) -> DatasetOrDataArray:
329
+ """Assigns only jax_coords, with same API as xarray's assign_coords."""
330
+ return assign_coords(x, jax_coords=jax_coords or jax_coords_kwargs)
331
+
332
+
333
+ def wrap(value):
334
+ """Wraps JAX arrays for use in xarray, passing through other values."""
335
+ if isinstance(value, jax.Array):
336
+ return JaxArrayWrapper(value)
337
+ else:
338
+ return value
339
+
340
+
341
+ def unwrap(value, require_jax=False):
342
+ """Unwraps wrapped JAX arrays used in xarray, passing through other values."""
343
+ if isinstance(value, JaxArrayWrapper):
344
+ return value.jax_array
345
+ elif isinstance(value, jax.Array):
346
+ return value
347
+ elif require_jax:
348
+ raise TypeError(f'Expected JAX array, found {type(value)}.')
349
+ else:
350
+ return value
351
+
352
+
353
+ def _wrapped(func):
354
+ """Surrounds a function with JAX array unwrapping/wrapping."""
355
+ def wrapped_func(*args, **kwargs):
356
+ args, kwargs = tree.map_structure(unwrap, (args, kwargs))
357
+ result = func(*args, **kwargs)
358
+ return tree.map_structure(wrap, result)
359
+ return wrapped_func
360
+
361
+
362
+ def unwrap_data(
363
+ value: Union[xarray.Variable, xarray.DataArray],
364
+ require_jax: bool = False
365
+ ) -> Union[jax.Array, np.ndarray]:
366
+ """The unwrapped (see unwrap) data of a an xarray.Variable or DataArray."""
367
+ return unwrap(value.data, require_jax=require_jax)
368
+
369
+
370
+ def unwrap_vars(
371
+ dataset: Mapping[Hashable, xarray.DataArray],
372
+ require_jax: bool = False
373
+ ) -> Mapping[str, Union[jax.Array, np.ndarray]]:
374
+ """The unwrapped data (see unwrap) of the variables in a dataset."""
375
+ # xarray types variable names as Hashable, but in practice they're invariably
376
+ # strings and we convert to str to allow for a more useful return type.
377
+ return {str(name): unwrap_data(var, require_jax=require_jax)
378
+ for name, var in dataset.items()}
379
+
380
+
381
+ def unwrap_coords(
382
+ dataset: Union[xarray.Dataset, xarray.DataArray],
383
+ require_jax: bool = False
384
+ ) -> Mapping[str, Union[jax.Array, np.ndarray]]:
385
+ """The unwrapped data (see unwrap) of the coords in a Dataset or DataArray."""
386
+ return {str(name): unwrap_data(var, require_jax=require_jax)
387
+ for name, var in dataset.coords.items()}
388
+
389
+
390
+ def jax_data(value: Union[xarray.Variable, xarray.DataArray]) -> jax.Array:
391
+ """Like unwrap_data, but will complain if not a jax array."""
392
+ # Implementing this separately so we can give a more specific return type
393
+ # for it.
394
+ return cast(jax.Array, unwrap_data(value, require_jax=True))
395
+
396
+
397
+ def jax_vars(
398
+ dataset: Mapping[Hashable, xarray.DataArray]) -> Mapping[str, jax.Array]:
399
+ """Like unwrap_vars, but will complain if vars are not all jax arrays."""
400
+ return cast(Mapping[str, jax.Array], unwrap_vars(dataset, require_jax=True))
401
+
402
+
403
+ class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin):
404
+ """Wraps a JAX array into a duck-typed array suitable for use with xarray.
405
+
406
+ This uses an older duck-typed array protocol based on __array_ufunc__ and
407
+ __array_function__ which works with numpy and xarray. (In newer versions
408
+ of xarray it implements xarray.namedarray._typing._array_function.)
409
+
410
+ This is in the process of being superseded by the Python array API standard
411
+ (https://data-apis.org/array-api/latest/index.html), but JAX hasn't
412
+ implemented it yet. Once they have, we should be able to get rid of
413
+ this wrapper and use JAX arrays directly with xarray.
414
+
415
+ """
416
+
417
+ def __init__(self, jax_array):
418
+ self.jax_array = jax_array
419
+
420
+ def __array_ufunc__(self, ufunc, method, *args, **kwargs):
421
+ for x in args:
422
+ if not isinstance(x, (jax.typing.ArrayLike, type(self))):
423
+ return NotImplemented
424
+ if method != '__call__':
425
+ return NotImplemented
426
+ try:
427
+ # Get the corresponding jax.numpy function to the NumPy ufunc:
428
+ func = getattr(jnp, ufunc.__name__)
429
+ except AttributeError:
430
+ return NotImplemented
431
+ # There may be an 'out' kwarg requesting an in-place operation, e.g. when
432
+ # this is called via __iadd__ (+=), __imul__ (*=) etc. JAX doesn't support
433
+ # in-place operations so we just remove this argument and have the ufunc
434
+ # return a fresh JAX array instead.
435
+ kwargs.pop('out', None)
436
+ return _wrapped(func)(*args, **kwargs)
437
+
438
+ def __array_function__(self, func, types, args, kwargs):
439
+ try:
440
+ # Get the corresponding jax.np function to the NumPy function:
441
+ func = getattr(jnp, func.__name__)
442
+ except AttributeError:
443
+ return NotImplemented
444
+ return _wrapped(func)(*args, **kwargs)
445
+
446
+ def __repr__(self):
447
+ return f'xarray_jax.JaxArrayWrapper({repr(self.jax_array)})'
448
+
449
+ # NDArrayOperatorsMixin already proxies most __dunder__ operator methods.
450
+ # We need to proxy through a few more methods in a similar way:
451
+
452
+ # Essential array properties:
453
+
454
+ @property
455
+ def shape(self):
456
+ return self.jax_array.shape
457
+
458
+ @property
459
+ def dtype(self):
460
+ return self.jax_array.dtype
461
+
462
+ @property
463
+ def ndim(self):
464
+ return self.jax_array.ndim
465
+
466
+ @property
467
+ def size(self):
468
+ return self.jax_array.size
469
+
470
+ @property
471
+ def real(self):
472
+ return self.jax_array.real
473
+
474
+ @property
475
+ def imag(self):
476
+ return self.jax_array.imag
477
+
478
+ # Array methods not covered by NDArrayOperatorsMixin:
479
+
480
+ # Allows conversion to numpy array using np.asarray etc. Warning: doing this
481
+ # will fail in a jax.jit-ed function.
482
+ def __array__(self, dtype=None, context=None):
483
+ return np.asarray(self.jax_array, dtype=dtype)
484
+
485
+ __getitem__ = _wrapped(lambda array, *args: array.__getitem__(*args))
486
+ # We drop the kwargs on this as they are not supported by JAX, but xarray
487
+ # uses at least one of them (the copy arg).
488
+ astype = _wrapped(lambda array, *args, **kwargs: array.astype(*args))
489
+
490
+ # There are many more methods which are more canonically available via (j)np
491
+ # functions, e.g. .sum() available via jnp.sum, and also mean, max, min,
492
+ # argmax, argmin etc. We don't attempt to proxy through all of these as
493
+ # methods, since this doesn't appear to be expected from a duck-typed array
494
+ # implementation. But there are a few which xarray calls as methods, so we
495
+ # proxy those:
496
+ transpose = _wrapped(jnp.transpose)
497
+ reshape = _wrapped(jnp.reshape)
498
+ all = _wrapped(jnp.all)
499
+
500
+
501
+ def apply_ufunc(func, *args, require_jax=False, **apply_ufunc_kwargs):
502
+ """Like xarray.apply_ufunc but for jax-specific ufuncs.
503
+
504
+ Many numpy ufuncs will work fine out of the box with xarray_jax and
505
+ JaxArrayWrapper, since JaxArrayWrapper quacks (mostly) like a numpy array and
506
+ will convert many numpy operations to jax ops under the hood. For these
507
+ situations, xarray.apply_ufunc should work fine.
508
+
509
+ But sometimes you need a jax-specific ufunc which needs to be given a
510
+ jax array as input or return a jax array as output. In that case you should
511
+ use this helper as it will remove any JaxArrayWrapper before calling the func,
512
+ and wrap the result afterwards before handing it back to xarray.
513
+
514
+ Args:
515
+ func: A function that works with jax arrays (e.g. using functions from
516
+ jax.numpy) but otherwise meets the spec for the func argument to
517
+ xarray.apply_ufunc.
518
+ *args: xarray arguments to be mapped to arguments for func
519
+ (see xarray.apply_ufunc).
520
+ require_jax: Whether to require that inputs are based on jax arrays or allow
521
+ those based on plain numpy arrays too.
522
+ **apply_ufunc_kwargs: See xarray.apply_ufunc.
523
+
524
+ Returns:
525
+ Corresponding xarray results (see xarray.apply_ufunc).
526
+ """
527
+ def wrapped_func(*maybe_wrapped_args):
528
+ unwrapped_args = [unwrap(a, require_jax) for a in maybe_wrapped_args]
529
+ result = func(*unwrapped_args)
530
+ # Result can be an array or a tuple of arrays, this handles both:
531
+ return jax.tree_util.tree_map(wrap, result)
532
+ return xarray.apply_ufunc(wrapped_func, *args, **apply_ufunc_kwargs)
533
+
534
+
535
+ def pmap(fn: Callable[..., Any],
536
+ dim: str,
537
+ axis_name: Optional[str] = None,
538
+ devices: ... = None,
539
+ backend: ... = None) -> Callable[..., Any]:
540
+ """Wraps a subset of jax.pmap functionality to handle xarray input/output.
541
+
542
+ Constraints:
543
+ * Any Dataset or DataArray passed to the function must have `dim` as the
544
+ first dimension. This will be checked. You can ensure this if necessary
545
+ by calling `.transpose(dim, ...)` beforehand.
546
+ * All args and return values will be mapped over the first dimension,
547
+ it will use in_axes=0, out_axes=0.
548
+ * No support for static_broadcasted_argnums, donate_argnums etc.
549
+
550
+ Args:
551
+ fn: Function to be pmap'd which takes and returns trees which may contain
552
+ xarray Dataset/DataArray. Any Dataset/DataArrays passed as input must use
553
+ `dim` as the first dimension on all arrays.
554
+ dim: The xarray dimension name corresponding to the first dimension that is
555
+ pmapped over (pmap is called with in_axes=0, out_axes=0).
556
+ axis_name: Used by jax to identify the mapped axis so that parallel
557
+ collectives can be applied. Defaults to same as `dim`.
558
+ devices:
559
+ backend:
560
+ See jax.pmap.
561
+
562
+ Returns:
563
+ A pmap'd version of `fn`, which takes and returns Dataset/DataArray with an
564
+ extra leading dimension `dim` relative to what the original `fn` sees.
565
+ """
566
+ input_treedef = None
567
+ output_treedef = None
568
+
569
+ def fn_passed_to_pmap(*flat_args):
570
+ assert input_treedef is not None
571
+ # Inside the pmap the original first dimension will no longer be present:
572
+ def check_and_remove_leading_dim(dims):
573
+ try:
574
+ index = dims.index(dim)
575
+ except ValueError:
576
+ index = None
577
+ if index != 0:
578
+ raise ValueError(f'Expected dim {dim} at index 0, found at {index}.')
579
+ return dims[1:]
580
+ with dims_change_on_unflatten(check_and_remove_leading_dim):
581
+ args = jax.tree_util.tree_unflatten(input_treedef, flat_args)
582
+ result = fn(*args)
583
+ nonlocal output_treedef
584
+ flat_result, output_treedef = jax.tree_util.tree_flatten(result)
585
+ return flat_result
586
+
587
+ pmapped_fn = jax.pmap(
588
+ fn_passed_to_pmap,
589
+ axis_name=axis_name or dim,
590
+ in_axes=0,
591
+ out_axes=0,
592
+ devices=devices,
593
+ backend=backend)
594
+
595
+ def result_fn(*args):
596
+ nonlocal input_treedef
597
+ flat_args, input_treedef = jax.tree_util.tree_flatten(args)
598
+ flat_result = pmapped_fn(*flat_args)
599
+ assert output_treedef is not None
600
+ # After the pmap an extra leading axis will be present, we need to add an
601
+ # xarray dimension for this when unflattening the result:
602
+ with dims_change_on_unflatten(lambda dims: (dim,) + dims):
603
+ return jax.tree_util.tree_unflatten(output_treedef, flat_result)
604
+
605
+ return result_fn
606
+
607
+
608
+ # Register xarray datatypes with jax.tree_util.
609
+
610
+
611
+ DimsChangeFn = Callable[[Tuple[Hashable, ...]], Tuple[Hashable, ...]]
612
+ _DIMS_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[DimsChangeFn] = (
613
+ contextvars.ContextVar('dims_change_on_unflatten_fn'))
614
+
615
+
616
+ @contextlib.contextmanager
617
+ def dims_change_on_unflatten(dims_change_fn: DimsChangeFn):
618
+ """Can be used to change the dims used when unflattening arrays into xarrays.
619
+
620
+ This is useful when some axes were added to / removed from the underlying jax
621
+ arrays after they were flattened using jax.tree_util.tree_flatten, and you
622
+ want to unflatten them again afterwards using the original treedef but
623
+ adjusted for the added/removed dimensions.
624
+
625
+ It can also be used with jax.tree_util.tree_map, when it's called with a
626
+ function that adds/removes axes or otherwise changes the axis order.
627
+
628
+ When dimensions are removed, any coordinates using those removed dimensions
629
+ will also be removed on unflatten.
630
+
631
+ This is implemented as a context manager that sets some thread-local state
632
+ affecting the behaviour of our unflatten functions, because it's not possible
633
+ to directly modify the treedef to change the dims/coords in it (and with
634
+ tree_map, the treedef isn't exposed to you anyway).
635
+
636
+ Args:
637
+ dims_change_fn: Maps a tuple of dimension names for the original
638
+ Variable/DataArray/Dataset that was flattened, to an updated tuple of
639
+ dimensions which should be used when unflattening.
640
+
641
+ Yields:
642
+ To a context manager in whose scope jax.tree_util.tree_unflatten and
643
+ jax.tree_util.tree_map will apply the dims_change_fn before reconstructing
644
+ xarrays from jax arrays.
645
+ """
646
+ token = _DIMS_CHANGE_ON_UNFLATTEN_FN.set(dims_change_fn)
647
+ try:
648
+ yield
649
+ finally:
650
+ _DIMS_CHANGE_ON_UNFLATTEN_FN.reset(token)
651
+
652
+
653
+ def _flatten_variable(v: xarray.Variable) -> Tuple[
654
+ Tuple[jax.typing.ArrayLike], Tuple[Hashable, ...]]:
655
+ """Flattens a Variable for jax.tree_util."""
656
+ children = (unwrap_data(v),)
657
+ aux = v.dims
658
+ return children, aux
659
+
660
+
661
+ def _unflatten_variable(
662
+ aux: Tuple[Hashable, ...],
663
+ children: Tuple[jax.typing.ArrayLike]) -> xarray.Variable:
664
+ """Unflattens a Variable for jax.tree_util."""
665
+ dims = aux
666
+ dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
667
+ if dims_change_fn: dims = dims_change_fn(dims)
668
+ return Variable(dims=dims, data=children[0])
669
+
670
+
671
+ def _split_static_and_jax_coords(
672
+ coords: xarray.core.coordinates.Coordinates) -> Tuple[
673
+ Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]]:
674
+ static_coord_vars = {}
675
+ jax_coord_vars = {}
676
+ for name, coord in coords.items():
677
+ if coord.attrs.get(_JAX_COORD_ATTR_NAME, False):
678
+ jax_coord_vars[name] = coord.variable
679
+ else:
680
+ assert not isinstance(coord, (jax.Array, JaxArrayWrapper))
681
+ static_coord_vars[name] = coord.variable
682
+ return static_coord_vars, jax_coord_vars
683
+
684
+
685
+ def _drop_with_none_of_dims(
686
+ coord_vars: Mapping[Hashable, xarray.Variable],
687
+ dims: Tuple[Hashable]) -> Mapping[Hashable, xarray.Variable]:
688
+ return {name: var for name, var in coord_vars.items()
689
+ if set(var.dims) <= set(dims)}
690
+
691
+
692
+ class _HashableCoords(collections.abc.Mapping):
693
+ """Wraps a dict of xarray Variables as hashable, used for static coordinates.
694
+
695
+ This needs to be hashable so that when an xarray.Dataset is passed to a
696
+ jax.jit'ed function, jax can check whether it's seen an array with the
697
+ same static coordinates(*) before or whether it needs to recompile the
698
+ function for the new values of the static coordinates.
699
+
700
+ (*) note jax_coords are not included in this; their value can be different
701
+ on different calls without triggering a recompile.
702
+ """
703
+
704
+ def __init__(self, coord_vars: Mapping[Hashable, xarray.Variable]):
705
+ self._variables = coord_vars
706
+
707
+ def __repr__(self) -> str:
708
+ return f'_HashableCoords({repr(self._variables)})'
709
+
710
+ def __getitem__(self, key: Hashable) -> xarray.Variable:
711
+ return self._variables[key]
712
+
713
+ def __len__(self) -> int:
714
+ return len(self._variables)
715
+
716
+ def __iter__(self) -> Iterator[Hashable]:
717
+ return iter(self._variables)
718
+
719
+ def __hash__(self):
720
+ if not hasattr(self, '_hash'):
721
+ self._hash = hash(frozenset((name, var.data.tobytes())
722
+ for name, var in self._variables.items()))
723
+ return self._hash
724
+
725
+ def __eq__(self, other):
726
+ if self is other:
727
+ return True
728
+ elif not isinstance(other, type(self)):
729
+ return NotImplemented
730
+ elif self._variables is other._variables:
731
+ return True
732
+ else:
733
+ return self._variables.keys() == other._variables.keys() and all(
734
+ variable.equals(other._variables[name])
735
+ for name, variable in self._variables.items())
736
+
737
+
738
+ def _flatten_data_array(v: xarray.DataArray) -> Tuple[
739
+ # Children (data variable, jax_coord_vars):
740
+ Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
741
+ # Static auxiliary data (name, static_coord_vars):
742
+ Tuple[Optional[Hashable], _HashableCoords]]:
743
+ """Flattens a DataArray for jax.tree_util."""
744
+ static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(v.coords)
745
+ children = (v.variable, jax_coord_vars)
746
+ aux = (v.name, _HashableCoords(static_coord_vars))
747
+ return children, aux
748
+
749
+
750
+ def _unflatten_data_array(
751
+ aux: Tuple[Optional[Hashable], _HashableCoords],
752
+ children: Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
753
+ ) -> xarray.DataArray:
754
+ """Unflattens a DataArray for jax.tree_util."""
755
+ variable, jax_coord_vars = children
756
+ name, static_coord_vars = aux
757
+ # Drop static coords which have dims not present in any of the data_vars.
758
+ # These would generally be dims that were dropped by a dims_change_fn, but
759
+ # because static coordinates don't go through dims_change_fn on unflatten, we
760
+ # just drop them where this causes a problem.
761
+ # Since jax_coords go through the dims_change_fn on unflatten we don't need
762
+ # to do this for jax_coords.
763
+ static_coord_vars = _drop_with_none_of_dims(static_coord_vars, variable.dims)
764
+ return DataArray(
765
+ variable, name=name, coords=static_coord_vars, jax_coords=jax_coord_vars)
766
+
767
+
768
+ def _flatten_dataset(dataset: xarray.Dataset) -> Tuple[
769
+ # Children (data variables, jax_coord_vars):
770
+ Tuple[Mapping[Hashable, xarray.Variable],
771
+ Mapping[Hashable, xarray.Variable]],
772
+ # Static auxiliary data (static_coord_vars):
773
+ _HashableCoords]:
774
+ """Flattens a Dataset for jax.tree_util."""
775
+ variables = {name: data_array.variable
776
+ for name, data_array in dataset.data_vars.items()}
777
+ static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(
778
+ dataset.coords)
779
+ children = (variables, jax_coord_vars)
780
+ aux = _HashableCoords(static_coord_vars)
781
+ return children, aux
782
+
783
+
784
+ def _unflatten_dataset(
785
+ aux: _HashableCoords,
786
+ children: Tuple[Mapping[Hashable, xarray.Variable],
787
+ Mapping[Hashable, xarray.Variable]],
788
+ ) -> xarray.Dataset:
789
+ """Unflattens a Dataset for jax.tree_util."""
790
+ data_vars, jax_coord_vars = children
791
+ static_coord_vars = aux
792
+ dataset = xarray.Dataset(data_vars)
793
+ # Drop static coords which have dims not present in any of the data_vars.
794
+ # See corresponding comment in _unflatten_data_array.
795
+ static_coord_vars = _drop_with_none_of_dims(static_coord_vars, dataset.dims) # pytype: disable=wrong-arg-types
796
+ return assign_coords(
797
+ dataset, coords=static_coord_vars, jax_coords=jax_coord_vars)
798
+
799
+
800
+ jax.tree_util.register_pytree_node(
801
+ xarray.Variable, _flatten_variable, _unflatten_variable)
802
+ # This is a subclass of Variable but still needs registering separately.
803
+ # Flatten/unflatten for IndexVariable is a bit of a corner case but we do
804
+ # need to support it.
805
+ jax.tree_util.register_pytree_node(
806
+ xarray.IndexVariable, _flatten_variable, _unflatten_variable)
807
+ jax.tree_util.register_pytree_node(
808
+ xarray.DataArray, _flatten_data_array, _unflatten_data_array)
809
+ jax.tree_util.register_pytree_node(
810
+ xarray.Dataset, _flatten_dataset, _unflatten_dataset)
graphcast/xarray_jax_test.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for xarray_jax."""
15
+
16
+ from absl.testing import absltest
17
+ import chex
18
+ from graphcast import xarray_jax
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ import xarray
23
+
24
+
25
+ class XarrayJaxTest(absltest.TestCase):
26
+
27
+ def test_jax_array_wrapper_with_numpy_api(self):
28
+ # This is just a side benefit of making things work with xarray, but the
29
+ # JaxArrayWrapper does allow you to manipulate JAX arrays using the
30
+ # standard numpy API, without converting them to numpy in the process:
31
+ ones = jnp.ones((3, 4), dtype=np.float32)
32
+ x = xarray_jax.JaxArrayWrapper(ones)
33
+ x = np.abs((x + 2) * (x - 3))
34
+ x = x[:-1, 1:3]
35
+ x = np.concatenate([x, x + 1], axis=0)
36
+ x = np.transpose(x, (1, 0))
37
+ x = np.reshape(x, (-1,))
38
+ x = x.astype(np.int32)
39
+ self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
40
+ # An explicit conversion gets us out of JAX-land however:
41
+ self.assertIsInstance(np.asarray(x), np.ndarray)
42
+
43
+ def test_jax_xarray_variable(self):
44
+ def ops_via_xarray(inputs):
45
+ x = xarray_jax.Variable(('lat', 'lon'), inputs)
46
+ # We'll apply a sequence of operations just to test that the end result is
47
+ # still a JAX array, i.e. we haven't converted to numpy at any point.
48
+ x = np.abs((x + 2) * (x - 3))
49
+ x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)})
50
+ x = xarray.Variable.concat([x, x + 1], dim='lat')
51
+ x = x.transpose('lon', 'lat')
52
+ x = x.stack(channels=('lon', 'lat'))
53
+ x = x.sum()
54
+ return xarray_jax.jax_data(x)
55
+
56
+ # Check it doesn't leave jax-land when passed concrete values:
57
+ ones = jnp.ones((3, 4), dtype=np.float32)
58
+ result = ops_via_xarray(ones)
59
+ self.assertIsInstance(result, jax.Array)
60
+
61
+ # And that you can JIT it and compute gradients through it. These will
62
+ # involve passing jax tracers through the xarray computation:
63
+ jax.jit(ops_via_xarray)(ones)
64
+ jax.grad(ops_via_xarray)(ones)
65
+
66
+ def test_jax_xarray_data_array(self):
67
+ def ops_via_xarray(inputs):
68
+ x = xarray_jax.DataArray(dims=('lat', 'lon'),
69
+ data=inputs,
70
+ coords={'lat': np.arange(3) * 10,
71
+ 'lon': np.arange(4) * 10})
72
+ x = np.abs((x + 2) * (x - 3))
73
+ x = x.sel({'lat': slice(0, 20)})
74
+ y = xarray_jax.DataArray(dims=('lat', 'lon'),
75
+ data=ones,
76
+ coords={'lat': np.arange(3, 6) * 10,
77
+ 'lon': np.arange(4) * 10})
78
+ x = xarray.concat([x, y], dim='lat')
79
+ x = x.transpose('lon', 'lat')
80
+ x = x.stack(channels=('lon', 'lat'))
81
+ x = x.unstack()
82
+ x = x.sum()
83
+ return xarray_jax.jax_data(x)
84
+
85
+ ones = jnp.ones((3, 4), dtype=np.float32)
86
+ result = ops_via_xarray(ones)
87
+ self.assertIsInstance(result, jax.Array)
88
+
89
+ jax.jit(ops_via_xarray)(ones)
90
+ jax.grad(ops_via_xarray)(ones)
91
+
92
+ def test_jax_xarray_dataset(self):
93
+ def ops_via_xarray(foo, bar):
94
+ x = xarray_jax.Dataset(
95
+ data_vars={'foo': (('lat', 'lon'), foo),
96
+ 'bar': (('time', 'lat', 'lon'), bar)},
97
+ coords={
98
+ 'time': np.arange(2),
99
+ 'lat': np.arange(3) * 10,
100
+ 'lon': np.arange(4) * 10})
101
+ x = np.abs((x + 2) * (x - 3))
102
+ x = x.sel({'lat': slice(0, 20)})
103
+ y = xarray_jax.Dataset(
104
+ data_vars={'foo': (('lat', 'lon'), foo),
105
+ 'bar': (('time', 'lat', 'lon'), bar)},
106
+ coords={
107
+ 'time': np.arange(2),
108
+ 'lat': np.arange(3, 6) * 10,
109
+ 'lon': np.arange(4) * 10})
110
+ x = xarray.concat([x, y], dim='lat')
111
+ x = x.transpose('lon', 'lat', 'time')
112
+ x = x.stack(channels=('lon', 'lat'))
113
+ x = (x.foo + x.bar).sum()
114
+ return xarray_jax.jax_data(x)
115
+
116
+ foo = jnp.ones((3, 4), dtype=np.float32)
117
+ bar = jnp.ones((2, 3, 4), dtype=np.float32)
118
+ result = ops_via_xarray(foo, bar)
119
+ self.assertIsInstance(result, jax.Array)
120
+
121
+ jax.jit(ops_via_xarray)(foo, bar)
122
+ jax.grad(ops_via_xarray)(foo, bar)
123
+
124
+ def test_jit_function_with_xarray_variable_arguments_and_return(self):
125
+ function = jax.jit(lambda v: v + 1)
126
+ with self.subTest('jax input'):
127
+ inputs = xarray_jax.Variable(
128
+ ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
129
+ _ = function(inputs)
130
+ # We test running the jitted function a second time, to exercise logic in
131
+ # jax which checks if the structure of the inputs (including dimension
132
+ # names and coordinates) is the same as it was for the previous call and
133
+ # so whether it needs to re-trace-and-compile a new version of the
134
+ # function or not. This can run into problems if the 'aux' structure
135
+ # returned by the registered flatten function is not hashable/comparable.
136
+ outputs = function(inputs)
137
+ self.assertEqual(outputs.dims, inputs.dims)
138
+ with self.subTest('numpy input'):
139
+ inputs = xarray.Variable(
140
+ ('lat', 'lon'), np.ones((3, 4), dtype=np.float32))
141
+ _ = function(inputs)
142
+ outputs = function(inputs)
143
+ self.assertEqual(outputs.dims, inputs.dims)
144
+
145
+ def test_jit_problem_if_convert_to_plain_numpy_array(self):
146
+ inputs = xarray_jax.DataArray(
147
+ data=jnp.ones((2,), dtype=np.float32), dims=('foo',))
148
+ with self.assertRaises(jax.errors.TracerArrayConversionError):
149
+ # Calling .values on a DataArray converts its values to numpy:
150
+ jax.jit(lambda data_array: data_array.values)(inputs)
151
+
152
+ def test_grad_function_with_xarray_variable_arguments(self):
153
+ x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
154
+ # For grad we still need a JAX scalar as the output:
155
+ jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)
156
+
157
+ def test_jit_function_with_xarray_data_array_arguments_and_return(self):
158
+ inputs = xarray_jax.DataArray(
159
+ data=jnp.ones((3, 4), dtype=np.float32),
160
+ dims=('lat', 'lon'),
161
+ coords={'lat': np.arange(3),
162
+ 'lon': np.arange(4) * 10})
163
+ fn = jax.jit(lambda v: v + 1)
164
+ _ = fn(inputs)
165
+ outputs = fn(inputs)
166
+ self.assertEqual(outputs.dims, inputs.dims)
167
+ chex.assert_trees_all_equal(outputs.coords, inputs.coords)
168
+
169
+ def test_jit_function_with_data_array_and_jax_coords(self):
170
+ inputs = xarray_jax.DataArray(
171
+ data=jnp.ones((3, 4), dtype=np.float32),
172
+ dims=('lat', 'lon'),
173
+ coords={'lat': np.arange(3)},
174
+ jax_coords={'lon': jnp.arange(4) * 10})
175
+ # Verify the jax_coord 'lon' retains jax data, and has not been created
176
+ # as an index coordinate:
177
+ self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
178
+ self.assertNotIn('lon', inputs.indexes)
179
+
180
+ @jax.jit
181
+ def fn(v):
182
+ # The non-JAX coord is passed with numpy array data and an index:
183
+ self.assertIsInstance(v.coords['lat'].data, np.ndarray)
184
+ self.assertIn('lat', v.indexes)
185
+
186
+ # The jax_coord is passed with JAX array data:
187
+ self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
188
+ self.assertNotIn('lon', v.indexes)
189
+
190
+ # Use the jax coord in the computation:
191
+ v = v + v.coords['lon']
192
+
193
+ # Return with an updated jax coord:
194
+ return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
195
+
196
+ _ = fn(inputs)
197
+ outputs = fn(inputs)
198
+
199
+ # Verify the jax_coord 'lon' has jax data in the output too:
200
+ self.assertIsInstance(
201
+ outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
202
+ self.assertNotIn('lon', outputs.indexes)
203
+
204
+ self.assertEqual(outputs.dims, inputs.dims)
205
+ chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
206
+ # Check our computations with the coordinate values worked:
207
+ chex.assert_trees_all_equal(
208
+ outputs.coords['lon'].data, (inputs.coords['lon']+1).data)
209
+ chex.assert_trees_all_equal(
210
+ outputs.data, (inputs + inputs.coords['lon']).data)
211
+
212
+ def test_jit_function_with_xarray_dataset_arguments_and_return(self):
213
+ foo = jnp.ones((3, 4), dtype=np.float32)
214
+ bar = jnp.ones((2, 3, 4), dtype=np.float32)
215
+ inputs = xarray_jax.Dataset(
216
+ data_vars={'foo': (('lat', 'lon'), foo),
217
+ 'bar': (('time', 'lat', 'lon'), bar)},
218
+ coords={
219
+ 'time': np.arange(2),
220
+ 'lat': np.arange(3) * 10,
221
+ 'lon': np.arange(4) * 10})
222
+ fn = jax.jit(lambda v: v + 1)
223
+ _ = fn(inputs)
224
+ outputs = fn(inputs)
225
+ self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys())
226
+ self.assertEqual(inputs.foo.dims, outputs.foo.dims)
227
+ self.assertEqual(inputs.bar.dims, outputs.bar.dims)
228
+ chex.assert_trees_all_equal(outputs.coords, inputs.coords)
229
+
230
+ def test_jit_function_with_dataset_and_jax_coords(self):
231
+ foo = jnp.ones((3, 4), dtype=np.float32)
232
+ bar = jnp.ones((2, 3, 4), dtype=np.float32)
233
+ inputs = xarray_jax.Dataset(
234
+ data_vars={'foo': (('lat', 'lon'), foo),
235
+ 'bar': (('time', 'lat', 'lon'), bar)},
236
+ coords={
237
+ 'time': np.arange(2),
238
+ 'lat': np.arange(3) * 10,
239
+ },
240
+ jax_coords={'lon': jnp.arange(4) * 10}
241
+ )
242
+ # Verify the jax_coord 'lon' retains jax data, and has not been created
243
+ # as an index coordinate:
244
+ self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
245
+ self.assertNotIn('lon', inputs.indexes)
246
+
247
+ @jax.jit
248
+ def fn(v):
249
+ # The non-JAX coords are passed with numpy array data and an index:
250
+ self.assertIsInstance(v.coords['lat'].data, np.ndarray)
251
+ self.assertIn('lat', v.indexes)
252
+
253
+ # The jax_coord is passed with JAX array data:
254
+ self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
255
+ self.assertNotIn('lon', v.indexes)
256
+
257
+ # Use the jax coord in the computation:
258
+ v = v + v.coords['lon']
259
+
260
+ # Return with an updated jax coord:
261
+ return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
262
+
263
+ _ = fn(inputs)
264
+ outputs = fn(inputs)
265
+
266
+ # Verify the jax_coord 'lon' has jax data in the output too:
267
+ self.assertIsInstance(
268
+ outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
269
+ self.assertNotIn('lon', outputs.indexes)
270
+
271
+ self.assertEqual(outputs.dims, inputs.dims)
272
+ chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
273
+ # Check our computations with the coordinate values worked:
274
+ chex.assert_trees_all_equal(
275
+ (outputs.coords['lon']).data,
276
+ (inputs.coords['lon']+1).data,
277
+ )
278
+ outputs_dict = {key: outputs[key].data for key in outputs}
279
+ inputs_and_inputs_coords_dict = {
280
+ key: (inputs + inputs.coords['lon'])[key].data
281
+ for key in inputs + inputs.coords['lon']
282
+ }
283
+ chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)
284
+
285
+ def test_flatten_unflatten_variable(self):
286
+ variable = xarray_jax.Variable(
287
+ ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
288
+ children, aux = xarray_jax._flatten_variable(variable)
289
+ # Check auxiliary info is hashable/comparable (important for jax.jit):
290
+ hash(aux)
291
+ self.assertEqual(aux, aux)
292
+ roundtrip = xarray_jax._unflatten_variable(aux, children)
293
+ self.assertTrue(variable.equals(roundtrip))
294
+
295
+ def test_flatten_unflatten_data_array(self):
296
+ data_array = xarray_jax.DataArray(
297
+ data=jnp.ones((3, 4), dtype=np.float32),
298
+ dims=('lat', 'lon'),
299
+ coords={'lat': np.arange(3)},
300
+ jax_coords={'lon': np.arange(4) * 10},
301
+ )
302
+ children, aux = xarray_jax._flatten_data_array(data_array)
303
+ # Check auxiliary info is hashable/comparable (important for jax.jit):
304
+ hash(aux)
305
+ self.assertEqual(aux, aux)
306
+ roundtrip = xarray_jax._unflatten_data_array(aux, children)
307
+ self.assertTrue(data_array.equals(roundtrip))
308
+
309
+ def test_flatten_unflatten_dataset(self):
310
+ foo = jnp.ones((3, 4), dtype=np.float32)
311
+ bar = jnp.ones((2, 3, 4), dtype=np.float32)
312
+ dataset = xarray_jax.Dataset(
313
+ data_vars={'foo': (('lat', 'lon'), foo),
314
+ 'bar': (('time', 'lat', 'lon'), bar)},
315
+ coords={
316
+ 'time': np.arange(2),
317
+ 'lat': np.arange(3) * 10},
318
+ jax_coords={
319
+ 'lon': np.arange(4) * 10})
320
+ children, aux = xarray_jax._flatten_dataset(dataset)
321
+ # Check auxiliary info is hashable/comparable (important for jax.jit):
322
+ hash(aux)
323
+ self.assertEqual(aux, aux)
324
+ roundtrip = xarray_jax._unflatten_dataset(aux, children)
325
+ self.assertTrue(dataset.equals(roundtrip))
326
+
327
+ def test_flatten_unflatten_added_dim(self):
328
+ data_array = xarray_jax.DataArray(
329
+ data=jnp.ones((3, 4), dtype=np.float32),
330
+ dims=('lat', 'lon'),
331
+ coords={'lat': np.arange(3),
332
+ 'lon': np.arange(4) * 10})
333
+ leaves, treedef = jax.tree_util.tree_flatten(data_array)
334
+ leaves = [jnp.expand_dims(x, 0) for x in leaves]
335
+ with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
336
+ with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
337
+ self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
338
+ xarray.testing.assert_identical(
339
+ jax.device_get(data_array),
340
+ jax.device_get(with_new_dim.isel(new=0)))
341
+
342
+ def test_map_added_dim(self):
343
+ data_array = xarray_jax.DataArray(
344
+ data=jnp.ones((3, 4), dtype=np.float32),
345
+ dims=('lat', 'lon'),
346
+ coords={'lat': np.arange(3),
347
+ 'lon': np.arange(4) * 10})
348
+ with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
349
+ with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0),
350
+ data_array)
351
+ self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
352
+ xarray.testing.assert_identical(
353
+ jax.device_get(data_array),
354
+ jax.device_get(with_new_dim.isel(new=0)))
355
+
356
+ def test_map_remove_dim(self):
357
+ foo = jnp.ones((1, 3, 4), dtype=np.float32)
358
+ bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
359
+ dataset = xarray_jax.Dataset(
360
+ data_vars={'foo': (('batch', 'lat', 'lon'), foo),
361
+ 'bar': (('batch', 'time', 'lat', 'lon'), bar)},
362
+ coords={
363
+ 'batch': np.array([123]),
364
+ 'time': np.arange(2),
365
+ 'lat': np.arange(3) * 10,
366
+ 'lon': np.arange(4) * 10})
367
+ with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
368
+ with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0),
369
+ dataset)
370
+ self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims)
371
+ self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims)
372
+ self.assertNotIn('batch', with_removed_dim.dims)
373
+ self.assertNotIn('batch', with_removed_dim.coords)
374
+ xarray.testing.assert_identical(
375
+ jax.device_get(dataset.isel(batch=0, drop=True)),
376
+ jax.device_get(with_removed_dim))
377
+
378
+ def test_pmap(self):
379
+ devices = jax.local_device_count()
380
+ foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
381
+ bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
382
+ dataset = xarray_jax.Dataset({
383
+ 'foo': (('device', 'lat', 'lon'), foo),
384
+ 'bar': (('device', 'time', 'lat', 'lon'), bar)})
385
+
386
+ def func(d):
387
+ self.assertNotIn('device', d.dims)
388
+ return d + 1
389
+ func = xarray_jax.pmap(func, dim='device')
390
+
391
+ result = func(dataset)
392
+ xarray.testing.assert_identical(
393
+ jax.device_get(dataset + 1),
394
+ jax.device_get(result))
395
+
396
+ # Can call it again with a different argument structure (it will recompile
397
+ # under the hood but should work):
398
+ dataset = dataset.drop_vars('foo')
399
+ result = func(dataset)
400
+ xarray.testing.assert_identical(
401
+ jax.device_get(dataset + 1),
402
+ jax.device_get(result))
403
+
404
+ def test_pmap_with_jax_coords(self):
405
+ devices = jax.local_device_count()
406
+ foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
407
+ bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
408
+ time = jnp.zeros((devices, 2), dtype=np.float32)
409
+ dataset = xarray_jax.Dataset(
410
+ {'foo': (('device', 'lat', 'lon'), foo),
411
+ 'bar': (('device', 'time', 'lat', 'lon'), bar)},
412
+ coords={
413
+ 'lat': np.arange(3),
414
+ 'lon': np.arange(4),
415
+ },
416
+ jax_coords={
417
+ # Currently any jax_coords need a leading device dimension to use
418
+ # with pmap, same as for data_vars.
419
+ # TODO(matthjw): have pmap automatically broadcast to all devices
420
+ # where the device dimension not present.
421
+ 'time': xarray_jax.Variable(('device', 'time'), time),
422
+ }
423
+ )
424
+
425
+ def func(d):
426
+ self.assertNotIn('device', d.dims)
427
+ self.assertNotIn('device', d.coords['time'].dims)
428
+
429
+ # The jax_coord 'time' should be passed in backed by a JAX array, but
430
+ # not as an index coordinate.
431
+ self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper)
432
+ self.assertNotIn('time', d.indexes)
433
+
434
+ return d + 1
435
+ func = xarray_jax.pmap(func, dim='device')
436
+
437
+ result = func(dataset)
438
+ xarray.testing.assert_identical(
439
+ jax.device_get(dataset + 1),
440
+ jax.device_get(result))
441
+
442
+ # Can call it again with a different argument structure (it will recompile
443
+ # under the hood but should work):
444
+ dataset = dataset.drop_vars('foo')
445
+ result = func(dataset)
446
+ xarray.testing.assert_identical(
447
+ jax.device_get(dataset + 1),
448
+ jax.device_get(result))
449
+
450
+ def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
451
+ devices = jax.local_device_count()
452
+ data_array = xarray_jax.DataArray(
453
+ data=jnp.ones((devices, 3, 4), dtype=np.float32),
454
+ dims=('device', 'lat', 'lon'))
455
+ plain_array = jnp.ones((devices, 2), dtype=np.float32)
456
+ inputs = {'foo': data_array,
457
+ 'bar': plain_array}
458
+
459
+ def func(x):
460
+ return x['foo'] + 1, x['bar'] + 1
461
+
462
+ func = xarray_jax.pmap(func, dim='device')
463
+ result_foo, result_bar = func(inputs)
464
+ xarray.testing.assert_identical(
465
+ jax.device_get(inputs['foo'] + 1),
466
+ jax.device_get(result_foo))
467
+ np.testing.assert_array_equal(
468
+ jax.device_get(inputs['bar'] + 1),
469
+ jax.device_get(result_bar))
470
+
471
+ def test_pmap_complains_when_dim_not_first(self):
472
+ devices = jax.local_device_count()
473
+ data_array = xarray_jax.DataArray(
474
+ data=jnp.ones((3, devices, 4), dtype=np.float32),
475
+ dims=('lat', 'device', 'lon'))
476
+
477
+ func = xarray_jax.pmap(lambda x: x+1, dim='device')
478
+ with self.assertRaisesRegex(
479
+ ValueError, 'Expected dim device at index 0, found at 1'):
480
+ func(data_array)
481
+
482
+ def test_apply_ufunc(self):
483
+ inputs = xarray_jax.DataArray(
484
+ data=jnp.asarray([[1, 2], [3, 4]]),
485
+ dims=('x', 'y'),
486
+ coords={'x': [0, 1],
487
+ 'y': [2, 3]})
488
+ result = xarray_jax.apply_ufunc(
489
+ lambda x: jnp.sum(x, axis=-1),
490
+ inputs,
491
+ input_core_dims=[['x']])
492
+ expected_result = xarray_jax.DataArray(
493
+ data=[4, 6],
494
+ dims=('y',),
495
+ coords={'y': [2, 3]})
496
+ xarray.testing.assert_identical(expected_result, jax.device_get(result))
497
+
498
+ def test_apply_ufunc_multiple_return_values(self):
499
+ def ufunc(array):
500
+ return jnp.min(array, axis=-1), jnp.max(array, axis=-1)
501
+ inputs = xarray_jax.DataArray(
502
+ data=jnp.asarray([[1, 4], [3, 2]]),
503
+ dims=('x', 'y'),
504
+ coords={'x': [0, 1],
505
+ 'y': [2, 3]})
506
+ result = xarray_jax.apply_ufunc(
507
+ ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []])
508
+ expected = (
509
+ # Mins:
510
+ xarray_jax.DataArray(
511
+ data=[1, 2],
512
+ dims=('y',),
513
+ coords={'y': [2, 3]}
514
+ ),
515
+ # Maxes:
516
+ xarray_jax.DataArray(
517
+ data=[3, 4],
518
+ dims=('y',),
519
+ coords={'y': [2, 3]}
520
+ )
521
+ )
522
+ xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
523
+ xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))
524
+
525
+ if __name__ == '__main__':
526
+ absltest.main()
graphcast/xarray_tree.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities for working with trees of xarray.DataArray (including Datasets).
15
+
16
+ Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
17
+ it won't work as a leaf node since it implements Mapping, but also won't work
18
+ as an internal node since tree doesn't know how to re-create it properly.
19
+
20
+ To fix this, we reimplement a subset of `map_structure`, exposing its
21
+ constituent DataArrays as leaf nodes. This means it can be mapped over as a
22
+ generic container of DataArrays, while still preserving the result as a Dataset
23
+ where possible.
24
+
25
+ This is useful because in a few places we need to handle a general
26
+ Mapping[str, DataArray] (where the coordinates might not be compatible across
27
+ the constituent DataArrays) but also the special case of a Dataset nicely.
28
+
29
+ For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
30
+ some of the child DataArrays, they will be omitted from the returned dataset. If
31
+ any values other than DataArrays or None are returned, then we don't attempt to
32
+ return a Dataset and just return a plain dict of the results. Similarly if
33
+ DataArrays are returned but with non-matching coordinates, it will just return a
34
+ plain dict of DataArrays.
35
+
36
+ Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
37
+ but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
38
+ as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
39
+ latter exposes DataArrays as leaf nodes.
40
+ """
41
+
42
+ from typing import Any, Callable
43
+
44
+ import xarray
45
+
46
+
47
+ def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
48
+ """Maps func through given structures with xarrays. See tree.map_structure."""
49
+ if not callable(func):
50
+ raise TypeError(f'func must be callable, got: {func}')
51
+ if not structures:
52
+ raise ValueError('Must provide at least one structure')
53
+
54
+ first = structures[0]
55
+ if isinstance(first, xarray.Dataset):
56
+ data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
57
+ if all(isinstance(a, (type(None), xarray.DataArray))
58
+ for a in data.values()):
59
+ data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
60
+ try:
61
+ return xarray.merge(data_arrays, join='exact')
62
+ except ValueError: # Exact join not possible.
63
+ pass
64
+ return data
65
+ if isinstance(first, dict):
66
+ return {k: map_structure(func, *[s[k] for s in structures])
67
+ for k in first.keys()}
68
+ if isinstance(first, (list, tuple, set)):
69
+ return type(first)(map_structure(func, *s) for s in zip(*structures))
70
+ return func(*structures)
graphcast/xarray_tree_test.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS-IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for xarray_tree."""
15
+
16
+ from absl.testing import absltest
17
+ from graphcast import xarray_tree
18
+ import numpy as np
19
+ import xarray
20
+
21
+
22
+ TEST_DATASET = xarray.Dataset(
23
+ data_vars={
24
+ "foo": (("x", "y"), np.zeros((2, 3))),
25
+ "bar": (("x",), np.zeros((2,))),
26
+ },
27
+ coords={
28
+ "x": [1, 2],
29
+ "y": [10, 20, 30],
30
+ }
31
+ )
32
+
33
+
34
+ class XarrayTreeTest(absltest.TestCase):
35
+
36
+ def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self):
37
+ def fn(leaf):
38
+ self.assertIsInstance(leaf, xarray.DataArray)
39
+ result = leaf + 1
40
+ # Removing the name from the returned DataArray to test that we don't rely
41
+ # on it being present to restore the correct names in the result:
42
+ result = result.rename(None)
43
+ return result
44
+
45
+ result = xarray_tree.map_structure(fn, TEST_DATASET)
46
+ self.assertIsInstance(result, xarray.Dataset)
47
+ self.assertSameElements({"foo", "bar"}, result.keys())
48
+
49
+ def test_map_structure_on_data_arrays(self):
50
+ data_arrays = dict(TEST_DATASET)
51
+ result = xarray_tree.map_structure(lambda x: x+1, data_arrays)
52
+ self.assertIsInstance(result, dict)
53
+ self.assertSameElements({"foo", "bar"}, result.keys())
54
+
55
+ def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self):
56
+ def fn(leaf):
57
+ # Returns DataArrays that can't be exactly merged back into a Dataset
58
+ # due to the coordinates not matching:
59
+ if leaf.name == "foo":
60
+ return xarray.DataArray(
61
+ data=np.zeros(2), dims=("x",), coords={"x": [1, 2]})
62
+ else:
63
+ return xarray.DataArray(
64
+ data=np.zeros(2), dims=("x",), coords={"x": [3, 4]})
65
+
66
+ result = xarray_tree.map_structure(fn, TEST_DATASET)
67
+ self.assertIsInstance(result, dict)
68
+ self.assertSameElements({"foo", "bar"}, result.keys())
69
+
70
+ def test_map_structure_on_dataset_drops_vars_with_none_return_values(self):
71
+ def fn(leaf):
72
+ return leaf if leaf.name == "foo" else None
73
+
74
+ result = xarray_tree.map_structure(fn, TEST_DATASET)
75
+ self.assertIsInstance(result, xarray.Dataset)
76
+ self.assertSameElements({"foo"}, result.keys())
77
+
78
+ def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self):
79
+ def fn(leaf):
80
+ self.assertIsInstance(leaf, xarray.DataArray)
81
+ return "not a DataArray"
82
+
83
+ result = xarray_tree.map_structure(fn, TEST_DATASET)
84
+ self.assertEqual({"foo": "not a DataArray",
85
+ "bar": "not a DataArray"}, result)
86
+
87
+ def test_map_structure_two_args_different_variable_orders(self):
88
+ dataset_different_order = TEST_DATASET[["bar", "foo"]]
89
+ def fn(arg1, arg2):
90
+ self.assertEqual(arg1.name, arg2.name)
91
+ xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ absltest.main()