Source code for tidy3d.plugins.adjoint.components.data.sim_data
"""Defines a jax-compatible SimulationData."""from__future__importannotationsfromtypingimportDict,List,Tuple,Unionimportnumpyasnpimportpydantic.v1aspdimportxarrayasxrfromjax.tree_utilimportregister_pytree_node_classfrom.....components.data.monitor_dataimportFieldData,MonitorDataType,PermittivityDatafrom.....components.data.sim_dataimportSimulationDatafrom.....components.source.currentimportPointDipolefrom.....components.source.timeimportGaussianPulsefrom.....logimportlogfrom..baseimportJaxObjectfrom..simulationimportJaxInfo,JaxSimulationfrom.monitor_dataimportJAX_MONITOR_DATA_MAP,JaxMonitorDataType
[docs]@register_pytree_node_classclassJaxSimulationData(SimulationData,JaxObject):"""A :class:`.SimulationData` registered with jax."""output_data:Tuple[JaxMonitorDataType,...]=pd.Field((),title="Jax Data",description="Tuple of Jax-compatible data associated with output monitors.",jax_field=True,)grad_data:Tuple[FieldData,...]=pd.Field((),title="Gradient Field Data",description="Tuple of monitor data storing fields associated with the input structures.",)grad_eps_data:Tuple[PermittivityData,...]=pd.Field((),title="Gradient Permittivity Data",description="Tuple of monitor data storing epsilon associated with the input structures.",)simulation:JaxSimulation=pd.Field(...,title="Simulation",description="The jax-compatible simulation corresponding to the data.",)task_id:str=pd.Field(None,title="Task ID",description="Optional field storing the task_id for the original JaxSimulation.",)
[docs]defget_poynting_vector(self,field_monitor_name:str)->xr.Dataset:"""return ``xarray.Dataset`` of the Poynting vector at Yee cell centers. Calculated values represent the instantaneous Poynting vector for time-domain fields and the complex vector for frequency-domain: ``S = 1/2 E × conj(H)``. Only the available components are returned, e.g., if the indicated monitor doesn't include field component `"Ex"`, then `"Sy"` and `"Sz"` will not be calculated. Parameters ---------- field_monitor_name : str Name of field monitor used in the original :class:`Simulation`. Returns ------- xarray.DataArray DataArray containing the Poynting vector calculated based on the field components colocated at the center locations of the Yee grid. """iffield_monitor_nameinself.output_monitor_data:raiseNotImplementedError("Adjoint support for differentiation with respect to Poynting vector not available.")returnsuper().get_poynting_vector(field_monitor_name)
@propertydefgrad_data_symmetry(self)->Tuple[FieldData,...]:"""``self.grad_data`` but with ``symmetry_expanded_copy`` applied."""returntuple(data.symmetry_expanded_copyfordatainself.grad_data)@propertydefgrad_eps_data_symmetry(self)->Tuple[FieldData,...]:"""``self.grad_eps_data`` but with ``symmetry_expanded_copy`` applied."""returntuple(data.symmetry_expanded_copyfordatainself.grad_eps_data)@propertydefoutput_monitor_data(self)->Dict[str,JaxMonitorDataType]:"""Dictionary of ``.output_data`` monitor ``.name`` to the corresponding data."""return{monitor_data.monitor.name:monitor_dataformonitor_datainself.output_data}@propertydefmonitor_data(self)->Dict[str,Union[JaxMonitorDataType,MonitorDataType]]:"""Dictionary of ``.output_data`` monitor ``.name`` to the corresponding data."""reg_mnt_data={monitor_data.monitor.name:monitor_dataformonitor_datainself.data}reg_mnt_data.update(self.output_monitor_data)returnreg_mnt_data
[docs]@staticmethoddefsplit_data(mnt_data:List[MonitorDataType],jax_info:JaxInfo)->Dict[str,List[MonitorDataType]]:"""Split list of monitor data into data, output_data, grad_data, and grad_eps_data."""# Get information needed to split the full data listlen_output_data=jax_info.num_output_monitorslen_grad_data=jax_info.num_grad_monitorslen_grad_eps_data=jax_info.num_grad_eps_monitorslen_data=len(mnt_data)-len_output_data-len_grad_data-len_grad_eps_data# split the data list into regular data, output_data, and grad_dataall_data=list(mnt_data)data=all_data[:len_data]output_data=all_data[len_data:len_data+len_output_data]grad_data=all_data[len_data+len_output_data:len_data+len_output_data+len_grad_data]grad_eps_data=all_data[len_data+len_output_data+len_grad_data:]returndict(data=data,output_data=output_data,grad_data=grad_data,grad_eps_data=grad_eps_data)
[docs]@classmethoddeffrom_sim_data(cls,sim_data:SimulationData,jax_info:JaxInfo,task_id:str=None)->JaxSimulationData:"""Construct a :class:`.JaxSimulationData` instance from a :class:`.SimulationData`."""self_dict=sim_data.dict(exclude={"type","simulation","data"})# convert the simulation to JaxSimulationjax_sim=JaxSimulation.from_simulation(simulation=sim_data.simulation,jax_info=jax_info)# construct JaxSimulationData with no data (yet)self_dict["simulation"]=jax_simself_dict["data"]=()data_dict=cls.split_data(mnt_data=sim_data.data,jax_info=jax_info)# convert the output data to the proper jax typeoutput_data_list=[]formnt_dataindata_dict["output_data"]:mnt_data_type_str=type(mnt_data)ifmnt_data_type_strnotinJAX_MONITOR_DATA_MAP:raiseKeyError(f"MonitorData type '{mnt_data_type_str}' ""not currently supported by adjoint plugin.")mnt_data_type=JAX_MONITOR_DATA_MAP[mnt_data_type_str]jax_mnt_data=mnt_data_type.from_monitor_data(mnt_data)output_data_list.append(jax_mnt_data)data_dict["output_data"]=output_data_listself_dict.update(data_dict)self_dict.update(dict(task_id=task_id))returncls.parse_obj(self_dict)
[docs]@classmethoddefsplit_fwd_sim_data(cls,sim_data:SimulationData,jax_info:JaxInfo)->Tuple[SimulationData,SimulationData]:"""Split a :class:`.SimulationData` into two parts, containing user and gradient data."""sim=sim_data.simulationdata_dict=cls.split_data(mnt_data=sim_data.data,jax_info=jax_info)user_data=data_dict["data"]+data_dict["output_data"]adjoint_data=data_dict["grad_data"]+data_dict["grad_eps_data"]mnt_dict=JaxSimulation.split_monitors(monitors=sim_data.simulation.monitors,jax_info=jax_info)user_mnts=mnt_dict["monitors"]+mnt_dict["output_monitors"]adjoint_mnts=mnt_dict["grad_monitors"]+mnt_dict["grad_eps_monitors"]user_sim=sim.updated_copy(monitors=user_mnts)adjoint_sim=sim.updated_copy(monitors=adjoint_mnts)user_sim_data=sim_data.updated_copy(data=user_data,simulation=user_sim)adjoint_sim_data=sim_data.updated_copy(data=adjoint_data,simulation=adjoint_sim)returnuser_sim_data,adjoint_sim_data
[docs]defmake_adjoint_simulation(self,fwidth:float,run_time:float)->JaxSimulation:"""Make an adjoint simulation out of the data provided (generally, the vjp sim data)."""sim_fwd=self.simulation# grab boundary conditions with flipped bloch vectors (for adjoint)bc_adj=sim_fwd.boundary_spec.flipped_bloch_vecs# add all adjoint sources and boundary conditions (at same time for BC validators to work)adj_srcs=[]formnt_data_vjpinself.output_data:foradj_sourceinmnt_data_vjp.to_adjoint_sources(fwidth=fwidth):adj_srcs.append(adj_source)# in this case (no adjoint sources) give it an "empty" sourceifnotadj_srcs:log.warning("No adjoint sources, making a mock source with amplitude = 0. ""All gradients will be zero for anything depending on this simulation's data. ""This comes up when a simulation's data contributes to the value of an objective ""function but the contribution from each member of the data is 0. ""If this is intended (eg. if using 'jnp.max()' of several simulation results), ""please ignore. Otherwise, this can suggest a mistake in your objective function.")# set a zero-amplitude sourceadj_srcs.append(PointDipole(center=sim_fwd.center,polarization="Ez",source_time=GaussianPulse(freq0=sim_fwd.freqs_adjoint[0],fwidth=sim_fwd._fwidth_adjoint,amplitude=0.0,),))# set a very short run time relative to the fwidthrun_time=2/fwidthupdate_dict=dict(boundary_spec=bc_adj,sources=adj_srcs,monitors=(),output_monitors=(),run_time=run_time,normalize_index=None,# normalize later, frequency-by-frequency)update_dict.update(sim_fwd.get_grad_monitors(input_structures=sim_fwd.input_structures,freqs_adjoint=sim_fwd.freqs_adjoint,include_eps_mnts=False,))# set the ADJ grid spec wavelength to the FWD wavelength (for same meshing)grid_spec_fwd=sim_fwd.grid_speciflen(sim_fwd.sources)andgrid_spec_fwd.wavelengthisNone:wavelength_fwd=grid_spec_fwd.wavelength_from_sources(sim_fwd.sources)grid_spec_adj=grid_spec_fwd.updated_copy(wavelength=wavelength_fwd)update_dict.update(dict(grid_spec=grid_spec_adj))returnsim_fwd.updated_copy(**update_dict)
[docs]defnormalize_adjoint_fields(self)->JaxSimulationData:"""Make copy of jax_sim_data with grad_data (fields) normalized by adjoint sources."""grad_data_norm=[]forfield_datainself.grad_data:field_components_norm={}forfield_name,field_componentinfield_data.field_components.items():freqs=field_component.coords["f"]norm_factor_f=np.zeros(len(freqs),dtype=complex)fori,freqinenumerate(freqs):freq=float(freq)forsource_index,sourceinenumerate(self.simulation.sources):ifsource.source_time.freq0==freqandsource.source_time.amplitude>0:spectrum_fn=self.source_spectrum(source_index)norm_factor_f[i]=complex(spectrum_fn([freq])[0])norm_factor_f_darr=xr.DataArray(norm_factor_f,coords=dict(f=freqs))field_component_norm=field_component/norm_factor_f_darrfield_components_norm[field_name]=field_component_normfield_data_norm=field_data.updated_copy(**field_components_norm)grad_data_norm.append(field_data_norm)returnself.updated_copy(grad_data=grad_data_norm)