Source code for tidy3d.plugins.adjoint.components.medium
"""Defines jax-compatible mediums."""from__future__importannotationsfromabcimportABCfromtypingimportCallable,Dict,Optional,Tuple,Unionimportnumpyasnpimportpydantic.v1aspdimportxarrayasxrfromjax.tree_utilimportregister_pytree_node_classfrom....components.data.monitor_dataimportFieldDatafrom....components.geometry.baseimportGeometryfrom....components.mediumimportAnisotropicMedium,CustomMedium,Mediumfrom....components.typesimportBound,Literalfrom....constantsimportCONDUCTIVITYfrom....exceptionsimportSetupErrorfrom.baseimportWEB_ADJOINT_MESSAGE,JaxObjectfrom.data.data_arrayimportJaxDataArrayfrom.data.datasetimportJaxPermittivityDatasetfrom.typesimportJaxFloat# number of integration points per unit wavelength in materialPTS_PER_WVL_INTEGRATION=20# maximum number of pixels allowed in each component of a JaxCustomMediumMAX_NUM_CELLS_CUSTOM_MEDIUM=250_000classAbstractJaxMedium(ABC,JaxObject):"""Holds some utility functions for Jax medium types."""def_get_volume_disc(self,grad_data:FieldData,sim_bounds:Bound,wvl_mat:float)->Tuple[Dict[str,np.ndarray],float]:"""Get the coordinates and volume element for the inside of the corresponding structure."""# find intersecting volume between structure and simulationmnt_bounds=grad_data.monitor.geometry.boundsrmin,rmax=Geometry.bounds_intersection(mnt_bounds,sim_bounds)# assemble volume coordinates and differential volume elementd_vol=1.0vol_coords={}forcoord_name,min_edge,max_edgeinzip("xyz",rmin,rmax):size=max_edge-min_edge# don't discretize this dimension if there is no thickness along itifsize==0:vol_coords[coord_name]=[max_edge]continue# update the volume element valuenum_cells_dim=int(size*PTS_PER_WVL_INTEGRATION/wvl_mat)+1d_len=size/num_cells_dimd_vol*=d_len# construct the interpolation coordinates along this dimensioncoords_interp=np.linspace(min_edge+d_len/2,max_edge-d_len/2,num_cells_dim)vol_coords[coord_name]=coords_interpreturnvol_coords,d_vol@staticmethoddefmake_inside_mask(vol_coords:Dict[str,np.ndarray],inside_fn:Callable)->xr.DataArray:"""Make a 3D mask of where the volume coordinates are inside a supplied function."""meshgrid_args=[vol_coords[dim]fordimin"xyz"ifdiminvol_coords]vol_coords_meshgrid=np.meshgrid(*meshgrid_args,indexing="ij")inside_kwargs=dict(zip("xyz",vol_coords_meshgrid))values=inside_fn(**inside_kwargs)returnxr.DataArray(values,coords=vol_coords)defe_mult_volume(self,field:Literal["Ex","Ey","Ez"],grad_data_fwd:FieldData,grad_data_adj:FieldData,vol_coords:Dict[str,np.ndarray],d_vol:float,inside_fn:Callable,)->xr.DataArray:"""Get the E_fwd * E_adj * dV field distribution inside of the discretized volume."""e_fwd=grad_data_fwd.field_components[field]e_adj=grad_data_adj.field_components[field]e_dotted=e_fwd*e_adjinside_mask=self.make_inside_mask(vol_coords=vol_coords,inside_fn=inside_fn)isel_kwargs={key:[0]forkey,valueinvol_coords.items()ifisinstance(value,float)orlen(value)<=1}interp_kwargs={key:valueforkey,valueinvol_coords.items()ifkeynotinisel_kwargs}fields_eval=e_dotted.isel(**isel_kwargs).interp(**interp_kwargs,assume_sorted=True)inside_mask=inside_mask.isel(**isel_kwargs)mask_dV=inside_mask*d_volfields_eval=fields_eval.assign_coords(**mask_dV.coords)returnmask_dV*fields_evaldefd_eps_map(self,grad_data_fwd:FieldData,grad_data_adj:FieldData,sim_bounds:Bound,wvl_mat:float,inside_fn:Callable,)->xr.DataArray:"""Mapping of gradient w.r.t. permittivity at each point in discretized volume."""vol_coords,d_vol=self._get_volume_disc(grad_data=grad_data_fwd,sim_bounds=sim_bounds,wvl_mat=wvl_mat)e_mult_sum=0.0forfieldin("Ex","Ey","Ez"):e_mult_sum+=self.e_mult_volume(field=field,grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,vol_coords=vol_coords,d_vol=d_vol,inside_fn=inside_fn,)returne_mult_sum
[docs]@register_pytree_node_classclassJaxMedium(Medium,AbstractJaxMedium):"""A :class:`.Medium` registered with jax."""_tidy3d_class=Mediumpermittivity_jax:JaxFloat=pd.Field(1.0,title="Permittivity",description="Relative permittivity of the medium. May be a ``jax`` ``Array``.",stores_jax_for="permittivity",)conductivity_jax:JaxFloat=pd.Field(0.0,title="Conductivity",description="Electric conductivity. Defined such that the imaginary part of the complex ""permittivity at angular frequency omega is given by conductivity/omega.",units=CONDUCTIVITY,stores_jax_for="conductivity",)
[docs]defstore_vjp(self,grad_data_fwd:FieldData,grad_data_adj:FieldData,sim_bounds:Bound,wvl_mat:float,inside_fn:Callable[[np.ndarray,np.ndarray,np.ndarray],np.ndarray],)->JaxMedium:"""Returns the gradient of the medium parameters given forward and adjoint field data."""# integrate the dot product of each E component over the volume, update vjp for epsilond_eps_map=self.d_eps_map(grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,sim_bounds=sim_bounds,wvl_mat=wvl_mat,inside_fn=inside_fn,)vjp_eps_complex=d_eps_map.sum(dim=("x","y","z"))vjp_eps=0.0vjp_sigma=0.0forfreqind_eps_map.coords["f"]:vjp_eps_complex_f=vjp_eps_complex.sel(f=freq)_vjp_eps,_vjp_sigma=self.eps_complex_to_eps_sigma(vjp_eps_complex_f,freq)vjp_eps+=_vjp_epsvjp_sigma+=_vjp_sigmareturnself.copy(update=dict(permittivity_jax=vjp_eps,conductivity_jax=vjp_sigma,))
[docs]@register_pytree_node_classclassJaxAnisotropicMedium(AnisotropicMedium,AbstractJaxMedium):"""A :class:`.Medium` registered with jax."""_tidy3d_class=AnisotropicMediumxx:JaxMedium=pd.Field(...,title="XX Component",description="Medium describing the xx-component of the diagonal permittivity tensor.",jax_field=True,)yy:JaxMedium=pd.Field(...,title="YY Component",description="Medium describing the yy-component of the diagonal permittivity tensor.",jax_field=True,)zz:JaxMedium=pd.Field(...,title="ZZ Component",description="Medium describing the zz-component of the diagonal permittivity tensor.",jax_field=True,)
[docs]defstore_vjp(self,grad_data_fwd:FieldData,grad_data_adj:FieldData,sim_bounds:Bound,wvl_mat:float,inside_fn:Callable,)->JaxMedium:"""Returns the gradient of the medium parameters given forward and adjoint field data."""# integrate the dot product of each E component over the volume, update vjp for epsilonvol_coords,d_vol=self._get_volume_disc(grad_data=grad_data_fwd,sim_bounds=sim_bounds,wvl_mat=wvl_mat)vjp_fields={}forcomponentin"xyz":field_name="E"+componentcomponent_name=component+componente_mult_dim=self.e_mult_volume(field=field_name,grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,vol_coords=vol_coords,d_vol=d_vol,inside_fn=inside_fn,)vjp_eps_complex_ii=e_mult_dim.sum(dim=("x","y","z"))freq=e_mult_dim.coords["f"][0]vjp_eps_ii=0.0vjp_sigma_ii=0.0forfreqine_mult_dim.coords["f"]:vjp_eps_complex_ii_f=vjp_eps_complex_ii.sel(f=freq)_vjp_eps_ii,_vjp_sigma_ii=self.eps_complex_to_eps_sigma(vjp_eps_complex_ii_f,freq)vjp_eps_ii+=_vjp_eps_iivjp_sigma_ii+=_vjp_sigma_iivjp_medium=self.components[component_name]vjp_fields[component_name]=vjp_medium.updated_copy(permittivity_jax=vjp_eps_ii,conductivity_jax=vjp_sigma_ii,)returnself.copy(update=vjp_fields)
[docs]@register_pytree_node_classclassJaxCustomMedium(CustomMedium,AbstractJaxMedium):"""A :class:`.CustomMedium` registered with ``jax``. Note: The gradient calculation assumes uniform field across the pixel. Therefore, the accuracy degrades as the pixel size becomes large with respect to the field variation. """_tidy3d_class=CustomMediumeps_dataset:Optional[JaxPermittivityDataset]=pd.Field(None,title="Permittivity Dataset",description="User-supplied dataset containing complex-valued permittivity ""as a function of space. Permittivity distribution over the Yee-grid will be ""interpolated based on the data nearest to the grid location.",jax_field=True,)@pd.root_validator(pre=True)def_pre_deprecation_dataset(cls,values):"""Don't allow permittivity as a field until we support it."""ifvalues.get("permittivity")orvalues.get("conductivity"):raiseSetupError("'permittivity' and 'conductivity' are not yet supported in adjoint plugin. ""Please continue to use the 'eps_dataset' field to define the component ""of the permittivity tensor.")returnvaluesdef_validate_web_adjoint(self)->None:"""Run validators for this component, only if using ``tda.web.run()``."""self._is_not_too_large()def_is_not_too_large(self):"""Ensure number of pixels does not surpass a set amount."""field_components=self.eps_dataset.field_componentsforfield_dimin"xyz":field_name=f"eps_{field_dim}{field_dim}"data_array=field_components[field_name]coord_lens=[len(data_array.coords[key])forkeyin"xyz"]num_cells_dim=np.prod(coord_lens)ifnum_cells_dim>MAX_NUM_CELLS_CUSTOM_MEDIUM:raiseSetupError("For the adjoint plugin, each component of the 'JaxCustomMedium.eps_dataset' "f"is restricted to have a maximum of {MAX_NUM_CELLS_CUSTOM_MEDIUM} cells. "f"Detected {num_cells_dim} grid cells in the '{field_name}' component. "+WEB_ADJOINT_MESSAGE)@pd.validator("eps_dataset",always=True)def_eps_dataset_single_frequency(cls,val):"""Override of inherited validator. (still needed)"""returnval@pd.validator("eps_dataset",always=True)def_eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls,val,values):"""Override of inherited validator."""returnval
[docs]defstore_vjp(self,grad_data_fwd:FieldData,grad_data_adj:FieldData,sim_bounds:Bound,wvl_mat:float,inside_fn:Callable[[np.ndarray,np.ndarray,np.ndarray],np.ndarray],)->JaxMedium:"""Returns the gradient of the medium parameters given forward and adjoint field data."""# get the boundaries of the intersection of the CustomMedium and the Simulationmnt_bounds=grad_data_fwd.monitor.geometry.boundsbounds_intersect=Geometry.bounds_intersection(mnt_bounds,sim_bounds)# get the grids associated with the user-supplied coordinates within these boundsgrids=self.grids(bounds=bounds_intersect)vjp_field_components={}fordimin"xyz":eps_field_name=f"eps_{dim}{dim}"# grab the original data and its coordinatesorig_data_array=self.eps_dataset.field_components[eps_field_name]coords=orig_data_array.coordsgrid=grids[eps_field_name]d_sizes=grid.sizesd_sizes=[d_sizes.x,d_sizes.y,d_sizes.z]# construct the coordinates for interpolation and selection within the custom medium# TODO: extend this to all points within the volume.interp_coords={}sum_axes=[]fordim_index,dim_ptinenumerate("xyz"):coord_dim=coords[dim_pt]# if it's uniform / single pixel along this dimiflen(np.array(coord_dim))==1:# discretize along this edge like a regular volume# compute the length of the pixel within the sim boundsr_min_coords,r_max_coords=grid.boundaries.to_list[dim_index]r_min_sim,r_max_sim=np.array(sim_bounds).T[dim_index]r_min=max(r_min_coords,r_min_sim)r_max=min(r_max_coords,r_max_sim)size=abs(r_max-r_min)# compute the length element along the dim, handling case of sim.size=0ifsize>0:# discretize according to PTS_PER_WVLnum_cells_dim=int(size*PTS_PER_WVL_INTEGRATION/wvl_mat)+1d_len=size/num_cells_dimcoords_interp=np.linspace(r_min+d_len/2,r_max-d_len/2,num_cells_dim)else:# just interpolate at the single position, dL=1 to normalize outd_len=1.0coords_interp=np.array([(r_min+r_max)/2.0])# construct the interpolation coordinates along this dimensiond_sizes[dim_index]=np.array([d_len])interp_coords[dim_pt]=coords_interp# only sum this dimension if there are multiple pointssum_axes.append(dim_pt)# otherwiseelse:# just evaluate at the original data coordsinterp_coords[dim_pt]=coord_dim# outer product all dimensions to get a volume element maskd_vols=np.einsum("i, j, k -> ijk",*d_sizes)# grab the corresponding dotted fields at these interp_coords and sum over len-1 pixelsfield_name="E"+dime_dotted=(self.e_mult_volume(field=field_name,grad_data_fwd=grad_data_fwd,grad_data_adj=grad_data_adj,vol_coords=interp_coords,d_vol=d_vols,inside_fn=inside_fn,).sum(sum_axes).sum(dim="f"))# reshape values to the expected vjp shape to be more safevjp_shape=tuple(len(coord)for_,coordincoords.items())# make sure this has the same dtype as the originaldtype_orig=np.array(orig_data_array.values).dtypevjp_values=e_dotted.values.reshape(vjp_shape)ifdtype_orig.kind=="f":vjp_values=vjp_values.realvjp_values=vjp_values.astype(dtype_orig)# construct a DataArray storing the vjpvjp_data_array=JaxDataArray(values=vjp_values,coords=coords)vjp_field_components[eps_field_name]=vjp_data_array# package everything into datasetvjp_eps_dataset=JaxPermittivityDataset(**vjp_field_components)returnself.copy(update=dict(eps_dataset=vjp_eps_dataset))