tidy3d.plugins.adjoint.JaxDataArray#

class JaxDataArray[source]#

Bases: Tidy3dBaseModel

A DataArray-like class that only wraps xarray for jax compatibility.

Parameters:
  • values (Attribute: values) –

    Type

    Optional[Any]

    Default

    Description

    Nested list containing the raw values, which can be tracked by jax.

  • coords (Attribute: coords) –

    Type

    Mapping[str, list]

    Default

    Description

    Dictionary storing the coordinates, namely (direction, f, mode_index).

Attributes

as_jnp_array

self.values as a jax array.

as_list

self.values as a numpy array converted to a list.

as_ndarray

self.values as a numpy array.

imag

Imaginary part of self.

nonzero_val_coords

The value and coordinate associated with the only non-zero element of self.values.

real

Real part of self.

shape

Shape of self.values.

Methods

assign_coords([coords])

Assign new coordinates to this object.

conj()

Complex conjugate of self.

from_hdf5(fname, group_path)

Load an DataArray from an hdf5 file with a given path to the group.

from_tidy3d(tidy3d_obj)

Convert xr.DataArray instance to JaxDataArray.

get_coord_list(coord_name)

Get a coordinate list by name.

interp([kwargs, assume_sorted])

Linearly interpolate into the JaxDataArray at values into coordinates.

interp_single(key, val)

Interpolate into a single dimension of self.

isel(**isel_kwargs)

Select a value from the JaxDataArray by indexing into coordinates by index.

isel_single(coord_name, coord_index)

Select a value corresponding to a single coordinate from the JaxDataArray.

multiply_at(value, coord_name, indices)

Multiply self by value at indices into .

sel([indexers, method])

Select a value from the JaxDataArray by indexing into coordinate values.

squeeze([dim, drop])

Remove any non-zero dims.

sum([dim])

Sum (optionally along a single or multiple dimensions).

to_hdf5(fname, group_path)

Save an xr.DataArray to the hdf5 file with a given path to the group.

to_tidy3d()

Convert JaxDataArray instance to xr.DataArray instance.

tree_flatten()

Jax works on the values, stash the coords for reconstruction.

tree_unflatten(aux_data, children)

How to unflatten the values and coords.

values#
coords#
to_tidy3d()[source]#

Convert JaxDataArray instance to xr.DataArray instance.

classmethod from_tidy3d(tidy3d_obj)[source]#

Convert xr.DataArray instance to JaxDataArray.

__eq__(other)[source]#

Check if two JaxDataArray instances are equal.

to_hdf5(fname, group_path)[source]#

Save an xr.DataArray to the hdf5 file with a given path to the group.

classmethod from_hdf5(fname, group_path)[source]#

Load an DataArray from an hdf5 file with a given path to the group.

property as_ndarray#

self.values as a numpy array.

property as_jnp_array#

self.values as a jax array.

property shape#

Shape of self.values.

property as_list#

self.values as a numpy array converted to a list.

property real#

Real part of self.

property imag#

Imaginary part of self.

conj()[source]#

Complex conjugate of self.

__abs__()[source]#

Absolute value of self’s values.

__pow__(power)[source]#

Values raised to a power.

__add__(other)[source]#

Sum self with something else.

__neg__()[source]#

Negative of self.

__sub__(other)[source]#

Subtraction

__radd__(other)[source]#

Sum self with something else.

__mul__(other)[source]#

Multiply self with something else.

__rmul__(other)[source]#

Multiply self with something else.

sum(dim=None)[source]#

Sum (optionally along a single or multiple dimensions).

squeeze(dim=None, drop=True)[source]#

Remove any non-zero dims.

get_coord_list(coord_name)[source]#

Get a coordinate list by name.

isel_single(coord_name, coord_index)[source]#

Select a value corresponding to a single coordinate from the JaxDataArray.

isel(**isel_kwargs)[source]#

Select a value from the JaxDataArray by indexing into coordinates by index.

sel(indexers=None, method='nearest', **sel_kwargs)[source]#

Select a value from the JaxDataArray by indexing into coordinate values.

assign_coords(coords=None, **coords_kwargs)[source]#

Assign new coordinates to this object.

multiply_at(value, coord_name, indices)[source]#

Multiply self by value at indices into .

interp_single(key, val)[source]#

Interpolate into a single dimension of self.

Note: this interpolation works by finding the index of the value into the coords list. Instead of an integer value, we use interpolation to get a floating point index. The floor() of this value is the ‘minus’ index and the ceil() gives the ‘plus’ index. We then apply coefficients linearly based on how close to plus or minus we are. This is a workaround to jnp.interp not allowing multi-dimensional interpolation.

interp(kwargs=None, assume_sorted=None, **interp_kwargs)[source]#

Linearly interpolate into the JaxDataArray at values into coordinates.

property nonzero_val_coords#

The value and coordinate associated with the only non-zero element of self.values.

tree_flatten()[source]#

Jax works on the values, stash the coords for reconstruction.

classmethod tree_unflatten(aux_data, children)[source]#

How to unflatten the values and coords.

__hash__()#

Hash method.