tidy3d.plugins.adjoint.JaxPolySlab#

class JaxPolySlab[source]#

Bases: JaxGeometry, PolySlab, JaxObject

A PolySlab registered with jax.

Parameters:
  • attrs (dict = {}) – Dictionary storing arbitrary metadata for a Tidy3D object. This dictionary can be freely used by the user for storing data without affecting the operation of Tidy3D as it is not used internally. Note that, unlike regular Tidy3D fields, attrs are mutable. For example, the following is allowed for setting an attr obj.attrs['foo'] = bar. Also note that Tidy3D` will raise a TypeError if attrs contain objects that can not be serialized. One can check if attrs are serializable by calling obj.json().

  • axis (Literal[0, 1, 2] = 2) – Specifies dimension of the planar axis (0,1,2) -> (x,y,z).

  • sidewall_angle (ConstrainedFloatValue = 0.0) – [units = rad]. Angle of the sidewall. sidewall_angle=0 (default) specifies a vertical wall; 0<sidewall_angle<np.pi/2 specifies a shrinking cross section along the axis direction; and -np.pi/2<sidewall_angle<0 specifies an expanding cross section along the axis direction.

  • reference_plane (Literal['bottom', 'middle', 'top'] = middle) – The position of the plane where the supplied cross section are defined. The plane is perpendicular to the axis. The plane is located at the bottom, middle, or top of the geometry with respect to the axis. E.g. if axis=1, bottom refers to the negative side of the y-axis, and top refers to the positive side of the y-axis.

  • slab_bounds (Tuple[float, float]) – [units = um]. Minimum and maximum positions of the slab along axis dimension.

  • dilation (float = 0.0) – [units = um]. Dilation of the supplied polygon by shifting each edge along its normal outwards direction by a distance; a negative value corresponds to erosion.

  • vertices (ArrayLike[dtype=float, ndim=2]) – [units = um]. List of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the reference_plane. The index of dimension should be in the ascending order: e.g. if the slab normal axis is axis=y, the coordinate of the vertices will be in (x, z)

  • vertices_jax (Tuple[Tuple[Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object], Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object]], ...]) – [units = um]. Jax-traced list of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the reference_plane. The index of dimension should be in the ascending order: e.g. if the slab normal axis is axis=y, the coordinate of the vertices will be in (x, z)

  • slab_bounds_jax (Tuple[Union[float, NumpyArrayType, Array, JVPTracer, object], Union[float, NumpyArrayType, Array, JVPTracer, object]]) – [units = um]. Jax-traced list of (h1, h2) defining the minimum and maximum positions of the slab along the axis dimension.

  • sidewall_angle_jax (Union[float, NumpyArrayType, Array, JVPTracer, object] = 0.0) – [units = um]. Jax-traced float defining the sidewall angle of the slab along the axis dimension.

Attributes

reference_polygon

The polygon at the reference plane.

attrs

Methods

edge_contrib(vertex_grad, vertex_stat, ...)

Gradient w.r.t change in vertex_grad connected to vertex_stat.

no_dilation(val)

Don't allow dilation.

no_sidewall(val)

Warn if sidewall angle present.

store_vjp(grad_data_fwd, grad_data_adj, ...)

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_parallel(e_mult_xyz, d_mult_xyz, ...)

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_sequential(e_mult_xyz, d_mult_xyz, ...)

Stores the gradient of the vertices given forward and adjoint field data.

vertex_vjp(i_vertex, e_mult_xyz, d_mult_xyz, ...)

Compute the vjp for every vertex.

vertices_to_array(vertices_tuple)

Converts a list of tuples (vertices) to a jax array.

Inherited Common Usage

vertices_jax#
slab_bounds_jax#
sidewall_angle_jax#
classmethod no_sidewall(val)[source]#

Warn if sidewall angle present.

classmethod no_dilation(val)[source]#

Don’t allow dilation.

static vertices_to_array(vertices_tuple)[source]#

Converts a list of tuples (vertices) to a jax array.

property reference_polygon#

The polygon at the reference plane.

Returns:

The vertices of the polygon at the reference plane.

Return type:

ArrayLike[float, float]

edge_contrib(vertex_grad, vertex_stat, is_next, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#

Gradient w.r.t change in vertex_grad connected to vertex_stat.

vertex_vjp(i_vertex, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#

Compute the vjp for every vertex.

store_vjp(grad_data_fwd, grad_data_adj, grad_data_eps, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_sequential(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_parallel(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#

Stores the gradient of the vertices given forward and adjoint field data.

__hash__()#

Hash method.