"""Defines jax-compatible DataArrays."""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Sequence, Tuple, Union
import h5py
import jax
import jax.numpy as jnp
import numpy as np
import pydantic.v1 as pd
import xarray as xr
from jax.tree_util import register_pytree_node_class
from .....components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
from .....exceptions import AdjointError, DataError, Tidy3dKeyError
# condition setting when to set value in DataArray to zero:
# if abs(val) <= VALUE_FILTER_THRESHOLD * max(abs(val))
VALUE_FILTER_THRESHOLD = 1e-6
# JaxDataArrays are written to json as JAX_DATA_ARRAY_TAG
JAX_DATA_ARRAY_TAG = "<<JaxDataArray>>"
[docs]
@register_pytree_node_class
class JaxDataArray(Tidy3dBaseModel):
"""A :class:`.DataArray`-like class that only wraps xarray for jax compatibility."""
values: Any = pd.Field(
...,
title="Values",
description="Nested list containing the raw values, which can be tracked by jax.",
jax_field=True,
)
coords: Dict[str, list] = pd.Field(
...,
title="Coords",
description="Dictionary storing the coordinates, namely ``(direction, f, mode_index)``.",
)
[docs]
def to_tidy3d(self: JaxDataArray) -> xr.DataArray:
"""Convert :class:`.JaxDataArray` instance to ``xr.DataArray`` instance."""
coords = {k: np.array(v).tolist() for k, v in self.coords.items()}
return xr.DataArray(np.array(self.values), coords=coords, dims=self.coords.keys())
[docs]
@classmethod
def from_tidy3d(cls, tidy3d_obj: xr.DataArray) -> JaxDataArray:
"""Convert ``xr.DataArray`` instance to :class:`.JaxDataArray`."""
coords = {k: np.array(v).tolist() for k, v in tidy3d_obj.coords.items()}
return cls(values=tidy3d_obj.data, coords=coords)
@pd.validator("values", always=True)
def _convert_values_to_np(cls, val):
"""Convert supplied values to numpy if they are list (from file)."""
if isinstance(val, list):
return np.array(val)
return val
@pd.validator("coords", always=True)
@skip_if_fields_missing(["values"])
def _coords_match_values(cls, val, values):
"""Make sure the coordinate dimensions and shapes match the values data."""
_values = values.get("values")
# get the shape, handling both regular and jax objects
try:
values_shape = np.array(_values).shape
except TypeError:
values_shape = jnp.array(_values).shape
for (key, coord_val), size_dim in zip(val.items(), values_shape):
if len(coord_val) != size_dim:
raise ValueError(
f"JaxDataArray coord {key} has {len(coord_val)} elements, "
"which doesn't match the values array "
f"with size {size_dim} along that dimension."
)
return val
@pd.validator("coords", always=True)
def _convert_coords_to_list(cls, val):
"""Convert supplied coordinates to Dict[str, list]."""
return {coord_name: list(coord_list) for coord_name, coord_list in val.items()}
[docs]
def __eq__(self, other) -> bool:
"""Check if two ``JaxDataArray`` instances are equal."""
return jnp.array_equal(self.values, other.values)
[docs]
def to_hdf5(self, fname: str, group_path: str) -> None:
"""Save an xr.DataArray to the hdf5 file with a given path to the group."""
sub_group = fname.create_group(group_path)
sub_group["values"] = self.values
dims = []
for key, val in self.coords.items():
# sub_group[key] = val
dims.append(key)
val = np.array(val)
if val.dtype == "<U1":
sub_group[key] = val.tolist()
else:
sub_group[key] = val
sub_group["dims"] = dims
[docs]
@classmethod
def from_hdf5(cls, fname: str, group_path: str) -> JaxDataArray:
"""Load an DataArray from an hdf5 file with a given path to the group."""
with h5py.File(fname, "r") as f:
sub_group = f[group_path]
values = np.array(sub_group["values"])
dims = sub_group["dims"]
coords = {dim: np.array(sub_group[dim]) for dim in dims}
for key, val in coords.items():
val = np.array(val)
if val.dtype == "O":
coords[key] = [byte_string.decode() for byte_string in val.tolist()]
coords = {
key: val.tolist() if isinstance(val, np.ndarray) else val
for key, val in coords.items()
}
return cls(values=values, coords=coords)
@cached_property
def as_ndarray(self) -> np.ndarray:
"""``self.values`` as a numpy array."""
if not isinstance(self.values, np.ndarray):
return np.array(self.values)
return self.values
@cached_property
def as_jnp_array(self) -> jnp.ndarray:
"""``self.values`` as a jax array."""
if not isinstance(self.values, jnp.ndarray):
return jnp.array(self.values)
return self.values
@cached_property
def shape(self) -> tuple:
"""Shape of self.values."""
return self.as_jnp_array.shape
@cached_property
def as_list(self) -> list:
"""``self.values`` as a numpy array converted to a list."""
return self.as_ndarray.tolist()
@cached_property
def real(self) -> np.ndarray:
"""Real part of self."""
new_values = jnp.real(self.as_jnp_array)
return self.copy(update=dict(values=new_values))
@cached_property
def imag(self) -> np.ndarray:
"""Imaginary part of self."""
new_values = jnp.imag(self.as_jnp_array)
return self.copy(update=dict(values=new_values))
[docs]
def conj(self) -> JaxDataArray:
"""Complex conjugate of self."""
new_values = jnp.conj(self.as_jnp_array)
return self.copy(update=dict(values=new_values))
[docs]
def __abs__(self) -> JaxDataArray:
"""Absolute value of self's values."""
new_values = jnp.abs(self.as_jnp_array)
return self.updated_copy(values=new_values)
[docs]
def __pow__(self, power: int) -> JaxDataArray:
"""Values raised to a power."""
new_values = self.as_jnp_array**power
return self.updated_copy(values=new_values)
[docs]
def __add__(self, other: JaxDataArray) -> JaxDataArray:
"""Sum self with something else."""
if isinstance(other, JaxDataArray):
new_values = self.as_jnp_array + other.as_jnp_array
else:
new_values = self.as_jnp_array + other
return self.updated_copy(values=new_values)
[docs]
def __neg__(self) -> JaxDataArray:
"""Negative of self."""
new_values = -self.as_jnp_array
return self.updated_copy(values=new_values)
[docs]
def __sub__(self, other) -> JaxDataArray:
"""Subtraction"""
return self + (-other)
[docs]
def __radd__(self, other) -> JaxDataArray:
"""Sum self with something else."""
return self + other
[docs]
def __mul__(self, other: JaxDataArray) -> JaxDataArray:
"""Multiply self with something else."""
if isinstance(other, JaxDataArray):
new_values = self.as_jnp_array * other.as_jnp_array
elif isinstance(other, xr.DataArray):
# handle case where other is missing dims present in self
new_shape = list(self.shape)
for dim_index, dim in enumerate(self.coords.keys()):
if dim not in other.dims:
other = other.expand_dims(dim=dim)
new_shape[dim_index] = 1
other_values = other.values.reshape(new_shape)
new_values = self.as_jnp_array * other_values
else:
new_values = self.as_jnp_array * other
return self.updated_copy(values=new_values)
[docs]
def __rmul__(self, other) -> JaxDataArray:
"""Multiply self with something else."""
return self * other
[docs]
def sum(self, dim: str = None):
"""Sum (optionally along a single or multiple dimensions)."""
if dim is None:
return jnp.sum(self.as_jnp_array)
# dim is supplied
if isinstance(dim, str):
axis = list(self.coords.keys()).index(dim)
new_values = jnp.sum(self.values, axis=axis)
new_coords = self.coords.copy()
new_coords.pop(dim)
return self.updated_copy(values=new_values, coords=new_coords)
# dim is iterative, recursively call sum with single dim
ret = self.copy()
for dim_i in dim:
ret = ret.sum(dim=dim_i)
return ret
[docs]
def squeeze(self, dim: str = None, drop: bool = True) -> JaxDataArray:
"""Remove any non-zero dims."""
if dim is None:
new_values = jnp.squeeze(self.as_jnp_array)
new_coords = {}
for (key, val), dim_size in zip(self.coords.items(), self.values.shape):
if dim_size > 1:
new_coords.update({key: val})
else:
axis = list(self.coords.keys()).index(dim)
new_values = jnp.array(jnp.squeeze(self.as_jnp_array, axis=axis))
new_coords = self.coords.copy()
new_coords.pop(dim)
return self.updated_copy(values=new_values, coords=new_coords)
[docs]
def get_coord_list(self, coord_name: str) -> list:
"""Get a coordinate list by name."""
if coord_name not in self.coords:
raise Tidy3dKeyError(f"Could not select '{coord_name}', not found in coords dict.")
return self.coords.get(coord_name)
[docs]
def isel_single(self, coord_name: str, coord_index: int) -> JaxDataArray:
"""Select a value corresponding to a single coordinate from the :class:`.JaxDataArray`."""
# select out the proper values and coordinates
coord_axis = list(self.coords.keys()).index(coord_name)
values = self.as_jnp_array
new_values = jnp.take(values, indices=coord_index, axis=coord_axis)
new_coords = self.coords.copy()
# if the coord index has more than one item, keep that coordinate
coord_index = np.array(coord_index)
if len(coord_index.shape) >= 1:
coord_indices = coord_index.tolist()
new_coord_vals = [self.coords[coord_name][coord_index] for coord_index in coord_indices]
new_coords[coord_name] = new_coord_vals
else:
new_coords.pop(coord_name)
# return just the values if no coordinate remain
if not new_coords:
if new_values.shape:
raise AdjointError(
"All coordinates selected out, but raw data values are still multi-dimensional."
" If you encountered this error, please raise an issue on the Tidy3D "
"front end github repository so we can look into the source of the bug."
)
return new_values
# otherwise, return another JaxDataArray with the values and coords selected out
return self.copy(update=dict(values=new_values, coords=new_coords))
[docs]
def isel(self, **isel_kwargs) -> JaxDataArray:
"""Select a value from the :class:`.JaxDataArray` by indexing into coordinates by index."""
self_sel = self.copy()
for coord_name, coord_index in isel_kwargs.items():
coord_index = np.array(coord_index)
coord_list = self_sel.get_coord_list(coord_name)
if np.any(coord_index < 0) or np.any(coord_index >= len(coord_list)):
raise DataError(
f"'isel' kwarg '{coord_name}={coord_index}' is out of range "
f"for the coordinate '{coord_name}' with {len(coord_list)} values."
)
self_sel = self_sel.isel_single(coord_name=coord_name, coord_index=coord_index)
return self_sel
[docs]
def sel(
self, indexers: dict = None, method: Literal[None, "nearest"] = None, **sel_kwargs
) -> JaxDataArray:
"""Select a value from the :class:`.JaxDataArray` by indexing into coordinates by value.
Parameters
----------
sel_kwargs : dict
Keyword arguments with names matching the coordinates of :class:`.JaxDataArray` and values
given by scalars or lists, e.g. `da.sel(x=0.1, y=[0.2, 0.3])`.
method : Literal[None, "nearest"] = None
Method to use for matching coordinate values:
- None (default): only exact matches
- nearest: use nearest valid index value
Returns
-------
JaxDataArray
JaxDataArray with extracted values.
"""
if method not in [None, "nearest"]:
raise NotImplementedError(f"Unkown selection method: {method}.")
isel_kwargs = {}
for coord_name, vals in sel_kwargs.items():
coord_list = self.get_coord_list(coord_name)
try: # handle non-numeric types (e.g. str)
coord_list = jnp.asarray(coord_list)
except TypeError:
isel_kwargs[coord_name] = self._indices_literal(coord_list, vals)
continue
vals_ary = jnp.atleast_1d(vals)
dist = jnp.abs(coord_list[None] - vals_ary[:, None])
if method is None:
indices = jnp.where(jnp.isclose(dist, 0))[1]
elif method == "nearest":
indices = jnp.argmin(dist, axis=1)
if indices.size == 0:
raise DataError(
f"Could not select '{coord_name}={vals_ary}', some values were not found."
)
if np.isscalar(vals):
indices = jnp.squeeze(indices)
isel_kwargs[coord_name] = indices
return self.isel(**isel_kwargs)
def _indices_literal(self, coord_list: list, values: Union[Any, Sequence[Any]]) -> np.ndarray:
"""Find indices of non-numeric `values` in `coord_list`.
Parameters
----------
coord_list : list
List of all entries for a specific coordinate.
values : Union[Any, Sequence[Any]]
Single value or values of which to find the index of.
Returns
-------
numpy.ndarray
Indices of `values` in `coord_list`.
"""
indices = []
for v in np.atleast_1d(values):
if v not in coord_list:
raise DataError(f"Could not select '{v}' from '{coord_list}', value not found.")
indices.append(coord_list.index(v))
if np.isscalar(values):
indices = np.squeeze(indices)
return indices
[docs]
def assign_coords(self, coords: dict = None, **coords_kwargs) -> JaxDataArray:
"""Assign new coordinates to this object."""
update_kwargs = self.coords.copy()
for key, val in coords_kwargs.items():
update_kwargs[key] = val
if coords:
for key, val in coords.items():
update_kwargs[key] = val
update_kwargs = {key: np.array(value).tolist() for key, value in update_kwargs.items()}
return self.updated_copy(coords=update_kwargs)
[docs]
def multiply_at(self, value: complex, coord_name: str, indices: List[int]) -> JaxDataArray:
"""Multiply self by value at indices into ."""
axis = list(self.coords.keys()).index(coord_name)
scalar_data_arr = self.as_jnp_array
scalar_data_arr = jnp.moveaxis(scalar_data_arr, axis, 0)
scalar_data_arr = scalar_data_arr.at[indices].multiply(value)
scalar_data_arr = jnp.moveaxis(scalar_data_arr, 0, axis)
return self.updated_copy(values=scalar_data_arr)
[docs]
def interp_single(self, key: str, val: float) -> JaxDataArray:
"""Interpolate into a single dimension of self.
Note: this interpolation works by finding the index of the value into the coords list.
Instead of an integer value, we use interpolation to get a floating point index.
The floor() of this value is the 'minus' index and the ceil() gives the 'plus' index.
We then apply coefficients linearly based on how close to `plus` or minus we are.
This is a workaround to `jnp.interp` not allowing multi-dimensional interpolation.
"""
val = jax.lax.stop_gradient(val)
# get the coordinates associated with this key.
if key not in self.coords:
raise Tidy3dKeyError(f"Key '{key}' not found in JaxDataArray coords.")
coords_1d = jnp.array(self.coords[key])
axis = list(self.coords.keys()).index(key)
# get floating point index of the value into these coordinates
coord_indices = jnp.arange(len(coords_1d))
index_interp = jnp.interp(x=val, xp=coords_1d, fp=coord_indices)
# strip out the linear interpolation coefficients from the float index
index_minus = np.array(index_interp).astype(int)
index_plus = index_minus + 1
coeff_plus = index_interp - index_minus
# if any plus_index is out of range, set it in range (coeff will be 0 anyway)
if index_plus.shape:
index_plus[index_plus >= len(coord_indices)] = len(coord_indices) - 1
else:
if index_plus > len(coord_indices):
index_plus = index_minus
coeff_plus = 0.0
coeff_minus = 1 - coeff_plus
def get_values_at_index(key: str, index: int) -> jnp.array:
"""grab values array at index into coordinate key."""
values_sel = self.isel(**{key: index})
if isinstance(values_sel, JaxDataArray):
return values_sel.values
return values_sel
# return weighted average of this object along these dimensions
values_minus = get_values_at_index(key=key, index=index_minus)
if np.any(coeff_plus > 0):
values_plus = get_values_at_index(key=key, index=index_plus)
if coeff_minus.shape:
coeff_shape = np.ones(len(values_minus.shape), dtype=int)
coeff_shape[axis] = len(coeff_minus)
coeff_minus = coeff_minus.reshape(coeff_shape)
coeff_plus = coeff_plus.reshape(coeff_shape)
values_interp = coeff_minus * values_minus + coeff_plus * values_plus
else:
values_interp = values_minus
# construct a new JaxDataArray to return
coords_interp = self.coords.copy()
if jnp.array(index_interp).size <= 1:
coords_interp.pop(key)
else:
coords_interp[key] = np.atleast_1d(val).tolist()
if coords_interp:
return JaxDataArray(values=values_interp, coords=coords_interp)
return values_interp
[docs]
def interp(self, kwargs=None, assume_sorted=None, **interp_kwargs) -> JaxDataArray:
"""Linearly interpolate into the :class:`.JaxDataArray` at values into coordinates."""
# note: kwargs does nothing, only used for making this subclass compatible with super
ret_value = self.copy()
for key, val in interp_kwargs.items():
ret_value = ret_value.interp_single(key=key, val=val)
return ret_value
@cached_property
def nonzero_val_coords(self) -> Tuple[List[complex], Dict[str, Any]]:
"""The value and coordinate associated with the only non-zero element of ``self.values``."""
values = np.nan_to_num(self.as_ndarray)
# filter out values that are very small relative to maximum
values_filtered = values.copy()
max_value = np.max(np.abs(values_filtered))
val_cutoff = VALUE_FILTER_THRESHOLD * max_value
values_filtered[np.abs(values_filtered) <= val_cutoff] = 0.0
nonzero_inds = np.nonzero(values_filtered)
nonzero_values = values_filtered[nonzero_inds].tolist()
nonzero_coords = {}
for nz_inds, (coord_name, coord_list) in zip(nonzero_inds, self.coords.items()):
coord_array = np.array(coord_list)
nonzero_coords[coord_name] = coord_array[nz_inds].tolist()
return nonzero_values, nonzero_coords
[docs]
def tree_flatten(self) -> Tuple[list, dict]:
"""Jax works on the values, stash the coords for reconstruction."""
return self.values, self.coords
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children) -> JaxDataArray:
"""How to unflatten the values and coords."""
return cls(values=children, coords=aux_data)