"""Tools for generating an S matrix automatically from tidy3d simulation and port definitions."""from__future__importannotationsfromtypingimportList,Tuple,Optional,Dictimportosimportpydantic.v1aspdimportnumpyasnpfrom...constantsimportHERTZfrom...components.simulationimportSimulationfrom...components.geometry.baseimportBoxfrom...components.modeimportModeSpecfrom...components.monitorimportModeMonitorfrom...components.sourceimportModeSource,GaussianPulsefrom...components.data.sim_dataimportSimulationDatafrom...components.data.data_arrayimportDataArrayfrom...components.typesimportDirection,Ax,Complex,FreqArrayfrom...components.vizimportadd_ax_if_none,equal_aspectfrom...components.baseimportTidy3dBaseModel,cached_propertyfrom...exceptionsimportSetupError,Tidy3dKeyErrorfrom...logimportlogfrom...web.api.containerimportBatchData,Batch# fwidth of gaussian pulse in units of central frequencyFWIDTH_FRAC=1.0/10DEFAULT_DATA_DIR="."
[docs]classPort(Box):"""Specifies a port in the scattering matrix."""direction:Direction=pd.Field(...,title="Direction",description="'+' or '-', defining which direction is considered 'input'.",)mode_spec:ModeSpec=pd.Field(ModeSpec(),title="Mode Specification",description="Specifies how the mode solver will solve for the modes of the port.",)name:str=pd.Field(...,title="Name",description="Unique name for the port.",min_length=1,)
MatrixIndex=Tuple[str,pd.NonNegativeInt]# the 'i' in S_ijElement=Tuple[MatrixIndex,MatrixIndex]# the 'ij' in S_ij
[docs]classComponentModeler(Tidy3dBaseModel):""" Tool for modeling devices and computing scattering matrix elements. .. TODO missing basic example See Also -------- **Notebooks** * `Computing the scattering matrix of a device <../../notebooks/SMatrix.html>`_ """simulation:Simulation=pd.Field(...,title="Simulation",description="Simulation describing the device without any sources present.",)ports:Tuple[Port,...]=pd.Field((),title="Ports",description="Collection of ports describing the scattering matrix elements. ""For each input mode, one simulation will be run with a modal source.",)freqs:FreqArray=pd.Field(...,title="Frequencies",description="Array or list of frequencies at which to evaluate the scattering matrix.",units=HERTZ,)folder_name:str=pd.Field("default",title="Folder Name",description="Name of the folder for the tasks on web.",)element_mappings:Tuple[Tuple[Element,Element,Complex],...]=pd.Field((),title="Element Mappings",description="Mapping between elements of the scattering matrix, ""as specified by pairs of ``(port name, mode index)`` matrix indices, where the ""first element of the pair is the output and the second element of the pair is the input.""Each item of ``element_mappings`` is a tuple of ``(element1, element2, c)``, where ""the scattering matrix ``Smatrix[element2]`` is set equal to ``c * Smatrix[element1]``.""If all elements of a given column of the scattering matrix are defined by "" ``element_mappings``, the simulation corresponding to this column ""is skipped automatically.",)run_only:Optional[Tuple[MatrixIndex,...]]=pd.Field(None,title="Run Only",description="If given, a tuple of matrix indices, specified by (:class:`.Port`, ``int``),"" to run only, excluding the other rows from the scattering matrix. ""If this option is used, ""the data corresponding to other inputs will be missing in the resulting matrix.",)"""Finally, to exclude some rows of the scattering matrix, one can supply a ``run_only`` parameter to the :class:`ComponentModeler`. ``run_only`` contains the scattering matrix indices that the user wants to run as a source. If any indices are excluded, they will not be run."""verbose:bool=pd.Field(False,title="Verbosity",description="Whether the :class:`.ComponentModeler` should print status and progressbars.",)callback_url:str=pd.Field(None,title="Callback URL",description="Http PUT url to receive simulation finish event. ""The body content is a json file with fields ""``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.",)path_dir:str=pd.Field(DEFAULT_DATA_DIR,title="Directory Path",description="Base directory where data and batch will be downloaded.",)@pd.validator("simulation",always=True)def_sim_has_no_sources(cls,val):"""Make sure simulation has no sources as they interfere with tool."""iflen(val.sources)>0:raiseSetupError("'ComponentModeler.simulation' must not have any sources.")returnval@cached_propertydefsim_dict(self)->Dict[str,Simulation]:"""Generate all the :class:`Simulation` objects for the S matrix calculation."""sim_dict={}mode_monitors=[self.to_monitor(port=port)forportinself.ports]forport_name,mode_indexinself.matrix_indices_run_sim:port=self.get_port_by_name(port_name=port_name)port_source=self.shift_port(port=port)mode_source=self.to_source(port=port_source,mode_index=mode_index)new_mnts=list(self.simulation.monitors)+mode_monitorssim_copy=self.simulation.copy(update=dict(sources=[mode_source],monitors=new_mnts))task_name=self._task_name(port=port,mode_index=mode_index)sim_dict[task_name]=sim_copyreturnsim_dict@cached_propertydefmatrix_indices_monitor(self)->Tuple[MatrixIndex,...]:"""Tuple of all the possible matrix indices (port, mode_index) in the Component Modeler."""matrix_indices=[]forportinself.ports:formode_indexinrange(port.mode_spec.num_modes):matrix_indices.append((port.name,mode_index))returntuple(matrix_indices)@cached_propertydefmatrix_indices_source(self)->Tuple[MatrixIndex,...]:"""Tuple of all the source matrix indices (port, mode_index) in the Component Modeler."""ifself.run_onlyisnotNone:returnself.run_onlyreturnself.matrix_indices_monitor@cached_propertydefmatrix_indices_run_sim(self)->Tuple[MatrixIndex,...]:"""Tuple of all the source matrix indices (port, mode_index) in the Component Modeler."""ifself.element_mappingsisNoneorself.element_mappings=={}:returnself.matrix_indices_source# all the (i, j) pairs in `S_ij` that are tagged as covered by `element_mappings`elements_determined_by_map=[element_outfor(_,element_out,_)inself.element_mappings]# loop through rows of the full s matrix and record rows that still need running.source_indices_needed=[]forcol_indexinself.matrix_indices_source:# loop through columns and keep track of whether each element is covered by mapping.matrix_elements_covered=[]forrow_indexinself.matrix_indices_monitor:element=(row_index,col_index)element_covered_by_map=elementinelements_determined_by_mapmatrix_elements_covered.append(element_covered_by_map)# if any matrix elements in row still not covered by map, a source is needed for row.ifnotall(matrix_elements_covered):source_indices_needed.append(col_index)returnsource_indices_needed
[docs]defget_port_by_name(self,port_name:str)->Port:"""Get the port from the name."""ports=[portforportinself.portsifport.name==port_name]iflen(ports)==0:raiseTidy3dKeyError(f'Port "{port_name}" not found.')returnports[0]
[docs]defto_monitor(self,port:Port)->ModeMonitor:"""Creates a mode monitor from a given port."""returnModeMonitor(center=port.center,size=port.size,freqs=self.freqs,mode_spec=port.mode_spec,name=port.name,)
[docs]defto_source(self,port:Port,mode_index:int)->List[ModeSource]:"""Creates a list of mode sources from a given port."""freq0=np.mean(self.freqs)fdiff=max(self.freqs)-min(self.freqs)fwidth=max(fdiff,freq0*FWIDTH_FRAC)returnModeSource(center=port.center,size=port.size,source_time=GaussianPulse(freq0=freq0,fwidth=fwidth),mode_spec=port.mode_spec,mode_index=mode_index,direction=port.direction,name=port.name,)
def_shift_value_signed(self,port:Port)->float:"""How far (signed) to shift the source from the monitor."""# get the grid boundaries and sizes along port normal from the simulationnormal_axis=port.size.index(0.0)grid=self.simulation.gridgrid_boundaries=grid.boundaries.to_list[normal_axis]grid_centers=grid.centers.to_list[normal_axis]# get the index of the grid cell where the port liesport_position=port.center[normal_axis]port_pos_gt_grid_bounds=np.argwhere(port_position>grid_boundaries)# no port index can be determinediflen(port_pos_gt_grid_bounds)==0:raiseSetupError(f"Port position '{port_position}' outside of simulation bounds.")port_index=port_pos_gt_grid_bounds[-1]# shift the port to the leftifport.direction=="+":shifted_index=port_index-2ifshifted_index<0:raiseSetupError(f"Port {port.name} normal is too close to boundary "f"on -{'xyz'[normal_axis]} side.")# shift the port to the rightelse:shifted_index=port_index+2ifshifted_index>=len(grid_centers):raiseSetupError(f"Port {port.name} normal is too close to boundary "f"on +{'xyz'[normal_axis]} side.")new_pos=grid_centers[shifted_index]returnnew_pos-port_position
[docs]defshift_port(self,port:Port)->Port:"""Generate a new port shifted by the shift amount in normal direction."""shift_value=self._shift_value_signed(port=port)center_shifted=list(port.center)center_shifted[port.size.index(0.0)]+=shift_valueport_shifted=port.copy(update=dict(center=center_shifted))returnport_shifted
@staticmethoddef_task_name(port:Port,mode_index:int)->str:"""The name of a task, determined by the port of the source and mode index."""returnf"smatrix_{port.name}_{mode_index}"
[docs]@equal_aspect@add_ax_if_nonedefplot_sim(self,x:float=None,y:float=None,z:float=None,ax:Ax=None,**kwargs)->Ax:"""Plot a :class:`Simulation` with all sources added for each port, for troubleshooting."""plot_sources=[]forport_sourceinself.ports:mode_source_0=self.to_source(port=port_source,mode_index=0)plot_sources.append(mode_source_0)sim_plot=self.simulation.copy(update=dict(sources=plot_sources))returnsim_plot.plot(x=x,y=y,z=z,ax=ax,**kwargs)
[docs]@equal_aspect@add_ax_if_nonedefplot_sim_eps(self,x:float=None,y:float=None,z:float=None,ax:Ax=None,**kwargs)->Ax:"""Plot permittivity of the :class:`Simulation` with all sources added for each port."""plot_sources=[]forport_sourceinself.ports:mode_source_0=self.to_source(port=port_source,mode_index=0)plot_sources.append(mode_source_0)sim_plot=self.simulation.copy(update=dict(sources=plot_sources))returnsim_plot.plot_eps(x=x,y=y,z=z,ax=ax,**kwargs)
@cached_propertydefbatch(self)->Batch:"""Batch associated with this component modeler."""# first try loading the batch from file, if it existsbatch_path=self._batch_pathifos.path.exists(batch_path):returnBatch.from_file(fname=batch_path)returnBatch(simulations=self.sim_dict,folder_name=self.folder_name,callback_url=self.callback_url,verbose=self.verbose,)@cached_propertydefbatch_path(self)->str:"""Path to the batch saved to file."""returnself.batch._batch_path(path_dir=DEFAULT_DATA_DIR)
[docs]defget_path_dir(self,path_dir:str)->None:"""Check whether the supplied 'path_dir' matches the internal field value."""ifpath_dirnotin(DEFAULT_DATA_DIR,self.path_dir):log.warning(f"'ComponentModeler' method was supplied a 'path_dir' of '{path_dir}' "f"when its internal 'path_dir' field was set to '{self.path_dir}'. ""The passed value will be deprecated in later versions. ""Please set the internal 'path_dir' field to the desired value and ""remove the 'path_dir' from the method argument. "f"Using supplied '{path_dir}'.")returnpath_dirreturnself.path_dir
@cached_propertydef_batch_path(self)->str:"""Where we store the batch for this ComponentModeler instance after the run."""hash_str=self._hash_self()returnos.path.join(self.path_dir,"batch"+hash_str+".json")def_run_sims(self,path_dir:str=DEFAULT_DATA_DIR)->BatchData:"""Run :class:`Simulations` for each port and return the batch after saving."""batch=self.batchbatch_data=batch.run(path_dir=path_dir)batch.to_file(self._batch_path)returnbatch_datadef_normalization_factor(self,port_source:Port,sim_data:SimulationData)->complex:"""Compute the normalization amplitude based on the measured input mode amplitude."""port_monitor_data=sim_data[port_source.name]mode_index=sim_data.simulation.sources[0].mode_indexnormalize_amps=port_monitor_data.amps.sel(f=np.array(self.freqs),direction=port_source.direction,mode_index=mode_index,)returnnormalize_amps.values@cached_propertydefmax_mode_index(self)->Tuple[int,int]:"""maximum mode indices for the smatrix dataset for the in and out ports, respectively."""defget_max_mode_indices(matrix_elements:Tuple[str,int])->int:"""Get the maximum mode index for a list of (port name, mode index)."""returnmax(mode_indexfor_,mode_indexinmatrix_elements)max_mode_index_out=get_max_mode_indices(self.matrix_indices_monitor)max_mode_index_in=get_max_mode_indices(self.matrix_indices_source)returnmax_mode_index_out,max_mode_index_in@cached_propertydefport_names(self)->Tuple[List[str],List[str]]:"""List of port names for inputs and outputs, respectively."""defget_port_names(matrix_elements:Tuple[str,int])->List[str]:"""Get the port names from a list of (port name, mode index)."""port_names=[]forport_name,_inmatrix_elements:ifport_namenotinport_names:port_names.append(port_name)returnport_namesport_names_in=get_port_names(self.matrix_indices_source)port_names_out=get_port_names(self.matrix_indices_monitor)returnport_names_out,port_names_indef_construct_smatrix(self,batch_data:BatchData)->SMatrixDataArray:"""Post process `BatchData` to generate scattering matrix."""max_mode_index_out,max_mode_index_in=self.max_mode_indexnum_modes_out=max_mode_index_out+1num_modes_in=max_mode_index_in+1port_names_out,port_names_in=self.port_namesvalues=np.zeros((len(port_names_out),len(port_names_in),num_modes_out,num_modes_in,len(self.freqs)),dtype=complex,)coords=dict(port_out=port_names_out,port_in=port_names_in,mode_index_out=range(num_modes_out),mode_index_in=range(num_modes_in),f=np.array(self.freqs),)s_matrix=SMatrixDataArray(values,coords=coords)# loop through source portsforcol_indexinself.matrix_indices_run_sim:port_name_in,mode_index_in=col_indexport_in=self.get_port_by_name(port_name=port_name_in)sim_data=batch_data[self._task_name(port=port_in,mode_index=mode_index_in)]forrow_indexinself.matrix_indices_monitor:port_name_out,mode_index_out=row_indexport_out=self.get_port_by_name(port_name=port_name_out)# directly compute the elementmode_amps_data=sim_data[port_out.name].copy().ampsdir_out="-"ifport_out.direction=="+"else"+"amp=mode_amps_data.sel(f=coords["f"],direction=dir_out,mode_index=mode_index_out)source_norm=self._normalization_factor(port_in,sim_data)s_matrix_elements=np.array(amp.data)/np.array(source_norm)s_matrix.loc[dict(port_in=port_name_in,mode_index_in=mode_index_in,port_out=port_name_out,mode_index_out=mode_index_out,)]=s_matrix_elements# element can be determined by user-defined mappingfor(row_in,col_in),(row_out,col_out),mult_byinself.element_mappings:port_out_from,mode_index_out_from=row_inport_in_from,mode_index_in_from=col_incoords_from=dict(port_in=port_in_from,mode_index_in=mode_index_in_from,port_out=port_out_from,mode_index_out=mode_index_out_from,)port_out_to,mode_index_out_to=row_outport_in_to,mode_index_in_to=col_outcoords_to=dict(port_in=port_in_to,mode_index_in=mode_index_in_to,port_out=port_out_to,mode_index_out=mode_index_out_to,)s_matrix.loc[coords_to]=mult_by*s_matrix.loc[coords_from].valuesreturns_matrix
[docs]defrun(self,path_dir:str=DEFAULT_DATA_DIR)->SMatrixDataArray:"""Solves for the scattering matrix of the system."""path_dir=self.get_path_dir(path_dir)batch_data=self._run_sims(path_dir=path_dir)returnself._construct_smatrix(batch_data=batch_data)
[docs]defload(self,path_dir:str=DEFAULT_DATA_DIR)->SMatrixDataArray:"""Load a scattering matrix from saved `BatchData` object."""path_dir=self.get_path_dir(path_dir)batch_data=BatchData.load(path_dir=path_dir)returnself._construct_smatrix(batch_data=batch_data)