tidy3d.plugins.adjoint.JaxPolySlab#
- class JaxPolySlab[source]#
Bases:
JaxGeometry
,PolySlab
,JaxObject
A
PolySlab
registered with jax.- Parameters:
attrs (dict = {}) – Dictionary storing arbitrary metadata for a Tidy3D object. This dictionary can be freely used by the user for storing data without affecting the operation of Tidy3D as it is not used internally. Note that, unlike regular Tidy3D fields,
attrs
are mutable. For example, the following is allowed for setting anattr
obj.attrs['foo'] = bar
. Also note that Tidy3D` will raise aTypeError
ifattrs
contain objects that can not be serialized. One can check ifattrs
are serializable by callingobj.json()
.axis (Literal[0, 1, 2] = 2) – Specifies dimension of the planar axis (0,1,2) -> (x,y,z).
sidewall_angle (ConstrainedFloatValue = 0.0) – [units = rad]. Angle of the sidewall.
sidewall_angle=0
(default) specifies a vertical wall;0<sidewall_angle<np.pi/2
specifies a shrinking cross section along theaxis
direction; and-np.pi/2<sidewall_angle<0
specifies an expanding cross section along theaxis
direction.reference_plane (Literal['bottom', 'middle', 'top'] = middle) – The position of the plane where the supplied cross section are defined. The plane is perpendicular to the
axis
. The plane is located at thebottom
,middle
, ortop
of the geometry with respect to the axis. E.g. ifaxis=1
,bottom
refers to the negative side of the y-axis, andtop
refers to the positive side of the y-axis.slab_bounds (Tuple[float, float]) – [units = um]. Minimum and maximum positions of the slab along axis dimension.
dilation (float = 0.0) – [units = um]. Dilation of the supplied polygon by shifting each edge along its normal outwards direction by a distance; a negative value corresponds to erosion.
vertices (Union[ArrayLike[dtype=float, ndim=2], Box]) – [units = um]. List of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the
reference_plane
. The index of dimension should be in the ascending order: e.g. if the slab normal axis isaxis=y
, the coordinate of the vertices will be in (x, z)vertices_jax (Tuple[Tuple[Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object], Union[float, tidy3d.plugins.adjoint.components.types.NumpyArrayType, jax.Array, jax._src.interpreters.ad.JVPTracer, object]], ...]) – [units = um]. Jax-traced list of (d1, d2) defining the 2 dimensional positions of the polygon face vertices at the
reference_plane
. The index of dimension should be in the ascending order: e.g. if the slab normal axis isaxis=y
, the coordinate of the vertices will be in (x, z)slab_bounds_jax (Tuple[Union[float, NumpyArrayType, Array, JVPTracer, object], Union[float, NumpyArrayType, Array, JVPTracer, object]]) – [units = um]. Jax-traced list of (h1, h2) defining the minimum and maximum positions of the slab along the
axis
dimension.sidewall_angle_jax (Union[float, NumpyArrayType, Array, JVPTracer, object] = 0.0) – [units = um]. Jax-traced float defining the sidewall angle of the slab along the
axis
dimension.dilation_jax (Union[float, NumpyArrayType, Array, JVPTracer, object] = 0.0) – [units = um]. Jax-traced float defining the dilation.
Attributes
The polygon at the reference plane.
Methods
edge_contrib
(vertex_grad, vertex_stat, ...)Gradient w.r.t change in
vertex_grad
connected tovertex_stat
.no_sidewall
(val)Warn if sidewall angle present.
store_vjp
(grad_data_fwd, grad_data_adj, ...)Stores the gradient of the vertices given forward and adjoint field data.
store_vjp_parallel
(e_mult_xyz, d_mult_xyz, ...)Stores the gradient of the vertices given forward and adjoint field data.
store_vjp_sequential
(e_mult_xyz, d_mult_xyz, ...)Stores the gradient of the vertices given forward and adjoint field data.
vertex_vjp
(i_vertex, e_mult_xyz, d_mult_xyz, ...)Compute the vjp for every vertex.
vertices_to_array
(vertices_tuple)Converts a list of tuples (vertices) to a jax array.
Inherited Common Usage
- vertices_jax#
- slab_bounds_jax#
- sidewall_angle_jax#
- dilation_jax#
- static vertices_to_array(vertices_tuple)[source]#
Converts a list of tuples (vertices) to a jax array.
- property reference_polygon#
The polygon at the reference plane.
- Returns:
The vertices of the polygon at the reference plane.
- Return type:
ArrayLike[float, float]
- edge_contrib(vertex_grad, vertex_stat, is_next, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#
Gradient w.r.t change in
vertex_grad
connected tovertex_stat
.
- vertex_vjp(i_vertex, e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#
Compute the vjp for every vertex.
- store_vjp(grad_data_fwd, grad_data_adj, grad_data_eps, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#
Stores the gradient of the vertices given forward and adjoint field data.
- store_vjp_sequential(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in)[source]#
Stores the gradient of the vertices given forward and adjoint field data.
- store_vjp_parallel(e_mult_xyz, d_mult_xyz, sim_bounds, wvl_mat, eps_out, eps_in, num_proc=1)[source]#
Stores the gradient of the vertices given forward and adjoint field data.
- __hash__()#
Hash method.