tidy3d.plugins.adjoint.JaxDataArray#
- class JaxDataArray[source]#
Bases:
Tidy3dBaseModel
A
DataArray
-like class that only wraps xarray for jax compatibility.- Parameters:
attrs (dict = {}) – Dictionary storing arbitrary metadata for a Tidy3D object. This dictionary can be freely used by the user for storing data without affecting the operation of Tidy3D as it is not used internally. Note that, unlike regular Tidy3D fields,
attrs
are mutable. For example, the following is allowed for setting anattr
obj.attrs['foo'] = bar
. Also note that Tidy3D` will raise aTypeError
ifattrs
contain objects that can not be serialized. One can check ifattrs
are serializable by callingobj.json()
.values (Optional[Any]) – Nested list containing the raw values, which can be tracked by jax.
coords (Mapping[str, list]) – Dictionary storing the coordinates, namely
(direction, f, mode_index)
.
Attributes
self.values
as a jax array.self.values
as a numpy array converted to a list.self.values
as a numpy array.Imaginary part of self.
The value and coordinate associated with the only non-zero element of
self.values
.Real part of self.
Shape of self.values.
Methods
assign_coords
([coords])Assign new coordinates to this object.
conj
()Complex conjugate of self.
from_hdf5
(fname, group_path)Load an DataArray from an hdf5 file with a given path to the group.
from_tidy3d
(tidy3d_obj)Convert
xr.DataArray
instance toJaxDataArray
.get_coord_list
(coord_name)Get a coordinate list by name.
interp
([kwargs, assume_sorted])Linearly interpolate into the
JaxDataArray
at values into coordinates.interp_single
(key, val)Interpolate into a single dimension of self.
isel
(**isel_kwargs)Select a value from the
JaxDataArray
by indexing into coordinates by index.isel_single
(coord_name, coord_index)Select a value corresponding to a single coordinate from the
JaxDataArray
.multiply_at
(value, coord_name, indices)Multiply self by value at indices into .
sel
([indexers, method])Select a value from the
JaxDataArray
by indexing into coordinates by value.squeeze
([dim, drop])Remove any non-zero dims.
sum
([dim])Sum (optionally along a single or multiple dimensions).
to_hdf5
(fname, group_path)Save an xr.DataArray to the hdf5 file with a given path to the group.
Convert
JaxDataArray
instance toxr.DataArray
instance.Jax works on the values, stash the coords for reconstruction.
tree_unflatten
(aux_data, children)How to unflatten the values and coords.
Inherited Common Usage
- values#
- coords#
- to_tidy3d()[source]#
Convert
JaxDataArray
instance toxr.DataArray
instance.
- classmethod from_tidy3d(tidy3d_obj)[source]#
Convert
xr.DataArray
instance toJaxDataArray
.
- to_hdf5(fname, group_path)[source]#
Save an xr.DataArray to the hdf5 file with a given path to the group.
- classmethod from_hdf5(fname, group_path)[source]#
Load an DataArray from an hdf5 file with a given path to the group.
- property as_ndarray#
self.values
as a numpy array.
- property as_jnp_array#
self.values
as a jax array.
- property shape#
Shape of self.values.
- property as_list#
self.values
as a numpy array converted to a list.
- property real#
Real part of self.
- property imag#
Imaginary part of self.
- isel_single(coord_name, coord_index)[source]#
Select a value corresponding to a single coordinate from the
JaxDataArray
.
- isel(**isel_kwargs)[source]#
Select a value from the
JaxDataArray
by indexing into coordinates by index.
- sel(indexers=None, method=None, **sel_kwargs)[source]#
Select a value from the
JaxDataArray
by indexing into coordinates by value.- Parameters:
sel_kwargs (dict) – Keyword arguments with names matching the coordinates of
JaxDataArray
and values given by scalars or lists, e.g. da.sel(x=0.1, y=[0.2, 0.3]).method (Literal[None, "nearest"] = None) –
Method to use for matching coordinate values:
None (default): only exact matches
nearest: use nearest valid index value
- Returns:
JaxDataArray with extracted values.
- Return type:
- interp_single(key, val)[source]#
Interpolate into a single dimension of self.
Note: this interpolation works by finding the index of the value into the coords list. Instead of an integer value, we use interpolation to get a floating point index. The floor() of this value is the ‘minus’ index and the ceil() gives the ‘plus’ index. We then apply coefficients linearly based on how close to plus or minus we are. This is a workaround to jnp.interp not allowing multi-dimensional interpolation.
- interp(kwargs=None, assume_sorted=None, **interp_kwargs)[source]#
Linearly interpolate into the
JaxDataArray
at values into coordinates.
- property nonzero_val_coords#
The value and coordinate associated with the only non-zero element of
self.values
.
- __hash__()#
Hash method.