tidy3d.plugins.adjoint.JaxPolySlab#

class JaxPolySlab[source]#

Bases: JaxGeometry, PolySlab, JaxObject

A PolySlab registered with jax.

Parameters:
  • axis (Literal[0, 1, 2] = 2) – Specifies dimension of the planar axis (0,1,2) -> (x,y,z).

  • sidewall_angle (ConstrainedFloatValue = 0.0) – [units = rad]. Angle of the sidewall. sidewall_angle=0 (default) specifies a vertical wall; 0<sidewall_angle<np.pi/2 specifies a shrinking cross section along the axis direction; and -np.pi/2<sidewall_angle<0 specifies an expanding cross section along the axis direction.

  • reference_plane (Literal['bottom', 'middle', 'top'] = middle) – The position of the plane where the supplied cross section are defined. The plane is perpendicular to the axis. The plane is located at the bottom, middle, or top of the geometry with respect to the axis. E.g. if axis=1, bottom refers to the negative side of the y-axis, and top refers to the positive side of the y-axis.

  • slab_bounds (Tuple[float, float]) – [units = um]. Minimum and maximum positions of the slab along axis dimension.

  • dilation (float = 0.0) – [units = um]. Dilation of the supplied polygon by shifting each edge along its normal outwards direction by a distance; a negative value corresponds to erosion.

  • vertices (ArrayLike[dtype=float, ndim=2]) – [units = um]. List of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the reference_plane. The index of dimension should be in the ascending order: e.g. if the slab normal axis is axis=y, the coordinate of the vertices will be in (x, z)

  • vertices_jax (Tuple[Tuple[Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object], Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object]], ...]) – [units = um]. Jax-traced list of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the reference_plane. The index of dimension should be in the ascending order: e.g. if the slab normal axis is axis=y, the coordinate of the vertices will be in (x, z)

Attributes

Methods

edge_contrib(vertex_grad, vertex_stat, ...)

Gradient w.r.t change in vertex_grad connected to vertex_stat.

limit_number_of_vertices(val)

Limit the maximum number of vertices.

no_dilation(val)

Don't allow dilation.

no_sidewall(val)

Don't allow sidewall.

store_vjp(grad_data_fwd, grad_data_adj, ...)

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_parallel(e_mult_xyz, d_mult_xyz, ...)

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_sequential(e_mult_xyz, d_mult_xyz, ...)

Stores the gradient of the vertices given forward and adjoint field data.

vertex_vjp(i_vertex, e_mult_xyz, d_mult_xyz, ...)

Compute the vjp for every vertex.

vertices_jax#
classmethod no_sidewall(val)[source]#

Don’t allow sidewall.

classmethod no_dilation(val)[source]#

Don’t allow dilation.

classmethod limit_number_of_vertices(val)[source]#

Limit the maximum number of vertices.

edge_contrib(vertex_grad, vertex_stat, is_next, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#

Gradient w.r.t change in vertex_grad connected to vertex_stat.

vertex_vjp(i_vertex, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#

Compute the vjp for every vertex.

store_vjp(grad_data_fwd, grad_data_adj, grad_data_eps, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_sequential(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#

Stores the gradient of the vertices given forward and adjoint field data.

store_vjp_parallel(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#

Stores the gradient of the vertices given forward and adjoint field data.

__hash__()#

Hash method.