Source code for

"""Defines jax-compatible DataArrays."""
from __future__ import annotations

from typing import Tuple, Any, Dict, List

import h5py
import pydantic.v1 as pd
import numpy as np
import jax.numpy as jnp
import jax
from jax.tree_util import register_pytree_node_class
import xarray as xr

from .....components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
from .....exceptions import DataError, Tidy3dKeyError, AdjointError

# condition setting when to set value in DataArray to zero:
# if abs(val) <= VALUE_FILTER_THRESHOLD * max(abs(val))

# JaxDataArrays are written to json as JAX_DATA_ARRAY_TAG
JAX_DATA_ARRAY_TAG = "<<JaxDataArray>>"

[docs] @register_pytree_node_class class JaxDataArray(Tidy3dBaseModel): """A :class:`.DataArray`-like class that only wraps xarray for jax compatibility.""" values: Any = pd.Field( ..., title="Values", description="Nested list containing the raw values, which can be tracked by jax.", jax_field=True, ) coords: Dict[str, list] = pd.Field( ..., title="Coords", description="Dictionary storing the coordinates, namely ``(direction, f, mode_index)``.", )
[docs] def to_tidy3d(self: JaxDataArray) -> xr.DataArray: """Convert :class:`.JaxDataArray` instance to ``xr.DataArray`` instance.""" coords = {k: np.array(v).tolist() for k, v in self.coords.items()} return xr.DataArray(np.array(self.values), coords=coords, dims=self.coords.keys())
[docs] @classmethod def from_tidy3d(cls, tidy3d_obj: xr.DataArray) -> JaxDataArray: """Convert ``xr.DataArray`` instance to :class:`.JaxDataArray`.""" coords = {k: np.array(v).tolist() for k, v in tidy3d_obj.coords.items()} return cls(, coords=coords)
@pd.validator("values", always=True) def _convert_values_to_np(cls, val): """Convert supplied values to numpy if they are list (from file).""" if isinstance(val, list): return np.array(val) return val @pd.validator("coords", always=True) @skip_if_fields_missing(["values"]) def _coords_match_values(cls, val, values): """Make sure the coordinate dimensions and shapes match the values data.""" _values = values.get("values") # get the shape, handling both regular and jax objects try: values_shape = np.array(_values).shape except TypeError: values_shape = jnp.array(_values).shape for (key, coord_val), size_dim in zip(val.items(), values_shape): if len(coord_val) != size_dim: raise ValueError( f"JaxDataArray coord {key} has {len(coord_val)} elements, " "which doesn't match the values array " f"with size {size_dim} along that dimension." ) return val @pd.validator("coords", always=True) def _convert_coords_to_list(cls, val): """Convert supplied coordinates to Dict[str, list].""" return {coord_name: list(coord_list) for coord_name, coord_list in val.items()}
[docs] def __eq__(self, other) -> bool: """Check if two ``JaxDataArray`` instances are equal.""" return jnp.array_equal(self.values, other.values)
[docs] def to_hdf5(self, fname: str, group_path: str) -> None: """Save an xr.DataArray to the hdf5 file with a given path to the group.""" sub_group = fname.create_group(group_path) sub_group["values"] = self.values dims = [] for key, val in self.coords.items(): # sub_group[key] = val dims.append(key) val = np.array(val) if val.dtype == "<U1": sub_group[key] = val.tolist() else: sub_group[key] = val sub_group["dims"] = dims
[docs] @classmethod def from_hdf5(cls, fname: str, group_path: str) -> JaxDataArray: """Load an DataArray from an hdf5 file with a given path to the group.""" with h5py.File(fname, "r") as f: sub_group = f[group_path] values = np.array(sub_group["values"]) dims = sub_group["dims"] coords = {dim: np.array(sub_group[dim]) for dim in dims} for key, val in coords.items(): val = np.array(val) if val.dtype == "O": coords[key] = [byte_string.decode() for byte_string in val.tolist()] coords = { key: val.tolist() if isinstance(val, np.ndarray) else val for key, val in coords.items() } return cls(values=values, coords=coords)
@cached_property def as_ndarray(self) -> np.ndarray: """``self.values`` as a numpy array.""" if not isinstance(self.values, np.ndarray): return np.array(self.values) return self.values @cached_property def as_jnp_array(self) -> jnp.ndarray: """``self.values`` as a jax array.""" if not isinstance(self.values, jnp.ndarray): return jnp.array(self.values) return self.values @cached_property def shape(self) -> tuple: """Shape of self.values.""" return self.as_jnp_array.shape @cached_property def as_list(self) -> list: """``self.values`` as a numpy array converted to a list.""" return self.as_ndarray.tolist() @cached_property def real(self) -> np.ndarray: """Real part of self.""" new_values = jnp.real(self.as_jnp_array) return self.copy(update=dict(values=new_values)) @cached_property def imag(self) -> np.ndarray: """Imaginary part of self.""" new_values = jnp.imag(self.as_jnp_array) return self.copy(update=dict(values=new_values))
[docs] def conj(self) -> JaxDataArray: """Complex conjugate of self.""" new_values = jnp.conj(self.as_jnp_array) return self.copy(update=dict(values=new_values))
[docs] def __abs__(self) -> JaxDataArray: """Absolute value of self's values.""" new_values = jnp.abs(self.as_jnp_array) return self.updated_copy(values=new_values)
[docs] def __pow__(self, power: int) -> JaxDataArray: """Values raised to a power.""" new_values = self.as_jnp_array**power return self.updated_copy(values=new_values)
[docs] def __add__(self, other: JaxDataArray) -> JaxDataArray: """Sum self with something else.""" if isinstance(other, JaxDataArray): new_values = self.as_jnp_array + other.as_jnp_array else: new_values = self.as_jnp_array + other return self.updated_copy(values=new_values)
[docs] def __neg__(self) -> JaxDataArray: """Negative of self.""" new_values = -self.as_jnp_array return self.updated_copy(values=new_values)
[docs] def __sub__(self, other) -> JaxDataArray: """Subtraction""" return self + (-other)
[docs] def __radd__(self, other) -> JaxDataArray: """Sum self with something else.""" return self + other
[docs] def __mul__(self, other: JaxDataArray) -> JaxDataArray: """Multiply self with something else.""" if isinstance(other, JaxDataArray): new_values = self.as_jnp_array * other.as_jnp_array elif isinstance(other, xr.DataArray): # handle case where other is missing dims present in self new_shape = list(self.shape) for dim_index, dim in enumerate(self.coords.keys()): if dim not in other.dims: other = other.expand_dims(dim=dim) new_shape[dim_index] = 1 other_values = other.values.reshape(new_shape) new_values = self.as_jnp_array * other_values else: new_values = self.as_jnp_array * other return self.updated_copy(values=new_values)
[docs] def __rmul__(self, other) -> JaxDataArray: """Multiply self with something else.""" return self * other
[docs] def sum(self, dim: str = None): """Sum (optionally along a single or multiple dimensions).""" if dim is None: return jnp.sum(self.as_jnp_array) # dim is supplied if isinstance(dim, str): axis = list(self.coords.keys()).index(dim) new_values = jnp.sum(self.values, axis=axis) new_coords = self.coords.copy() new_coords.pop(dim) return self.updated_copy(values=new_values, coords=new_coords) # dim is iterative, recursively call sum with single dim ret = self.copy() for dim_i in dim: ret = ret.sum(dim=dim_i) return ret
[docs] def squeeze(self, dim: str = None, drop: bool = True) -> JaxDataArray: """Remove any non-zero dims.""" if dim is None: new_values = jnp.squeeze(self.as_jnp_array) new_coords = {} for (key, val), dim_size in zip(self.coords.items(), self.values.shape): if dim_size > 1: new_coords.update({key: val}) else: axis = list(self.coords.keys()).index(dim) new_values = jnp.array(jnp.squeeze(self.as_jnp_array, axis=axis)) new_coords = self.coords.copy() new_coords.pop(dim) return self.updated_copy(values=new_values, coords=new_coords)
[docs] def get_coord_list(self, coord_name: str) -> list: """Get a coordinate list by name.""" if coord_name not in self.coords: raise Tidy3dKeyError(f"Could not select '{coord_name}', not found in coords dict.") return self.coords.get(coord_name)
[docs] def isel_single(self, coord_name: str, coord_index: int) -> JaxDataArray: """Select a value corresponding to a single coordinate from the :class:`.JaxDataArray`.""" # select out the proper values and coordinates coord_axis = list(self.coords.keys()).index(coord_name) values = self.as_jnp_array new_values = jnp.take(values, indices=coord_index, axis=coord_axis) new_coords = self.coords.copy() # if the coord index has more than one item, keep that coordinate coord_index = np.array(coord_index) if len(coord_index.shape) >= 1: coord_indices = coord_index.tolist() new_coord_vals = [self.coords[coord_name][coord_index] for coord_index in coord_indices] new_coords[coord_name] = new_coord_vals else: new_coords.pop(coord_name) # return just the values if no coordinate remain if not new_coords: if new_values.shape: raise AdjointError( "All coordinates selected out, but raw data values are still multi-dimensional." " If you encountered this error, please raise an issue on the Tidy3D " "front end github repository so we can look into the source of the bug." ) return new_values # otherwise, return another JaxDataArray with the values and coords selected out return self.copy(update=dict(values=new_values, coords=new_coords))
[docs] def isel(self, **isel_kwargs) -> JaxDataArray: """Select a value from the :class:`.JaxDataArray` by indexing into coordinates by index.""" self_sel = self.copy() for coord_name, coord_index in isel_kwargs.items(): coord_index = np.array(coord_index) coord_list = self_sel.get_coord_list(coord_name) if np.any(coord_index < 0) or np.any(coord_index >= len(coord_list)): raise DataError( f"'isel' kwarg '{coord_name}={coord_index}' is out of range " f"for the coordinate '{coord_name}' with {len(coord_list)} values." ) self_sel = self_sel.isel_single(coord_name=coord_name, coord_index=coord_index) return self_sel
[docs] def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> JaxDataArray: """Select a value from the :class:`.JaxDataArray` by indexing into coordinate values.""" isel_kwargs = {} for coord_name, sel_kwarg in sel_kwargs.items(): coord_list = self.get_coord_list(coord_name) if isinstance(sel_kwarg, (tuple, list, np.ndarray)): sel_kwarg = list(sel_kwarg) isel_kwargs[coord_name] = [] for _sel_kwarg in sel_kwarg: if _sel_kwarg not in coord_list: raise DataError( f"Could not select '{coord_name}={_sel_kwarg}', value not found." ) coord_index = coord_list.index(_sel_kwarg) isel_kwargs[coord_name].append(coord_index) else: if sel_kwarg not in coord_list: raise DataError( f"Could not select '{coord_name}={sel_kwarg}', value not found." ) coord_index = coord_list.index(sel_kwarg) isel_kwargs[coord_name] = coord_index return self.isel(**isel_kwargs)
[docs] def assign_coords(self, coords: dict = None, **coords_kwargs) -> JaxDataArray: """Assign new coordinates to this object.""" update_kwargs = self.coords.copy() for key, val in coords_kwargs.items(): update_kwargs[key] = val if coords: for key, val in coords.items(): update_kwargs[key] = val update_kwargs = {key: np.array(value).tolist() for key, value in update_kwargs.items()} return self.updated_copy(coords=update_kwargs)
[docs] def multiply_at(self, value: complex, coord_name: str, indices: List[int]) -> JaxDataArray: """Multiply self by value at indices into .""" axis = list(self.coords.keys()).index(coord_name) scalar_data_arr = self.as_jnp_array scalar_data_arr = jnp.moveaxis(scalar_data_arr, axis, 0) scalar_data_arr =[indices].multiply(value) scalar_data_arr = jnp.moveaxis(scalar_data_arr, 0, axis) return self.updated_copy(values=scalar_data_arr)
[docs] def interp_single(self, key: str, val: float) -> JaxDataArray: """Interpolate into a single dimension of self. Note: this interpolation works by finding the index of the value into the coords list. Instead of an integer value, we use interpolation to get a floating point index. The floor() of this value is the 'minus' index and the ceil() gives the 'plus' index. We then apply coefficients linearly based on how close to `plus` or minus we are. This is a workaround to `jnp.interp` not allowing multi-dimensional interpolation. """ val = jax.lax.stop_gradient(val) # get the coordinates associated with this key. if key not in self.coords: raise Tidy3dKeyError(f"Key '{key}' not found in JaxDataArray coords.") coords_1d = jnp.array(self.coords[key]) axis = list(self.coords.keys()).index(key) # get floating point index of the value into these coordinates coord_indices = jnp.arange(len(coords_1d)) index_interp = jnp.interp(x=val, xp=coords_1d, fp=coord_indices) # strip out the linear interpolation coefficients from the float index index_minus = np.array(index_interp).astype(int) index_plus = index_minus + 1 coeff_plus = index_interp - index_minus # if any plus_index is out of range, set it in range (coeff will be 0 anyway) if index_plus.shape: index_plus[index_plus >= len(coord_indices)] = len(coord_indices) - 1 else: if index_plus > len(coord_indices): index_plus = index_minus coeff_plus = 0.0 coeff_minus = 1 - coeff_plus def get_values_at_index(key: str, index: int) -> jnp.array: """grab values array at index into coordinate key.""" values_sel = self.isel(**{key: index}) if isinstance(values_sel, JaxDataArray): return values_sel.values return values_sel # return weighted average of this object along these dimensions values_minus = get_values_at_index(key=key, index=index_minus) if np.any(coeff_plus > 0): values_plus = get_values_at_index(key=key, index=index_plus) if coeff_minus.shape: coeff_shape = np.ones(len(values_minus.shape), dtype=int) coeff_shape[axis] = len(coeff_minus) coeff_minus = coeff_minus.reshape(coeff_shape) coeff_plus = coeff_plus.reshape(coeff_shape) values_interp = coeff_minus * values_minus + coeff_plus * values_plus else: values_interp = values_minus # construct a new JaxDataArray to return coords_interp = self.coords.copy() if jnp.array(index_interp).size <= 1: coords_interp.pop(key) else: coords_interp[key] = np.atleast_1d(val).tolist() if coords_interp: return JaxDataArray(values=values_interp, coords=coords_interp) return values_interp
[docs] def interp(self, kwargs=None, assume_sorted=None, **interp_kwargs) -> JaxDataArray: """Linearly interpolate into the :class:`.JaxDataArray` at values into coordinates.""" # note: kwargs does nothing, only used for making this subclass compatible with super ret_value = self.copy() for key, val in interp_kwargs.items(): ret_value = ret_value.interp_single(key=key, val=val) return ret_value
@cached_property def nonzero_val_coords(self) -> Tuple[List[complex], Dict[str, Any]]: """The value and coordinate associated with the only non-zero element of ``self.values``.""" values = np.nan_to_num(self.as_ndarray) # filter out values that are very small relative to maximum values_filtered = values.copy() max_value = np.max(np.abs(values_filtered)) val_cutoff = VALUE_FILTER_THRESHOLD * max_value values_filtered[np.abs(values_filtered) <= val_cutoff] = 0.0 nonzero_inds = np.nonzero(values_filtered) nonzero_values = values_filtered[nonzero_inds].tolist() nonzero_coords = {} for nz_inds, (coord_name, coord_list) in zip(nonzero_inds, self.coords.items()): coord_array = np.array(coord_list) nonzero_coords[coord_name] = coord_array[nz_inds].tolist() return nonzero_values, nonzero_coords
[docs] def tree_flatten(self) -> Tuple[list, dict]: """Jax works on the values, stash the coords for reconstruction.""" return self.values, self.coords
[docs] @classmethod def tree_unflatten(cls, aux_data, children) -> JaxDataArray: """How to unflatten the values and coords.""" return cls(values=children, coords=aux_data)