tidy3d.plugins.adjoint.JaxPolySlab#
- class JaxPolySlab[source]#
Bases:
JaxGeometry
,PolySlab
,JaxObject
A
PolySlab
registered with jax.- Parameters:
axis (Literal[0, 1, 2] = 2) – Specifies dimension of the planar axis (0,1,2) -> (x,y,z).
sidewall_angle (ConstrainedFloatValue = 0.0) – [units = rad]. Angle of the sidewall.
sidewall_angle=0
(default) specifies a vertical wall;0<sidewall_angle<np.pi/2
specifies a shrinking cross section along theaxis
direction; and-np.pi/2<sidewall_angle<0
specifies an expanding cross section along theaxis
direction.reference_plane (Literal['bottom', 'middle', 'top'] = middle) – The position of the plane where the supplied cross section are defined. The plane is perpendicular to the
axis
. The plane is located at thebottom
,middle
, ortop
of the geometry with respect to the axis. E.g. ifaxis=1
,bottom
refers to the negative side of the y-axis, andtop
refers to the positive side of the y-axis.slab_bounds (Tuple[float, float]) – [units = um]. Minimum and maximum positions of the slab along axis dimension.
dilation (float = 0.0) – [units = um]. Dilation of the supplied polygon by shifting each edge along its normal outwards direction by a distance; a negative value corresponds to erosion.
vertices (ArrayLike[dtype=float, ndim=2]) – [units = um]. List of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the
reference_plane
. The index of dimension should be in the ascending order: e.g. if the slab normal axis isaxis=y
, the coordinate of the vertices will be in (x, z)vertices_jax (Tuple[Tuple[Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object], Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object]], ...]) – [units = um]. Jax-traced list of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the
reference_plane
. The index of dimension should be in the ascending order: e.g. if the slab normal axis isaxis=y
, the coordinate of the vertices will be in (x, z)
Attributes
Methods
edge_contrib
(vertex_grad, vertex_stat, ...)Gradient w.r.t change in
vertex_grad
connected tovertex_stat
.Limit the maximum number of vertices.
no_dilation
(val)Don't allow dilation.
no_sidewall
(val)Don't allow sidewall.
store_vjp
(grad_data_fwd, grad_data_adj, ...)Stores the gradient of the vertices given forward and adjoint field data.
store_vjp_parallel
(e_mult_xyz, d_mult_xyz, ...)Stores the gradient of the vertices given forward and adjoint field data.
store_vjp_sequential
(e_mult_xyz, d_mult_xyz, ...)Stores the gradient of the vertices given forward and adjoint field data.
vertex_vjp
(i_vertex, e_mult_xyz, d_mult_xyz, ...)Compute the vjp for every vertex.
- vertices_jax#
- edge_contrib(vertex_grad, vertex_stat, is_next, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#
Gradient w.r.t change in
vertex_grad
connected tovertex_stat
.
- vertex_vjp(i_vertex, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#
Compute the vjp for every vertex.
- store_vjp(grad_data_fwd, grad_data_adj, grad_data_eps, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#
Stores the gradient of the vertices given forward and adjoint field data.
- store_vjp_sequential(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#
Stores the gradient of the vertices given forward and adjoint field data.
- store_vjp_parallel(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#
Stores the gradient of the vertices given forward and adjoint field data.
- __hash__()#
Hash method.