Source code for tidy3d.plugins.adjoint.components.structure
"""Defines a jax-compatible structure and its conversion to a gradient monitor."""from__future__importannotationsfromtypingimportDict,List,Unionimportnumpyasnpimportpydantic.v1aspdfromjax.tree_utilimportregister_pytree_node_classfrom....components.data.monitor_dataimportFieldData,PermittivityDatafrom....components.geometry.utilsimportGeometryTypefrom....components.mediumimportMediumTypefrom....components.monitorimportFieldMonitorfrom....components.structureimportStructurefrom....components.typesimportTYPE_TAG_STR,Boundfrom....constantsimportC_0from.baseimportJaxObjectfrom.geometryimportJAX_GEOMETRY_MAP,JaxBox,JaxGeometryTypefrom.mediumimportJAX_MEDIUM_MAP,JaxMediumTypeGEO_MED_MAPPINGS=dict(geometry=JAX_GEOMETRY_MAP,medium=JAX_MEDIUM_MAP)classAbstractJaxStructure(Structure,JaxObject):"""A :class:`.Structure` registered with jax."""_tidy3d_class=Structure# which of "geometry" or "medium" is differentiable for this class_differentiable_fields=()geometry:Union[JaxGeometryType,GeometryType]medium:Union[JaxMediumType,MediumType]@pd.validator("medium",always=True)def_check_2d_geometry(cls,val,values):"""Override validator checking 2D geometry, which triggers unnecessarily for gradients."""returnvaldef_validate_web_adjoint(self)->None:"""Run validators for this component, only if using ``tda.web.run()``."""if"geometry"inself._differentiable_fields:self.geometry._validate_web_adjoint()if"medium"inself._differentiable_fields:self.medium._validate_web_adjoint()@propertydefjax_fields(self):"""The fields that are jax-traced for this class."""returndict(geometry=self.geometry,medium=self.medium)@propertydefexclude_fields(self):"""Fields to exclude from the self dict."""returnset(["type"]+list(self.jax_fields.keys()))defto_structure(self)->Structure:"""Convert :class:`.JaxStructure` instance to :class:`.Structure`"""self_dict=self.dict(exclude=self.exclude_fields)forkey,componentinself.jax_fields.items():ifkeyinself._differentiable_fields:self_dict[key]=component.to_tidy3d()else:self_dict[key]=componentreturnStructure.parse_obj(self_dict)@classmethoddeffrom_structure(cls,structure:Structure)->JaxStructure:"""Convert :class:`.Structure` to :class:`.JaxStructure`."""struct_dict=structure.dict(exclude={"type"})jax_fields=dict(geometry=structure.geometry,medium=structure.medium)forkey,componentinjax_fields.items():ifkeyincls._differentiable_fields:type_map=GEO_MED_MAPPINGS[key]jax_type=type_map[type(component)]struct_dict[key]=jax_type.from_tidy3d(component)else:struct_dict[key]=componentreturncls.parse_obj(struct_dict)defmake_grad_monitors(self,freqs:List[float],name:str)->FieldMonitor:"""Return gradient monitor associated with this object."""if"geometry"notinself._differentiable_fields:# make a fake JaxBox to be able to call .make_grad_monitorsrmin,rmax=self.geometry.boundsgeometry=JaxBox.from_bounds(rmin=rmin,rmax=rmax)else:geometry=self.geometryreturngeometry.make_grad_monitors(freqs=freqs,name=name)def_get_medium_params(self,grad_data_eps:PermittivityData,)->Dict[str,float]:"""Compute params in the material of this structure."""freq_max=float(max(grad_data_eps.eps_xx.f))eps_in=self.medium.eps_model(frequency=freq_max)ref_ind=np.sqrt(np.max(np.real(eps_in)))ref_ind=max([1.0,abs(ref_ind)])wvl_free_space=C_0/freq_maxwvl_mat=wvl_free_space/ref_indreturndict(wvl_mat=wvl_mat,eps_in=eps_in)defgeometry_vjp(self,grad_data_fwd:FieldData,grad_data_adj:FieldData,grad_data_eps:PermittivityData,sim_bounds:Bound,eps_out:complex,num_proc:int=1,)->JaxGeometryType:"""Compute the VJP for the structure geometry."""medium_params=self._get_medium_params(grad_data_eps=grad_data_eps)returnself.geometry.store_vjp(grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,grad_data_eps=grad_data_eps,sim_bounds=sim_bounds,wvl_mat=medium_params["wvl_mat"],eps_out=eps_out,eps_in=medium_params["eps_in"],num_proc=num_proc,)defmedium_vjp(self,grad_data_fwd:FieldData,grad_data_adj:FieldData,grad_data_eps:PermittivityData,sim_bounds:Bound,)->JaxMediumType:"""Compute the VJP for the structure medium."""medium_params=self._get_medium_params(grad_data_eps=grad_data_eps)returnself.medium.store_vjp(grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,sim_bounds=sim_bounds,wvl_mat=medium_params["wvl_mat"],inside_fn=self.geometry.inside,)defstore_vjp(self,# field_keys: List[Literal["medium", "geometry"]],grad_data_fwd:FieldData,grad_data_adj:FieldData,grad_data_eps:PermittivityData,sim_bounds:Bound,eps_out:complex,num_proc:int=1,)->JaxStructure:"""Returns the gradient of the structure parameters given forward and adjoint field data."""# return right away if field_keys are not present for some reasonifnotself._differentiable_fields:returnselfvjp_dict={}# compute minimum wavelength in material (to use for determining integration points)if"geometry"inself._differentiable_fields:vjp_dict["geometry"]=self.geometry_vjp(grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,grad_data_eps=grad_data_eps,sim_bounds=sim_bounds,eps_out=eps_out,num_proc=num_proc,)if"medium"inself._differentiable_fields:vjp_dict["medium"]=self.medium_vjp(grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,grad_data_eps=grad_data_eps,sim_bounds=sim_bounds,)returnself.updated_copy(**vjp_dict)
[docs]@register_pytree_node_classclassJaxStructure(AbstractJaxStructure,JaxObject):"""A :class:`.Structure` registered with jax."""geometry:JaxGeometryType=pd.Field(...,title="Geometry",description="Geometry of the structure, which is jax-compatible.",jax_field=True,discriminator=TYPE_TAG_STR,)medium:JaxMediumType=pd.Field(...,title="Medium",description="Medium of the structure, which is jax-compatible.",jax_field=True,discriminator=TYPE_TAG_STR,)_differentiable_fields=("medium","geometry")
@register_pytree_node_classclassJaxStructureStaticMedium(AbstractJaxStructure,JaxObject):"""A :class:`.Structure` registered with jax."""geometry:JaxGeometryType=pd.Field(...,title="Geometry",description="Geometry of the structure, which is jax-compatible.",jax_field=True,discriminator=TYPE_TAG_STR,)medium:MediumType=pd.Field(...,title="Medium",description="Regular ``tidy3d`` medium of the structure, non differentiable. ""Supports dispersive materials.",jax_field=False,discriminator=TYPE_TAG_STR,)_differentiable_fields=("geometry",)@register_pytree_node_classclassJaxStructureStaticGeometry(AbstractJaxStructure,JaxObject):"""A :class:`.Structure` registered with jax."""geometry:GeometryType=pd.Field(...,title="Geometry",description="Regular ``tidy3d`` geometry of the structure, non differentiable. ""Supports angled sidewalls and other complex geometries.",jax_field=False,discriminator=TYPE_TAG_STR,)medium:JaxMediumType=pd.Field(...,title="Medium",description="Medium of the structure, which is jax-compatible.",jax_field=True,discriminator=TYPE_TAG_STR,)_differentiable_fields=("medium",)JaxStructureType=Union[JaxStructure,JaxStructureStaticMedium,JaxStructureStaticGeometry]