"""Collections of DataArrays."""from__future__importannotationsfromabcimportABC,abstractmethodfromtypingimportUnion,Dict,Callable,Any,Tupleimportxarrayasxrimportnumpyasnpimportpydantic.v1aspdfrommatplotlib.triimportTriangulationfrommatplotlibimportpyplotaspltimportnumbersfrom.data_arrayimportDataArray,DATA_ARRAY_MAPfrom.data_arrayimportScalarFieldDataArray,ScalarFieldTimeDataArray,ScalarModeFieldDataArrayfrom.data_arrayimportModeIndexDataArray,GroupIndexDataArray,ModeDispersionDataArrayfrom.data_arrayimportTriangleMeshDataArrayfrom.data_arrayimportTimeDataArrayfrom.data_arrayimportPointDataArray,IndexedDataArray,CellDataArray,SpatialDataArrayfrom.data_arrayimportEMEScalarFieldDataArray,EMEScalarModeFieldDataArrayfrom..vizimportequal_aspect,add_ax_if_none,plot_params_gridfrom..baseimportTidy3dBaseModel,cached_propertyfrom..baseimportskip_if_fields_missingfrom..typesimportAxis,Bound,ArrayLike,Ax,Coordinate,Literal,annotate_typefrom...packagingimportvtk,requires_vtkfrom...exceptionsimportDataError,ValidationError,Tidy3dNotImplementedErrorfrom...constantsimportPICOSECOND_PER_NANOMETER_PER_KILOMETER,inffrom...logimportlogDEFAULT_MAX_SAMPLES_PER_STEP=10_000DEFAULT_MAX_CELLS_PER_STEP=10_000DEFAULT_TOLERANCE_CELL_FINDING=1e-6classDataset(Tidy3dBaseModel,ABC):"""Abstract base class for objects that store collections of `:class:`.DataArray`s."""
[docs]classAbstractFieldDataset(Dataset,ABC):"""Collection of scalar fields with some symmetry properties."""@property@abstractmethoddeffield_components(self)->Dict[str,DataArray]:"""Maps the field components to their associated data."""
[docs]defapply_phase(self,phase:float)->AbstractFieldDataset:"""Create a copy where all elements are phase-shifted by a value (in radians)."""ifphase==0.0:returnselfphasor=np.exp(1j*phase)field_components_shifted={}forfld_name,fld_cmpinself.field_components.items():fld_cmp_shifted=phasor*fld_cmpfield_components_shifted[fld_name]=fld_cmp_shiftedreturnself.updated_copy(**field_components_shifted)
@property@abstractmethoddefgrid_locations(self)->Dict[str,str]:"""Maps field components to the string key of their grid locations on the yee lattice."""@property@abstractmethoddefsymmetry_eigenvalues(self)->Dict[str,Callable[[Axis],float]]:"""Maps field components to their (positive) symmetry eigenvalues."""
[docs]defpackage_colocate_results(self,centered_fields:Dict[str,ScalarFieldDataArray])->Any:"""How to package the dictionary of fields computed via self.colocate()."""returnxr.Dataset(centered_fields)
[docs]defcolocate(self,x=None,y=None,z=None)->xr.Dataset:"""Colocate all of the data at a set of x, y, z coordinates. Parameters ---------- x : Optional[array-like] = None x coordinates of locations. If not supplied, does not try to colocate on this dimension. y : Optional[array-like] = None y coordinates of locations. If not supplied, does not try to colocate on this dimension. z : Optional[array-like] = None z coordinates of locations. If not supplied, does not try to colocate on this dimension. Returns ------- xr.Dataset Dataset containing all fields at the same spatial locations. For more details refer to `xarray's Documentation <https://tinyurl.com/cyca3krz>`_. Note ---- For many operations (such as flux calculations and plotting), it is important that the fields are colocated at the same spatial locations. Be sure to apply this method to your field data in those cases. """ifhasattr(self,"monitor")andself.monitor.colocate:withlogasconsolidated_logger:consolidated_logger.warning("Colocating data that has already been colocated during the solver ""run. For most accurate results when colocating to custom coordinates set ""'Monitor.colocate' to 'False' to use the raw data on the Yee grid ""and avoid double interpolation. Note: the default value was changed to 'True' ""in Tidy3D version 2.4.0.")# convert supplied coordinates to array and assign string mapping to themsupplied_coord_map={k:np.array(v)fork,vinzip("xyz",(x,y,z))ifvisnotNone}# dict of data arrays to combine in dataset and returncentered_fields={}# loop through field componentsforfield_name,field_datainself.field_components.items():# loop through x, y, z dimensions and raise an error if only one element along dimforcoord_name,coords_suppliedinsupplied_coord_map.items():coord_data=np.array(field_data.coords[coord_name])ifcoord_data.size==1:raiseDataError(f"colocate given {coord_name}={coords_supplied}, but "f"data only has one coordinate at {coord_name}={coord_data[0]}. ""Therefore, can't colocate along this dimension. "f"supply {coord_name}=None to skip it.")centered_fields[field_name]=field_data.interp(**supplied_coord_map,kwargs={"bounds_error":True})# combine all centered fields in a datasetreturnself.package_colocate_results(centered_fields)
EMScalarFieldType=Union[ScalarFieldDataArray,ScalarFieldTimeDataArray,ScalarModeFieldDataArray,EMEScalarModeFieldDataArray,EMEScalarFieldDataArray,]classElectromagneticFieldDataset(AbstractFieldDataset,ABC):"""Stores a collection of E and H fields with x, y, z components."""Ex:EMScalarFieldType=pd.Field(None,title="Ex",description="Spatial distribution of the x-component of the electric field.",)Ey:EMScalarFieldType=pd.Field(None,title="Ey",description="Spatial distribution of the y-component of the electric field.",)Ez:EMScalarFieldType=pd.Field(None,title="Ez",description="Spatial distribution of the z-component of the electric field.",)Hx:EMScalarFieldType=pd.Field(None,title="Hx",description="Spatial distribution of the x-component of the magnetic field.",)Hy:EMScalarFieldType=pd.Field(None,title="Hy",description="Spatial distribution of the y-component of the magnetic field.",)Hz:EMScalarFieldType=pd.Field(None,title="Hz",description="Spatial distribution of the z-component of the magnetic field.",)@propertydeffield_components(self)->Dict[str,DataArray]:"""Maps the field components to their associated data."""fields={"Ex":self.Ex,"Ey":self.Ey,"Ez":self.Ez,"Hx":self.Hx,"Hy":self.Hy,"Hz":self.Hz,}return{field_name:fieldforfield_name,fieldinfields.items()iffieldisnotNone}@propertydefgrid_locations(self)->Dict[str,str]:"""Maps field components to the string key of their grid locations on the yee lattice."""returndict(Ex="Ex",Ey="Ey",Ez="Ez",Hx="Hx",Hy="Hy",Hz="Hz")@propertydefsymmetry_eigenvalues(self)->Dict[str,Callable[[Axis],float]]:"""Maps field components to their (positive) symmetry eigenvalues."""returndict(Ex=lambdadim:-1if(dim==0)else+1,Ey=lambdadim:-1if(dim==1)else+1,Ez=lambdadim:-1if(dim==2)else+1,Hx=lambdadim:+1if(dim==0)else-1,Hy=lambdadim:+1if(dim==1)else-1,Hz=lambdadim:+1if(dim==2)else-1,)
[docs]classFieldDataset(ElectromagneticFieldDataset):"""Dataset storing a collection of the scalar components of E and H fields in the freq. domain Example ------- >>> x = [-1,1] >>> y = [-2,0,2] >>> z = [-3,-1,1,3] >>> f = [2e14, 3e14] >>> coords = dict(x=x, y=y, z=z, f=f) >>> scalar_field = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) >>> data = FieldDataset(Ex=scalar_field, Hz=scalar_field) """Ex:ScalarFieldDataArray=pd.Field(None,title="Ex",description="Spatial distribution of the x-component of the electric field.",)Ey:ScalarFieldDataArray=pd.Field(None,title="Ey",description="Spatial distribution of the y-component of the electric field.",)Ez:ScalarFieldDataArray=pd.Field(None,title="Ez",description="Spatial distribution of the z-component of the electric field.",)Hx:ScalarFieldDataArray=pd.Field(None,title="Hx",description="Spatial distribution of the x-component of the magnetic field.",)Hy:ScalarFieldDataArray=pd.Field(None,title="Hy",description="Spatial distribution of the y-component of the magnetic field.",)Hz:ScalarFieldDataArray=pd.Field(None,title="Hz",description="Spatial distribution of the z-component of the magnetic field.",)
[docs]classFieldTimeDataset(ElectromagneticFieldDataset):"""Dataset storing a collection of the scalar components of E and H fields in the time domain Example ------- >>> x = [-1,1] >>> y = [-2,0,2] >>> z = [-3,-1,1,3] >>> t = [0, 1e-12, 2e-12] >>> coords = dict(x=x, y=y, z=z, t=t) >>> scalar_field = ScalarFieldTimeDataArray(np.random.random((2,3,4,3)), coords=coords) >>> data = FieldTimeDataset(Ex=scalar_field, Hz=scalar_field) """Ex:ScalarFieldTimeDataArray=pd.Field(None,title="Ex",description="Spatial distribution of the x-component of the electric field.",)Ey:ScalarFieldTimeDataArray=pd.Field(None,title="Ey",description="Spatial distribution of the y-component of the electric field.",)Ez:ScalarFieldTimeDataArray=pd.Field(None,title="Ez",description="Spatial distribution of the z-component of the electric field.",)Hx:ScalarFieldTimeDataArray=pd.Field(None,title="Hx",description="Spatial distribution of the x-component of the magnetic field.",)Hy:ScalarFieldTimeDataArray=pd.Field(None,title="Hy",description="Spatial distribution of the y-component of the magnetic field.",)Hz:ScalarFieldTimeDataArray=pd.Field(None,title="Hz",description="Spatial distribution of the z-component of the magnetic field.",)
[docs]defapply_phase(self,phase:float)->AbstractFieldDataset:"""Create a copy where all elements are phase-shifted by a value (in radians)."""ifphase!=0.0:raiseValueError("Can't apply phase to time-domain field data, which is real-valued.")returnself
[docs]classModeSolverDataset(ElectromagneticFieldDataset):"""Dataset storing scalar components of E and H fields as a function of freq. and mode_index. Example ------- >>> from tidy3d import ModeSpec >>> x = [-1,1] >>> y = [0] >>> z = [-3,-1,1,3] >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> field_coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index) >>> field = ScalarModeFieldDataArray((1+1j)*np.random.random((2,1,4,2,5)), coords=field_coords) >>> index_coords = dict(f=f, mode_index=mode_index) >>> index_data = ModeIndexDataArray((1+1j) * np.random.random((2,5)), coords=index_coords) >>> data = ModeSolverDataset( ... Ex=field, ... Ey=field, ... Ez=field, ... Hx=field, ... Hy=field, ... Hz=field, ... n_complex=index_data ... ) """Ex:ScalarModeFieldDataArray=pd.Field(None,title="Ex",description="Spatial distribution of the x-component of the electric field of the mode.",)Ey:ScalarModeFieldDataArray=pd.Field(None,title="Ey",description="Spatial distribution of the y-component of the electric field of the mode.",)Ez:ScalarModeFieldDataArray=pd.Field(None,title="Ez",description="Spatial distribution of the z-component of the electric field of the mode.",)Hx:ScalarModeFieldDataArray=pd.Field(None,title="Hx",description="Spatial distribution of the x-component of the magnetic field of the mode.",)Hy:ScalarModeFieldDataArray=pd.Field(None,title="Hy",description="Spatial distribution of the y-component of the magnetic field of the mode.",)Hz:ScalarModeFieldDataArray=pd.Field(None,title="Hz",description="Spatial distribution of the z-component of the magnetic field of the mode.",)n_complex:ModeIndexDataArray=pd.Field(...,title="Propagation Index",description="Complex-valued effective propagation constants associated with the mode.",)n_group_raw:GroupIndexDataArray=pd.Field(None,alias="n_group",# This is for backwards compatibility only when loading old datatitle="Group Index",description="Index associated with group velocity of the mode.",)dispersion_raw:ModeDispersionDataArray=pd.Field(None,title="Dispersion",description="Dispersion parameter for the mode.",units=PICOSECOND_PER_NANOMETER_PER_KILOMETER,)@propertydeffield_components(self)->Dict[str,DataArray]:"""Maps the field components to their associated data."""fields={"Ex":self.Ex,"Ey":self.Ey,"Ez":self.Ez,"Hx":self.Hx,"Hy":self.Hy,"Hz":self.Hz,}return{field_name:fieldforfield_name,fieldinfields.items()iffieldisnotNone}@propertydefn_eff(self)->ModeIndexDataArray:"""Real part of the propagation index."""returnself.n_complex.real@propertydefk_eff(self)->ModeIndexDataArray:"""Imaginary part of the propagation index."""returnself.n_complex.imag@propertydefn_group(self)->GroupIndexDataArray:"""Group index."""ifself.n_group_rawisNone:log.warning("The group index was not computed. To calculate group index, pass ""'group_index_step = True' in the 'ModeSpec'.",log_once=True,)returnself.n_group_raw@propertydefdispersion(self)->ModeDispersionDataArray:r"""Dispersion parameter. .. math:: D = -\frac{\lambda}{c_0} \frac{{\rm d}^2 n_{\text{eff}}}{{\rm d}\lambda^2} """ifself.dispersion_rawisNone:log.warning("The dispersion was not computed. To calculate dispersion, pass ""'group_index_step = True' in the 'ModeSpec'.",log_once=True,)returnself.dispersion_raw
[docs]defplot_field(self,*args,**kwargs):"""Warn user to use the :class:`.ModeSolver` ``plot_field`` function now."""raiseDeprecationWarning("The 'plot_field()' method was moved to the 'ModeSolver' object.""Once the 'ModeSolver' is constructed, one may call '.plot_field()' on the object and ""the modes will be computed and displayed with 'Simulation' overlay.")
[docs]classPermittivityDataset(AbstractFieldDataset):"""Dataset storing the diagonal components of the permittivity tensor. Example ------- >>> x = [-1,1] >>> y = [-2,0,2] >>> z = [-3,-1,1,3] >>> f = [2e14, 3e14] >>> coords = dict(x=x, y=y, z=z, f=f) >>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) >>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld) """@propertydeffield_components(self)->Dict[str,ScalarFieldDataArray]:"""Maps the field components to their associated data."""returndict(eps_xx=self.eps_xx,eps_yy=self.eps_yy,eps_zz=self.eps_zz)@propertydefgrid_locations(self)->Dict[str,str]:"""Maps field components to the string key of their grid locations on the yee lattice."""returndict(eps_xx="Ex",eps_yy="Ey",eps_zz="Ez")@propertydefsymmetry_eigenvalues(self)->Dict[str,Callable[[Axis],float]]:"""Maps field components to their (positive) symmetry eigenvalues."""returndict(eps_xx=None,eps_yy=None,eps_zz=None)eps_xx:ScalarFieldDataArray=pd.Field(...,title="Epsilon xx",description="Spatial distribution of the xx-component of the relative permittivity.",)eps_yy:ScalarFieldDataArray=pd.Field(...,title="Epsilon yy",description="Spatial distribution of the yy-component of the relative permittivity.",)eps_zz:ScalarFieldDataArray=pd.Field(...,title="Epsilon zz",description="Spatial distribution of the zz-component of the relative permittivity.",)
classTriangleMeshDataset(Dataset):"""Dataset for storing triangular surface data."""surface_mesh:TriangleMeshDataArray=pd.Field(...,title="Surface mesh data",description="Dataset containing the surface triangles and corresponding face indices ""for a surface mesh.",)classTimeDataset(Dataset):"""Dataset for storing a function of time."""values:TimeDataArray=pd.Field(...,title="Values",description="Values as a function of time.")classUnstructuredGridDataset(Dataset,np.lib.mixins.NDArrayOperatorsMixin,ABC):"""Abstract base for datasets that store unstructured grid data."""points:PointDataArray=pd.Field(...,title="Grid Points",description="Coordinates of points composing the unstructured grid.",)values:IndexedDataArray=pd.Field(...,title="Point Values",description="Values stored at the grid points.",)cells:CellDataArray=pd.Field(...,title="Grid Cells",description="Cells composing the unstructured grid specified as connections between grid ""points.",)@propertydefname(self)->str:"""Dataset name."""# we redirect name to values.namereturnself.values.name@propertydefis_complex(self)->bool:"""Data type."""returnnp.iscomplexobj(self.values)@propertydef_double_type(self):"""Corresponding double data type."""returnnp.complex128ifself.is_complexelsenp.float64@pd.validator("points",always=True)defpoints_right_dims(cls,val):"""Check that point coordinates have the right dimensionality."""# currently support only the standard axis ordering, that is 01(2)axis_coords_expected=np.arange(cls._point_dims())axis_coords_given=val.axis.dataifnp.any(axis_coords_given!=axis_coords_expected):raiseValidationError(f"Points array is expected to have {axis_coords_expected} coord values along 'axis'"f" (given: {axis_coords_given}).")returnval@propertydefis_uniform(self):"""Whether each element is of equal value in ``values``."""returnself.values.is_uniform@pd.validator("points",always=True)defpoints_right_indexing(cls,val):"""Check that points are indexed corrrectly."""indices_expected=np.arange(len(val.data))indices_given=val.index.dataifnp.any(indices_expected!=indices_given):raiseValidationError("Coordinate 'index' of array 'points' is expected to have values (0, 1, 2, ...). ""This can be easily achieved, for example, by using ""PointDataArray(data, dims=['index', 'axis']).")returnval@pd.validator("values",always=True)defvalues_right_indexing(cls,val):"""Check that data values are indexed correctly."""# currently support only simple ordered indexing of points, that is, 0, 1, 2, ...indices_expected=np.arange(len(val.data))indices_given=val.index.dataifnp.any(indices_expected!=indices_given):raiseValidationError("Coordinate 'index' of array 'values' is expected to have values (0, 1, 2, ...). ""This can be easily achieved, for example, by using ""IndexedDataArray(data, dims=['index']).")returnval@pd.validator("values",always=True)@skip_if_fields_missing(["points"])defnumber_of_values_matches_points(cls,val,values):"""Check that the number of data values matches the number of grid points."""num_values=len(val)points=values.get("points")num_points=len(points)ifnum_points!=num_values:raiseValidationError(f"The number of data values ({num_values}) does not match the number of grid "f"points ({num_points}).")returnval@pd.validator("cells",always=True)defmatch_cells_to_vtk_type(cls,val):"""Check that cell connections does not have duplicate points."""ifvtkisNone:returnval# using val.astype(np.int32/64) directly causes issues when dataarray are later checked ==returnCellDataArray(val.data.astype(vtk["id_type"],copy=False),coords=val.coords)@pd.validator("cells",always=True)defcells_right_type(cls,val):"""Check that cell are of the right type."""# only supporting the standard ordering of cell vertices 012(3)vertex_coords_expected=np.arange(cls._cell_num_vertices())vertex_coords_given=val.vertex_index.dataifnp.any(vertex_coords_given!=vertex_coords_expected):raiseValidationError(f"Cell connections array is expected to have {vertex_coords_expected} coord values"f" along 'vertex_index' (given: {vertex_coords_given}).")returnval@pd.validator("cells",always=True)@skip_if_fields_missing(["points"])defcheck_cell_vertex_range(cls,val,values):"""Check that cell connections use only defined points."""all_point_indices_used=val.data.ravel()# skip validation if zero size dataiflen(all_point_indices_used)>0:min_index_used=np.min(all_point_indices_used)max_index_used=np.max(all_point_indices_used)points=values.get("points")num_points=len(points)ifmax_index_used>num_points-1ormin_index_used<0:raiseValidationError("Cell connections array uses undefined point indices in the range "f"[{min_index_used}, {max_index_used}]. The valid range of point indices is "f"[0, {num_points-1}].")returnval@classmethoddef_find_degenerate_cells(cls,cells:CellDataArray):"""Find explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """indices=cells.data# skip validation if zero size datadegenerate_cell_inds=set()iflen(indices)>0:foriinrange(cls._cell_num_vertices()-1):forjinrange(i+1,cls._cell_num_vertices()):degenerate_cell_inds=degenerate_cell_inds.union(np.where(indices[:,i]==indices[:,j])[0])returndegenerate_cell_inds@classmethoddef_remove_degenerate_cells(cls,cells:CellDataArray):"""Remove explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """degenerate_cells=cls._find_degenerate_cells(cells=cells)iflen(degenerate_cells)>0:data=np.delete(cells.values,list(degenerate_cells),axis=0)cell_index=np.delete(cells.cell_index.values,list(degenerate_cells))returnCellDataArray(data=data,coords=dict(cell_index=cell_index,vertex_index=cells.vertex_index))returncells@classmethoddef_remove_unused_points(cls,points:PointDataArray,values:IndexedDataArray,cells:CellDataArray):"""Remove unused points if any. That is, points that are not used in any grid cell. """used_indices=np.unique(cells.values.ravel())num_points=len(points)iflen(used_indices)!=num_pointsornp.any(np.diff(used_indices)!=1):min_index=np.min(used_indices)map_len=np.max(used_indices)-min_index+1index_map=np.zeros(map_len)index_map[used_indices-min_index]=np.arange(len(used_indices))cells=CellDataArray(data=index_map[cells.data-min_index],coords=cells.coords)points=PointDataArray(points.data[used_indices,:],dims=["index","axis"])values=IndexedDataArray(values.data[used_indices],dims=["index"])returnpoints,values,cellsdefclean(self,remove_degenerate_cells=True,remove_unused_points=True):"""Remove degenerate cells and/or unused points."""ifremove_degenerate_cells:cells=self._remove_degenerate_cells(cells=self.cells)else:cells=self.cellsifremove_unused_points:points,values,cells=self._remove_unused_points(self.points,self.values,cells)else:points=self.pointsvalues=self.valuesreturnself.updated_copy(points=points,values=values,cells=cells)@pd.validator("cells",always=True)defwarn_degenerate_cells(cls,val):"""Check that cell connections does not have duplicate points."""degenerate_cells=cls._find_degenerate_cells(val)num_degenerate_cells=len(degenerate_cells)ifnum_degenerate_cells>0:log.warning(f"Unstructured grid contains {num_degenerate_cells} degenerate cell(s). ""Such cells can be removed by using function ""'.clean(remove_degenerate_cells: bool = True, remove_unused_points: bool = True)'. ""For example, 'dataset = dataset.clean()'.")returnval@pd.root_validator(pre=True,allow_reuse=True)def_warn_if_none(cls,values):"""Warn if any of data arrays are not loaded."""no_data_fields=[]forfield_namein["points","cells","values"]:field=values.get(field_name)ifisinstance(field,str)andfieldinDATA_ARRAY_MAP.keys():no_data_fields.append(field_name)iflen(no_data_fields)>0:formatted_names=[f"'{fname}'"forfnameinno_data_fields]log.warning(f"Loading {', '.join(formatted_names)} without data. Constructing an empty dataset.")values["points"]=PointDataArray(np.zeros((0,cls._point_dims())),dims=["index","axis"])values["cells"]=CellDataArray(np.zeros((0,cls._cell_num_vertices())),dims=["cell_index","vertex_index"])values["values"]=IndexedDataArray(np.zeros(0),dims=["index"])returnvalues@pd.root_validator(skip_on_failure=True,allow_reuse=True)def_warn_unused_points(cls,values):"""Warn if some points are unused."""point_indices=set(np.arange(len(values["values"].data)))used_indices=set(values["cells"].values.ravel())ifnotpoint_indices.issubset(used_indices):log.warning("Unstructured grid dataset contains unused points. ""Consider calling 'clean()' to remove them.")returnvaluesdefrename(self,name:str)->UnstructuredGridDataset:"""Return a renamed array."""returnself.updated_copy(values=self.values.rename(name))def__array_ufunc__(self,ufunc,method,*inputs,**kwargs):"""Override of numpy functions."""out=kwargs.get("out",())forxininputs+out:# Only support operations with the same class or a scalarifnotisinstance(x,(numbers.Number,type(self))):returnTidy3dNotImplementedError# Defer to the implementation of the ufunc on unwrapped values.inputs=tuple(x.valuesifisinstance(x,UnstructuredGridDataset)elsexforxininputs)ifout:kwargs["out"]=tuple(x.valuesifisinstance(x,UnstructuredGridDataset)elsexforxinout)result=getattr(ufunc,method)(*inputs,**kwargs)iftype(result)istuple:# multiple return valuesreturntuple(self.updated_copy(values=x)forxinresult)elifmethod=="at":# no return valuereturnNoneelse:# one return valuereturnself.updated_copy(values=result)@propertydefreal(self)->UnstructuredGridDataset:"""Real part of dataset."""returnself.updated_copy(values=self.values.real)@propertydefimag(self)->UnstructuredGridDataset:"""Imaginary part of dataset."""returnself.updated_copy(values=self.values.imag)@propertydefabs(self)->UnstructuredGridDataset:"""Absolute value of dataset."""returnself.updated_copy(values=self.values.abs)@cached_propertydefbounds(self)->Bound:"""Grid bounds."""returntuple(np.min(self.points.data,axis=0)),tuple(np.max(self.points.data,axis=0))@classmethod@abstractmethoddef_point_dims(cls)->pd.PositiveInt:"""Dimensionality of stored grid point coordinates."""@cached_property@abstractmethoddef_points_3d_array(self):"""3D coordinates of grid points."""@classmethod@abstractmethoddef_cell_num_vertices(cls)->pd.PositiveInt:"""Number of vertices in a cell."""@classmethod@abstractmethod@requires_vtkdef_vtk_cell_type(cls):"""VTK cell type to use in the VTK representation."""@cached_propertydef_vtk_offsets(self)->ArrayLike:"""Offsets array to use in the VTK representation."""offsets=np.arange(len(self.cells)+1)*self._cell_num_vertices()ifvtkisNone:returnoffsetsreturnoffsets.astype(vtk["id_type"],copy=False)@property@requires_vtkdef_vtk_cells(self):"""VTK cell array to use in the VTK representation."""cells=vtk["mod"].vtkCellArray()cells.SetData(vtk["numpy_to_vtkIdTypeArray"](self._vtk_offsets),vtk["numpy_to_vtkIdTypeArray"](self.cells.data.ravel()),)returncells@property@requires_vtkdef_vtk_points(self):"""VTK point array to use in the VTK representation."""pts=vtk["mod"].vtkPoints()pts.SetData(vtk["numpy_to_vtk"](self._points_3d_array))returnpts@property@requires_vtkdef_vtk_obj(self):"""A VTK representation (vtkUnstructuredGrid) of the grid."""grid=vtk["mod"].vtkUnstructuredGrid()grid.SetPoints(self._vtk_points)grid.SetCells(self._vtk_cell_type(),self._vtk_cells)ifself.is_complex:# vtk doesn't support complex numbers# so we will store our complex array as a two-component vtk arraydata_values=self.values.values.view("(2,)float")else:data_values=self.values.valuespoint_data_vtk=vtk["numpy_to_vtk"](data_values)point_data_vtk.SetName(self.values.name)grid.GetPointData().AddArray(point_data_vtk)returngrid@requires_vtkdef_plane_slice_raw(self,axis:Axis,pos:float):"""Slice data with a plane and return the resulting VTK object."""ifpos>self.bounds[1][axis]orpos<self.bounds[0][axis]:raiseDataError(f"Slicing plane (axis: {axis}, pos: {pos}) does not intersect the unstructured grid "f"(extent along axis {axis}: {self.bounds[0][axis]}, {self.bounds[1][axis]}).")origin=[0,0,0]origin[axis]=posnormal=[0,0,0]# orientation of normal is important for edge (literally) casesnormal[axis]=-1ifpos>(self.bounds[0][axis]+self.bounds[1][axis])/2:normal[axis]=1# create cutting planeplane=vtk["mod"].vtkPlane()plane.SetOrigin(origin[0],origin[1],origin[2])plane.SetNormal(normal[0],normal[1],normal[2])# create cuttercutter=vtk["mod"].vtkPlaneCutter()cutter.SetPlane(plane)cutter.SetInputData(self._vtk_obj)cutter.InterpolateAttributesOn()cutter.Update()# clean up the slicecleaner=vtk["mod"].vtkCleanPolyData()cleaner.SetInputData(cutter.GetOutput())cleaner.Update()returncleaner.GetOutput()@abstractmethod@requires_vtkdefplane_slice(self,axis:Axis,pos:float)->Union[SpatialDataArray,UnstructuredGridDataset]:"""Slice data with a plane and return the Tidy3D representation of the result (``UnstructuredGridDataset``). Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- Union[SpatialDataArray, UnstructuredGridDataset] The resulting slice. """@staticmethod@requires_vtkdef_read_vtkUnstructuredGrid(fname:str):"""Load a :class:`vtkUnstructuredGrid` from a file."""reader=vtk["mod"].vtkXMLUnstructuredGridReader()reader.SetFileName(fname)reader.Update()grid=reader.GetOutput()returngrid@classmethod@abstractmethod@requires_vtkdef_from_vtk_obj(cls,vtk_obj,field:str=None,remove_degenerate_cells:bool=False,remove_unused_points:bool=False,)->UnstructuredGridDataset:"""Initialize from a vtk object."""@classmethod@requires_vtkdeffrom_vtu(cls,file:str,field:str=None,remove_degenerate_cells:bool=False,remove_unused_points:bool=False,)->UnstructuredGridDataset:"""Load unstructured data from a vtu file. Parameters ---------- fname : str Full path to the .vtu file to load the unstructured data from. field : str = None Name of the field to load. remove_degenerate_cells : bool = False Remove explicitly degenerate cells. remove_unused_points : bool = False Remove unused points. Returns ------- UnstructuredGridDataset Unstructured data. """grid=cls._read_vtkUnstructuredGrid(file)returncls._from_vtk_obj(grid,field=field,remove_degenerate_cells=remove_degenerate_cells,remove_unused_points=remove_unused_points,)@requires_vtkdefto_vtu(self,fname:str):"""Exports unstructured grid data into a .vtu file. Parameters ---------- fname : str Full path to the .vtu file to save the unstructured data to. """writer=vtk["mod"].vtkXMLUnstructuredGridWriter()writer.SetFileName(fname)writer.SetInputData(self._vtk_obj)writer.Write()@classmethod@requires_vtkdef_get_values_from_vtk(cls,vtk_obj,num_points:pd.PositiveInt,field:str=None)->IndexedDataArray:"""Get point data values from a VTK object."""point_data=vtk_obj.GetPointData()num_point_arrays=point_data.GetNumberOfArrays()ifnum_point_arrays==0:log.warning("No point data is found in a VTK object. '.values' will be initialized to zeros.")values_numpy=np.zeros(num_points)values_name=Noneelse:iffieldisnotNone:array_vtk=point_data.GetAbstractArray(field)else:array_vtk=point_data.GetAbstractArray(0)# currently we assume there is only one point data array provided in the VTK objectifnum_point_arrays>1andfieldisNone:array_name=array_vtk.GetName()log.warning(f"{num_point_arrays} point data arrays are found in a VTK object. "f"Only the first array (name: {array_name}) will be used to initialize ""'.values' while the rest will be ignored.")# currently we assume data is real or complex scalarnum_components=array_vtk.GetNumberOfComponents()ifnum_components>2:raiseDataError("Found point data array in a VTK object is expected to have maximum 2 ""components (1 is for real data, 2 is for complex data). "f"Found {num_components} components.")# check that number of values matches number of grid pointsnum_tuples=array_vtk.GetNumberOfTuples()ifnum_tuples!=num_points:raiseDataError(f"The length of found point data array ({num_tuples}) does not match the number"f" of grid points ({num_points}).")values_numpy=vtk["vtk_to_numpy"](array_vtk)values_name=array_vtk.GetName()# vtk doesn't support complex numbers# we store our complex array as a two-component vtk array# so here we convert that into a single component complex arrayifnum_components==2:values_numpy=values_numpy.view("complex")[:,0]values=IndexedDataArray(values_numpy,coords=dict(index=np.arange(len(values_numpy))),name=values_name)returnvalues@requires_vtkdefbox_clip(self,bounds:Bound)->UnstructuredGridDataset:"""Clip the unstructured grid using a box defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- UnstructuredGridDataset Clipped grid. """# make and run a VTK clipperclipper=vtk["mod"].vtkBoxClipDataSet()clipper.SetOrientation(0)clipper.SetBoxClip(bounds[0][0],bounds[1][0],bounds[0][1],bounds[1][1],bounds[0][2],bounds[1][2])clipper.SetInputData(self._vtk_obj)clipper.GenerateClipScalarsOn()clipper.GenerateClippedOutputOff()clipper.Update()clip=clipper.GetOutput()# clean grid from unused pointsgrid_cleaner=vtk["mod"].vtkRemoveUnusedPoints()grid_cleaner.SetInputData(clip)grid_cleaner.GenerateOriginalPointIdsOff()grid_cleaner.Update()clean_clip=grid_cleaner.GetOutput()# no intersection checkifclean_clip.GetNumberOfPoints()==0:raiseDataError("Clipping box does not intersect the unstructured grid.")returnself._from_vtk_obj(clean_clip,remove_degenerate_cells=True,remove_unused_points=True)definterp(self,x:Union[float,ArrayLike],y:Union[float,ArrayLike],z:Union[float,ArrayLike],fill_value:Union[float,Literal["extrapolate"]]=None,use_vtk:bool=False,method:Literal["linear","nearest"]="linear",max_samples_per_step:int=DEFAULT_MAX_SAMPLES_PER_STEP,max_cells_per_step:int=DEFAULT_MAX_CELLS_PER_STEP,rel_tol:float=DEFAULT_TOLERANCE_CELL_FINDING,)->SpatialDataArray:"""Interpolate data at provided x, y, and z. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : Union[float, Literal["extrapolate"]] = 0 Value to use when filling points without interpolated values. If ``"extrapolate"`` then nearest values are used. Note: in a future version the default value will be changed to ``"extrapolate"``. use_vtk : bool = False Use vtk's interpolation functionality or Tidy3D's own implementation. Note: this option will be removed in a future version. method: Literal["linear", "nearest"] = "linear" Interpolation method to use. max_samples_per_step : int = 1e4 Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int = 1e4 Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float = 1e-6 Relative tolerance when determining whether a point belongs to a cell. Returns ------- SpatialDataArray Interpolated data. """iffill_valueisNone:log.warning("Default parameter setting 'fill_value=0' will be changed to ""'fill_value=``extrapolate``' in a future version.")fill_value=0# calculate the resulting array shapex=np.atleast_1d(x)y=np.atleast_1d(y)z=np.atleast_1d(z)ifmethod=="nearest":interpolated_values=self._interp_nearest(x=x,y=y,z=z)else:iffill_value=="extrapolate":fill_value_actual=np.nanelse:fill_value_actual=fill_valueifuse_vtk:log.warning("Note that option 'use_vtk=True' will be removed in future versions.")interpolated_values=self._interp_vtk(x=x,y=y,z=z,fill_value=fill_value_actual)else:interpolated_values=self._interp_py(x=x,y=y,z=z,fill_value=fill_value_actual,max_samples_per_step=max_samples_per_step,max_cells_per_step=max_cells_per_step,rel_tol=rel_tol,)iffill_value=="extrapolate"andmethod!="nearest":interpolated_values=self._fill_nans_from_nearests(interpolated_values,x=x,y=y,z=z)returnSpatialDataArray(interpolated_values,coords=dict(x=x,y=y,z=z),name=self.values.name)def_interp_nearest(self,x:ArrayLike,y:ArrayLike,z:ArrayLike,)->ArrayLike:"""Interpolate data at provided x, y, and z using Scipy's nearest neighbor interpolator. Parameters ---------- x : ArrayLike x-coordinates of sampling points. y : ArrayLike y-coordinates of sampling points. z : ArrayLike z-coordinates of sampling points. Returns ------- ArrayLike Interpolated data. """fromscipy.interpolateimportNearestNDInterpolator# use scipy's nearest neighbor interpolatorX,Y,Z=np.meshgrid(x,y,z,indexing="ij")interp=NearestNDInterpolator(self._points_3d_array,self.values.values)values=interp(X,Y,Z)returnvaluesdef_fill_nans_from_nearests(self,values:ArrayLike,x:ArrayLike,y:ArrayLike,z:ArrayLike,)->ArrayLike:"""Replace nan's in ``values`` with nearest data points. Parameters ---------- values : ArrayLike 3D array containing nan's x : ArrayLike x-coordinates of sampling points. y : ArrayLike y-coordinates of sampling points. z : ArrayLike z-coordinates of sampling points. Returns ------- ArrayLike Data without nan's. """# locate all nansnans=np.isnan(values)ifnp.sum(nans)>0:fromscipy.interpolateimportNearestNDInterpolator# use scipy's nearest neighbor interpolatorX,Y,Z=np.meshgrid(x,y,z,indexing="ij")interp=NearestNDInterpolator(self._points_3d_array,self.values.values)values_to_replace_nans=interp(X[nans],Y[nans],Z[nans])values[nans]=values_to_replace_nansreturnvalues@requires_vtkdef_interp_vtk(self,x:ArrayLike,y:ArrayLike,z:ArrayLike,fill_value:float,)->ArrayLike:"""Interpolate data at provided x, y, and z using vtk package. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float = 0 Value to use when filling points without interpolated values. Returns ------- ArrayLike Interpolated data. """shape=(len(x),len(y),len(z))# create a VTK rectilinear grid to sample ontostructured_grid=vtk["mod"].vtkRectilinearGrid()structured_grid.SetDimensions(shape)structured_grid.SetXCoordinates(vtk["numpy_to_vtk"](x))structured_grid.SetYCoordinates(vtk["numpy_to_vtk"](y))structured_grid.SetZCoordinates(vtk["numpy_to_vtk"](z))# create and execute VTK interpolatorinterpolator=vtk["mod"].vtkResampleWithDataSet()interpolator.SetInputData(structured_grid)interpolator.SetSourceData(self._vtk_obj)interpolator.Update()interpolated=interpolator.GetOutput()# get results in a numpy representationarray_id=0ifself.values.nameisNoneelseself.values.namevalues_numpy=vtk["vtk_to_numpy"](interpolated.GetPointData().GetAbstractArray(array_id))# fill points without interpolated valuesiffill_value!=0:mask=vtk["vtk_to_numpy"](interpolated.GetPointData().GetAbstractArray("vtkValidPointMask"))values_numpy[mask!=1]=fill_value# VTK arrays are the z-y-x order, reorder interpolation results to x-y-z ordervalues_reordered=np.transpose(np.reshape(values_numpy,shape[::-1]),(2,1,0))returnvalues_reordered@abstractmethoddef_interp_py(self,x:ArrayLike,y:ArrayLike,z:ArrayLike,fill_value:float,max_samples_per_step:int,max_cells_per_step:int,rel_tol:float,)->ArrayLike:"""Dimensionality-specific function (2D and 3D) to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """def_interp_py_general(self,x:ArrayLike,y:ArrayLike,z:ArrayLike,fill_value:float,max_samples_per_step:int,max_cells_per_step:int,rel_tol:float,axis_ignore:Union[Axis,None],)->ArrayLike:"""A general function (2D and 3D) to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. axis_ignore : Union[Axis, None] When interpolating from a 2D dataset, must specify normal axis. Returns ------- ArrayLike Interpolated data. """# get dimensionality of datanum_dims=self._point_dims()ifnum_dims==2andaxis_ignoreisNone:raiseDataError("Must porvide 'axis_ignore' when interpolating from a 2d dataset.")xyz_grid=[x,y,z]ifaxis_ignoreisnotNone:xyz_grid.pop(axis_ignore)# get numpy arrays for points and cellscell_connections=(self.cells.values)# (num_cells, num_cell_vertices), num_cell_vertices=num_cell_facespoints=self.points.values# (num_points, num_dims)num_cells=len(cell_connections)num_points=len(points)# compute tolerances based on total size of unstructured gridbounds=self.boundssize=np.subtract(bounds[1],bounds[0])tol=size*rel_toldiag_tol=np.linalg.norm(tol)# compute (index) positions of unstructured points w.r.t. target Cartesian grid points# (i.e. between which Cartesian grid points a given unstructured grid point is located)# we perturb grid values in both directions to make sure we don't miss any points# due to numerical precisionxyz_pos_l=np.zeros((num_dims,num_points),dtype=int)xyz_pos_r=np.zeros((num_dims,num_points),dtype=int)fordiminrange(num_dims):xyz_pos_l[dim]=np.searchsorted(xyz_grid[dim]+tol[dim],points[:,dim])xyz_pos_r[dim]=np.searchsorted(xyz_grid[dim]-tol[dim],points[:,dim])# let's allocate an array for resulting values# every time we process a chunk of samples, we will write into this arrayinterpolated_values=fill_value+np.zeros([len(xyz_comp)forxyz_compinxyz_grid],dtype=self.values.dtype)processed_cells_global=0# to ovoid OOM for large datasets, we process only certain number of cells at a timewhileprocessed_cells_global<num_cells:target_processed_cells_global=min(num_cells,processed_cells_global+max_cells_per_step)connections_to_process=cell_connections[processed_cells_global:target_processed_cells_global]# now we transfer this information to each cell. That is, each cell knows how its vertices# positioned relative to Cartesian grid points.# (num_dims, num_cells, num_vertices=num_cell_faces)xyz_pos_l_per_cell=xyz_pos_l[:,connections_to_process]xyz_pos_r_per_cell=xyz_pos_r[:,connections_to_process]# taking min/max among all cell vertices (per each dimension separately)# we get min and max indices of Cartesian grid points that may receive their values# from a given cell.# (num_dims, num_cells)cell_ind_min=np.min(xyz_pos_l_per_cell,axis=2)cell_ind_max=np.max(xyz_pos_r_per_cell,axis=2)# calculate number of Cartesian grid points where we will perform interpolation for a given# cell. Note that this number is much larger than actually needed, because essentially for# each cell we consider all Cartesian grid points that fall into the cell's bounding box.# We use word "sample" to represent such Cartesian grid points.# (num_cells,)num_samples_per_cell=np.prod(cell_ind_max-cell_ind_min,axis=0)# find cells that have non-zero number of samples# we use "ne" as a shortcut for "non empty"ne_cells=num_samples_per_cell>0# (num_cells,)num_ne_cells=np.sum(ne_cells)# indices of cells with non-zero number of samples in the original list of cells# (num_cells,)ne_cell_inds=np.arange(processed_cells_global,target_processed_cells_global)[ne_cells]# restrict to non-empty cells onlynum_samples_per_ne_cell=num_samples_per_cell[ne_cells]cum_num_samples_per_ne_cell=np.cumsum(num_samples_per_ne_cell)ne_cell_ind_min=cell_ind_min[:,ne_cells]ne_cell_ind_max=cell_ind_max[:,ne_cells]# Next we need to perform actual interpolation at all sample points# this is computationally expensive operation and because we try to do everything# in the vectorized form, it can require a lot of memory, sometimes even causing OOM errors.# To avoid that, we impose restrictions on how many cells/samples can be processed at a time# effectivelly performing these operations in chunks.# Note that currently this is done sequentially, but could be relatively easy to parallelize# start counters of how many cells/samples have been processedprocessed_samples=0processed_cells=0whileprocessed_cells<num_ne_cells:# how many cells we would like to process by the end of this steptarget_processed_cells=min(num_ne_cells,processed_cells+max_cells_per_step)# find how many cells we can processed based on number of allowed samplestarget_processed_samples=processed_samples+max_samples_per_steptarget_processed_cells_from_samples=(np.searchsorted(cum_num_samples_per_ne_cell,target_processed_samples)+1)# take min between the twotarget_processed_cells=min(target_processed_cells,target_processed_cells_from_samples)# select cells and corresponding samples to processstep_ne_cell_ind_min=ne_cell_ind_min[:,processed_cells:target_processed_cells]step_ne_cell_ind_max=ne_cell_ind_max[:,processed_cells:target_processed_cells]step_ne_cell_inds=ne_cell_inds[processed_cells:target_processed_cells]# process selected cells and pointsxyz_inds,interpolated=self._interp_py_chunk(xyz_grid=xyz_grid,cell_inds=step_ne_cell_inds,cell_ind_min=step_ne_cell_ind_min,cell_ind_max=step_ne_cell_ind_max,sdf_tol=diag_tol,)ifnum_dims==3:interpolated_values[xyz_inds[0],xyz_inds[1],xyz_inds[2]]=interpolatedelse:interpolated_values[xyz_inds[0],xyz_inds[1]]=interpolatedprocessed_cells=target_processed_cellsprocessed_samples=cum_num_samples_per_ne_cell[target_processed_cells-1]processed_cells_global=target_processed_cells_global# in case of 2d grid broadcast results along normal direction assuming translational# invarianceifnum_dims==2:orig_shape=[len(x),len(y),len(z)]flat_shape=orig_shape.copy()flat_shape[axis_ignore]=1interpolated_values=np.reshape(interpolated_values,flat_shape)interpolated_values=np.broadcast_to(interpolated_values,(len(x),len(y),len(z))).copy()returninterpolated_valuesdef_interp_py_chunk(self,xyz_grid:Tuple[ArrayLike[float],...],cell_inds:ArrayLike[int],cell_ind_min:ArrayLike[int],cell_ind_max:ArrayLike[int],sdf_tol:float,)->Tuple[Tuple[ArrayLike,...],ArrayLike]:"""For each cell listed in ``cell_inds`` perform interpolation at a rectilinear subarray of xyz_grid given by a (3D) index span (cell_ind_min, cell_ind_max). Parameters ---------- xyz_grid : Tuple[ArrayLike[float], ...] x, y, and z coordiantes defining rectilinear grid. cell_inds : ArrayLike[int] Indices of cells to perfrom interpolation from. cell_ind_min : ArrayLike[int] Starting x, y, and z indices of points for interpolation for each cell. cell_ind_max : ArrayLike[int] End x, y, and z indices of points for interpolation for each cell. sdf_tol : float Effective zero level set value, below which a point is considered to be inside a cell. Returns ------- Tuple[Tuple[ArrayLike, ...], ArrayLike] x, y, and z indices of interpolated values and values themselves. """# get dimensionality of datanum_dims=self._point_dims()num_cell_faces=self._cell_num_vertices()# get mesh info as numpy arrayspoints=self.points.values# (num_points, num_dims)data_values=self.values.values# (num_points,)cell_connections=self.cells.values[cell_inds]# compute number of samples to generate per cellnum_samples_per_cell=np.prod(cell_ind_max-cell_ind_min,axis=0)# at this point we know how many samples we need to perform per each cell and we also# know span indices of these samples (in x, y, and z arrays)# we would like to perform all interpolations in a vectorized form, however, we have# a different number of interpolation samples for different cells. Thus, we need to# arange all samples in a linear way (flatten). Basically, we want to have data in this# form:# cell_ind | x_ind | y_ind | z_ind# --------------------------------# 0 | 23 | 5 | 11# 0 | 23 | 5 | 12# 0 | 23 | 6 | 11# 0 | 23 | 6 | 12# 1 | 41 | 11 | 0# 1 | 42 | 11 | 0# ... | ... | ... | ...# to do that we start with performing arange for each cell, but in vectorized way# this gives us something like this# [0, 1, 2, 3, 0, 1, 0, 1, 2, 3, 4, 5, 6, ...]# |<-cell 0->|<-cell 1->|<- cell 2 ->|<- ...num_cells=len(num_samples_per_cell)num_samples_cumul=num_samples_per_cell.cumsum()num_samples_total=num_samples_cumul[-1]# one big arange arrayinds_flat=np.arange(num_samples_total)# now subtract previous number of samplesinds_flat[num_samples_per_cell[0]:]-=np.repeat(num_samples_cumul[:-1],num_samples_per_cell[1:])# convert flat indices into 3d/2d indices as:# x_ind = [23, 23, 23, 23, 41, 41, ...]# y_ind = [ 5, 5, 5, 5, 6, 6, ...]# z_ind = [11, 12, 11, 12, 0, 0, ...]# |<- cell 0 ->|<- cell 1 ->|<- ...num_samples_y=np.repeat(cell_ind_max[1]-cell_ind_min[1],num_samples_per_cell)# note: in 2d x, y correspond to (x, y, z).pop(normal_axis)ifnum_dims==3:num_samples_z=np.repeat(cell_ind_max[2]-cell_ind_min[2],num_samples_per_cell)inds_flat,z_inds=np.divmod(inds_flat,num_samples_z)x_inds,y_inds=np.divmod(inds_flat,num_samples_y)start_inds=np.repeat(cell_ind_min,num_samples_per_cell,axis=1)x_inds=x_inds+start_inds[0]y_inds=y_inds+start_inds[1]ifnum_dims==3:z_inds=z_inds+start_inds[2]# finally, we repeat cell indices corresponding number of times to obtain how# (x_ind, y_ind, z_ind) map to cell indices. So, now we have four arras:# x_ind = [23, 23, 23, 23, 41, 41, ...]# y_ind = [ 5, 5, 5, 5, 6, 6, ...]# z_ind = [11, 12, 11, 12, 0, 0, ...]# cell_map = [ 0, 0, 0, 0, 1, 1, ...]# |<- cell 0 ->|<- cell 1 ->|<- ...step_cell_map=np.repeat(np.arange(num_cells),num_samples_per_cell)# let's put these arrays aside for a moment and perform the second preparatory step# specifically, for each face of each cell we will compute normal vector and distance# to the opposing cell vertex. This will allows us quickly calculate SDF of a cell at# each sample point as well as perform linear interpolation.# first, we collect coordinates of cell vertices into a single array# (num_cells, num_cell_vertices, num_dims)cell_vertices=np.float64(points[cell_connections,:])# array for resulting normals and distancesnormal=np.zeros((num_cell_faces,num_cells,num_dims))dist=np.zeros((num_cell_faces,num_cells))# loop face by face# note that by face_ind we denote both index of face in a cell and index of the opposing vertexforface_indinrange(num_cell_faces):# select vertices forming the given faceface_pinds=list(np.arange(num_cell_faces))face_pinds.pop(face_ind)# calculate normal to the face# in 3D: cross product of two vectors lying in the face plane# in 2D: (-ty, tx) for a vector (tx, ty) along the facep0=cell_vertices[:,face_pinds[0]]p01=cell_vertices[:,face_pinds[1]]-p0p0Opp=cell_vertices[:,face_ind]-p0ifnum_dims==3:p02=cell_vertices[:,face_pinds[2]]-p0n=np.cross(p01,p02)else:n=np.roll(p01,1,axis=1)n[:,0]=-n[:,0]n_norm=np.linalg.norm(n,axis=1)n=n/n_norm[:,None]# compute distance to the opposing vertex by taking a dot product between normal# and a vector connecting the opposing vertex and the faced=np.einsum("ij,ij->i",n,p0Opp)# obtained normal direction is arbitrary here. We will orient it such that it points# away from the triangle (and distance to the opposing vertex is negative).to_flip=d>0d[to_flip]*=-1n[to_flip,:]*=-1# set distances in degenerate triangles to something positive to ignore laterdist_zero=d==0ifany(dist_zero):d[dist_zero]=1# record obtained infonormal[face_ind]=ndist[face_ind]=d# now we all set up to proceed with actual interpolation at each sample point# the main idea here is that:# - we use `cell_map` to grab normals and distances# of cells in which the given sample point is (potentially) located.# - use `x_ind, y_ind, z_ind` to find actual coordinates of a given sample point# - combine the above two to calculate cell SDF and interpolated value at a given sample# point# - having cell SDF at the sample point actually tells us whether its inside the cell# (keep value) or outside of it (discard interpolated value)# to perform SDF calculation and interpolation we will loop face by face and recording# their contributions. That is,# cell_sdf = max(face0_sdf, face1_sdf, ...)# interpolated_value = value0 * face0_sdf / dist0_sdf + ...# (because face0_sdf / dist0_sdf is linear shape function for vertex0)sdf=-inf*np.ones(num_samples_total)interpolated=np.zeros(num_samples_total,dtype=self._double_type)# coordinates of each sample pointsample_xyz=np.zeros((num_samples_total,num_dims))sample_xyz[:,0]=xyz_grid[0][x_inds]sample_xyz[:,1]=xyz_grid[1][y_inds]ifnum_dims==3:sample_xyz[:,2]=xyz_grid[2][z_inds]# loop face by faceforface_indinrange(num_cell_faces):# find a vector connecting sample point and faceifface_ind==0:vertex_ind=1# anythin other than 0vec=sample_xyz-cell_vertices[step_cell_map,vertex_ind,:]ifface_ind==1:# since three faces share a point only do this oncevertex_ind=0# it belongs to every face 1, 2, and 3vec=sample_xyz-cell_vertices[step_cell_map,0,:]# compute distance from every sample point to the face of corresponding cell# using dot producttmp=normal[face_ind,step_cell_map,:]*vecd=np.sum(tmp,axis=1)# take max between distance to obtain the overall SDF of a cellsdf=np.maximum(sdf,d)# perform linear interpolation. Here we use the fact that when computing face SDF# at a given point and dividing it by the distance to the opposing vertex we get# a linear shape function for that vertex. So, we just need to multiply that by# the data value at that vertex to find its contribution into intepolated value.# (decomposed in an attempt to reduce memory consumption)tmp=self._double_type(data_values[cell_connections[step_cell_map,face_ind]])tmp*=dtmp/=dist[face_ind,step_cell_map]# ignore degenerate cellsdist_zero=dist[face_ind,step_cell_map]>0ifany(dist_zero):sdf[dist_zero]=10*sdf_tolinterpolated+=tmp# The resulting array of interpolated values contain multiple candidate values for# every Cartesian point because bounding boxes of cells overlap.# Thus, we need to keep only those that come cell actually containing a given point.# This can be easily determined by the sign of the cell SDF sampled at a given point.valid_samples=sdf<sdf_tolinterpolated_valid=interpolated[valid_samples]xyz_valid_inds=[]xyz_valid_inds.append(x_inds[valid_samples])xyz_valid_inds.append(y_inds[valid_samples])ifnum_dims==3:xyz_valid_inds.append(z_inds[valid_samples])returnxyz_valid_inds,interpolated_valid@abstractmethod@requires_vtkdefsel(self,x:Union[float,ArrayLike]=None,y:Union[float,ArrayLike]=None,z:Union[float,ArrayLike]=None,)->Union[UnstructuredGridDataset,SpatialDataArray]:"""Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z must be provided. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. Returns ------- Union[TriangularGridDataset, SpatialDataArray] Extracted data. """@requires_vtkdefsel_inside(self,bounds:Bound)->UnstructuredGridDataset:"""Return a new UnstructuredGridDataset that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- UnstructuredGridDataset Extracted spatial data array. """ifany(bmin>bmaxforbmin,bmaxinzip(*bounds)):raiseDataError("Min and max bounds must be packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``.")data_bounds=self.boundstol=1e-6# For extracting cells covering target region we use vtk's filter that extract cells based# on provided implicit function. However, when we provide to it the implicit function of# the entire box, it has a couple of issues coming from the fact that the algorithm# eliminates every cells for which the implicit function has positive sign at all vertices.# As result, sometimes there are cells that despite overlaping with the target domain still# being eliminated. Two common cases:# - near corners of the target domain# - target domain is very thin# That's why we perform selection by sequentially eliminating cells on the outer side of# each of the 6 surfaces of the bounding box separately.tmp=self._vtk_objfordirectioninrange(2):fordiminrange(3):sign=-1+2*directionplane_pos=bounds[direction][dim]# Dealing with situation when target region does intersect with any cell:# in this case we shift target region so that it barely touches at least some# of cellsifsign<0andplane_pos>data_bounds[1][dim]-tol:plane_pos=data_bounds[1][dim]-tolifsign>0andplane_pos<data_bounds[0][dim]+tol:plane_pos=data_bounds[0][dim]+tol# if all cells are on the inside side of the plane for a given surface# we don't need to check for intersectionifplane_pos<=data_bounds[1][dim]andplane_pos>=data_bounds[0][dim]:plane=vtk["mod"].vtkPlane()center=[0,0,0]normal=[0,0,0]center[dim]=plane_posnormal[dim]=signplane.SetOrigin(center)plane.SetNormal(normal)extractor=vtk["mod"].vtkExtractGeometry()extractor.SetImplicitFunction(plane)extractor.ExtractInsideOn()extractor.ExtractBoundaryCellsOn()extractor.SetInputData(tmp)extractor.Update()tmp=extractor.GetOutput()returnself._from_vtk_obj(tmp,remove_degenerate_cells=True,remove_unused_points=True)defdoes_cover(self,bounds:Bound)->bool:"""Check whether data fully covers specified by ``bounds`` spatial region. If data contains only one point along a given direction, then it is assumed the data is constant along that direction and coverage is not checked. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- bool Full cover check outcome. """returnall((dmin<=sminanddmax>=smax)fordmin,dmax,smin,smaxinzip(self.bounds[0],self.bounds[1],bounds[0],bounds[1]))@requires_vtkdefreflect(self,axis:Axis,center:float,reflection_only:bool=False)->UnstructuredGridDataset:"""Reflect unstructured data across the plane define by parameters ``axis`` and ``center``. By default the original data is preserved, setting ``reflection_only`` to ``True`` will produce only deflected data. Parameters ---------- axis : Literal[0, 1, 2] Normal direction of the reflection plane. center : float Location of the reflection plane along its normal direction. reflection_only : bool = False Return only reflected data. Returns ------- UnstructuredGridDataset Data after reflextion is performed. """reflector=vtk["mod"].vtkReflectionFilter()reflector.SetPlane([reflector.USE_X,reflector.USE_Y,reflector.USE_Z][axis])reflector.SetCenter(center)reflector.SetCopyInput(notreflection_only)reflector.SetInputData(self._vtk_obj)reflector.Update()returnself._from_vtk_obj(reflector.GetOutput())
[docs]classTriangularGridDataset(UnstructuredGridDataset):"""Dataset for storing triangular grid data. Data values are associated with the nodes of the grid. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> tri_grid_points = PointDataArray( ... [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], ... coords=dict(index=np.arange(4), axis=np.arange(2)), ... ) >>> >>> tri_grid_cells = CellDataArray( ... [[0, 1, 2], [1, 2, 3]], ... coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)), ... ) >>> >>> tri_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tri_grid = TriangularGridDataset( ... normal_axis=1, ... normal_pos=0, ... points=tri_grid_points, ... cells=tri_grid_cells, ... values=tri_grid_values, ... ) """normal_axis:Axis=pd.Field(...,title="Grid Axis",description="Orientation of the grid.",)normal_pos:float=pd.Field(...,title="Position",description="Coordinate of the grid along the normal direction.",)@cached_propertydefbounds(self)->Bound:"""Grid bounds."""bounds_2d=super().boundsbounds_3d=self._points_2d_to_3d(bounds_2d)returntuple(bounds_3d[0]),tuple(bounds_3d[1])@classmethoddef_point_dims(cls)->pd.PositiveInt:"""Dimensionality of stored grid point coordinates."""return2def_points_2d_to_3d(self,pts:ArrayLike)->ArrayLike:"""Convert 2d points into 3d points."""returnnp.insert(pts,obj=self.normal_axis,values=self.normal_pos,axis=1)@cached_propertydef_points_3d_array(self)->ArrayLike:"""3D representation of grid points."""returnself._points_2d_to_3d(self.points.data)@classmethoddef_cell_num_vertices(cls)->pd.PositiveInt:"""Number of vertices in a cell."""return3@classmethod@requires_vtkdef_vtk_cell_type(cls):"""VTK cell type to use in the VTK representation."""returnvtk["mod"].VTK_TRIANGLE@classmethod@requires_vtkdef_from_vtk_obj(cls,vtk_obj,field=None,remove_degenerate_cells:bool=False,remove_unused_points:bool=False,):"""Initialize from a vtkUnstructuredGrid instance."""# get points cells data from vtk objectifisinstance(vtk_obj,vtk["mod"].vtkPolyData):cells_vtk=vtk_obj.GetPolys()elifisinstance(vtk_obj,vtk["mod"].vtkUnstructuredGrid):cells_vtk=vtk_obj.GetCells()cells_numpy=vtk["vtk_to_numpy"](cells_vtk.GetConnectivityArray())cell_offsets=vtk["vtk_to_numpy"](cells_vtk.GetOffsetsArray())ifnotnp.all(np.diff(cell_offsets)==cls._cell_num_vertices()):raiseDataError("Only triangular 'vtkUnstructuredGrid' or 'vtkPolyData' can be converted into ""'TriangularGridDataset'.")points_numpy=vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData())# data values are read directly into Tidy3D arrayvalues=cls._get_values_from_vtk(vtk_obj,len(points_numpy),field)# detect zero size dimensionbounds=np.max(points_numpy,axis=0)-np.min(points_numpy,axis=0)zero_dims=np.where(np.isclose(bounds,0))[0]iflen(zero_dims)!=1:raiseDataError(f"Provided vtk grid does not represent a two dimensional grid. Found zero size dimensions are {zero_dims}.")normal_axis=zero_dims[0]normal_pos=points_numpy[0][normal_axis]tan_dims=[0,1,2]tan_dims.remove(normal_axis)# convert 3d coordinates into 2dpoints_2d_numpy=points_numpy[:,tan_dims]# create Tidy3D points and cells arraysnum_cells=len(cells_numpy)//cls._cell_num_vertices()cells_numpy=np.reshape(cells_numpy,(num_cells,cls._cell_num_vertices()))cells=CellDataArray(cells_numpy,coords=dict(cell_index=np.arange(num_cells),vertex_index=np.arange(cls._cell_num_vertices())),)points=PointDataArray(points_2d_numpy,coords=dict(index=np.arange(len(points_numpy)),axis=np.arange(cls._point_dims())),)ifremove_degenerate_cells:cells=cls._remove_degenerate_cells(cells=cells)ifremove_unused_points:points,values,cells=cls._remove_unused_points(points=points,values=values,cells=cells)returncls(normal_axis=normal_axis,normal_pos=normal_pos,points=points,cells=cells,values=values,)
[docs]@requires_vtkdefplane_slice(self,axis:Axis,pos:float)->SpatialDataArray:"""Slice data with a plane and return the resulting line as a SpatialDataArray. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- SpatialDataArray The resulting slice. """ifaxis==self.normal_axis:raiseDataError(f"Triangular grid (normal: {self.normal_axis}) cannot be sliced by a parallel ""plane.")# perform slicing in vtk and get unprocessed points and valuesslice_vtk=self._plane_slice_raw(axis=axis,pos=pos)points_numpy=vtk["vtk_to_numpy"](slice_vtk.GetPoints().GetData())values=self._get_values_from_vtk(slice_vtk,len(points_numpy))# axis of the resulting lineslice_axis=3-self.normal_axis-axis# sort found intersection in ascending ordersorting=np.argsort(points_numpy[:,slice_axis],kind="mergesort")# assemble coords for SpatialDataArraycoords=[None,None,None]coords[axis]=[pos]coords[self.normal_axis]=[self.normal_pos]coords[slice_axis]=points_numpy[sorting,slice_axis]coords_dict=dict(zip("xyz",coords))# reshape values from a 1d array into a 3d arraynew_shape=[1,1,1]new_shape[slice_axis]=len(values)values_reshaped=np.reshape(values.data[sorting],new_shape)returnSpatialDataArray(values_reshaped,coords=coords_dict,name=self.values.name)
@propertydef_triangulation_obj(self)->Triangulation:"""Matplotlib triangular representation of the grid to use in plotting."""returnTriangulation(self.points[:,0],self.points[:,1],self.cells)
[docs]@equal_aspect@add_ax_if_nonedefplot(self,ax:Ax=None,field:bool=True,grid:bool=True,cbar:bool=True,cmap:str="viridis",vmin:float=None,vmax:float=None,shading:Literal["gourand","flat"]="gouraud",cbar_kwargs:Dict=None,pcolor_kwargs:Dict=None,)->Ax:"""Plot the data field and/or the unstructured grid. Parameters ---------- ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. field : bool = True Whether to plot the data field. grid : bool = True Whether to plot the unstructured grid. cbar : bool = True Display colorbar (only if ``field == True``). cmap : str = "viridis" Color map to use for plotting. vmin : float = None The lower bound of data range that the colormap covers. If ``None``, they are inferred from the data and other keyword arguments. vmax : float = None The upper bound of data range that the colormap covers. If ``None``, they are inferred from the data and other keyword arguments. shading : Literal["gourand", "flat"] = "gourand" Type of shading to use when plotting the data field. cbar_kwargs : Dict = {} Additional parameters passed to colorbar object. Returns ------- matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ifcbar_kwargsisNone:cbar_kwargs={}ifpcolor_kwargsisNone:pcolor_kwargs={}ifnot(fieldorgrid):raiseDataError("Nothing to plot ('field == False', 'grid == False').")# plot data field if requestediffield:plot_obj=ax.tripcolor(self._triangulation_obj,self.values.data,shading=shading,cmap=cmap,vmin=vmin,vmax=vmax,**pcolor_kwargs,)ifcbar:label_kwargs={}if"label"notincbar_kwargs:label_kwargs["label"]=self.values.nameplt.colorbar(plot_obj,**cbar_kwargs,**label_kwargs)# plot grid if requestedifgrid:ax.triplot(self._triangulation_obj,color=plot_params_grid.edgecolor,linewidth=plot_params_grid.linewidth,)# set labels and titlesax_labels=["x","y","z"]normal_axis_name=ax_labels.pop(self.normal_axis)ax.set_xlabel(ax_labels[0])ax.set_ylabel(ax_labels[1])ax.set_title(f"{normal_axis_name} = {self.normal_pos}")returnax
[docs]definterp(self,x:Union[float,ArrayLike],y:Union[float,ArrayLike],z:Union[float,ArrayLike],fill_value:Union[float,Literal["extrapolate"]]=None,use_vtk:bool=False,method:Literal["linear","nearest"]="linear",ignore_normal_pos:bool=True,max_samples_per_step:int=DEFAULT_MAX_SAMPLES_PER_STEP,max_cells_per_step:int=DEFAULT_MAX_CELLS_PER_STEP,rel_tol:float=DEFAULT_TOLERANCE_CELL_FINDING,)->SpatialDataArray:"""Interpolate data at provided x, y, and z. Note that data is assumed to be invariant along the dataset's normal direction. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : Union[float, Literal["extrapolate"]] = 0 Value to use when filling points without interpolated values. If ``"extrapolate"`` then nearest values are used. Note: in a future version the default value will be changed to ``"extrapolate"``. use_vtk : bool = False Use vtk's interpolation functionality or Tidy3D's own implementation. Note: this option will be removed in a future version. method: Literal["linear", "nearest"] = "linear" Interpolation method to use. ignore_normal_pos : bool = True (Depreciated) Assume data is invariant along the normal direction to the grid plane. max_samples_per_step : int = 1e4 Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int = 1e4 Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float = 1e-6 Relative tolerance when determining whether a point belongs to a cell. Returns ------- SpatialDataArray Interpolated data. """iffill_valueisNone:log.warning("Default parameter setting 'fill_value=0' will be changed to ""'fill_value=``extrapolate``' in a future version.")fill_value=0ifnotignore_normal_pos:log.warning("Parameter 'ignore_normal_pos' is depreciated. It is always assumed that data ""contained in 'TriangularGridDataset' is invariant in the normal direction. ""That is, 'ignore_normal_pos=True' is used.")x=np.atleast_1d(x)y=np.atleast_1d(y)z=np.atleast_1d(z)xyz=[x,y,z]xyz[self.normal_axis]=[self.normal_pos]interp_inplane=super().interp(**dict(zip("xyz",xyz)),fill_value=fill_value,use_vtk=use_vtk,method=method,max_samples_per_step=max_samples_per_step,max_cells_per_step=max_cells_per_step,)interp_broadcasted=np.broadcast_to(interp_inplane,[len(np.atleast_1d(comp))forcompin[x,y,z]])returnSpatialDataArray(interp_broadcasted,coords=dict(x=x,y=y,z=z),name=self.values.name)
def_interp_py(self,x:ArrayLike,y:ArrayLike,z:ArrayLike,fill_value:float,max_samples_per_step:int,max_cells_per_step:int,rel_tol:float,)->ArrayLike:"""2D-specific function to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """returnself._interp_py_general(x=x,y=y,z=z,fill_value=fill_value,max_samples_per_step=max_samples_per_step,max_cells_per_step=max_cells_per_step,rel_tol=rel_tol,axis_ignore=self.normal_axis,)
[docs]@requires_vtkdefsel(self,x:Union[float,ArrayLike]=None,y:Union[float,ArrayLike]=None,z:Union[float,ArrayLike]=None,)->SpatialDataArray:"""Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z must be provided. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. Returns ------- SpatialDataArray Extracted data. """xyz=[x,y,z]axes=[indforind,compinenumerate(xyz)ifcompisnotNone]num_provided=len(axes)ifself.normal_axisinaxes:ifxyz[self.normal_axis]!=self.normal_pos:raiseDataError(f"No data for {'xyz'[self.normal_axis]} = {xyz[self.normal_axis]} (unstructured"f" grid is defined at {'xyz'[self.normal_axis]} = {self.normal_pos}).")ifnum_provided<3:num_provided-=1axes.remove(self.normal_axis)ifnum_provided==0:raiseDataError("At least one of 'x', 'y', and 'z' must be specified.")ifnum_provided==1:axis=axes[0]returnself.plane_slice(axis=axis,pos=xyz[axis])ifnum_provided==2:pos=[x,y,z]pos[self.normal_axis]=[self.normal_pos]returnself.interp(x=pos[0],y=pos[1],z=pos[2])ifnum_provided==3:returnself.interp(x=x,y=y,z=z)
[docs]@requires_vtkdefreflect(self,axis:Axis,center:float,reflection_only:bool=False)->UnstructuredGridDataset:"""Reflect unstructured data across the plane define by parameters ``axis`` and ``center``. By default the original data is preserved, setting ``reflection_only`` to ``True`` will produce only deflected data. Parameters ---------- axis : Literal[0, 1, 2] Normal direction of the reflection plane. center : float Location of the reflection plane along its normal direction. reflection_only : bool = False Return only reflected data. Returns ------- UnstructuredGridDataset Data after reflextion is performed. """# disallow reflecting along normal directionifaxis==self.normal_axis:ifreflection_only:returnself.updated_copy(normal_pos=2*center-self.normal_pos)else:raiseDataError("Reflection in the normal direction to the grid is prohibited unless 'reflection_only=True'.")returnsuper().reflect(axis=axis,center=center,reflection_only=reflection_only)
[docs]@requires_vtkdefsel_inside(self,bounds:Bound)->TriangularGridDataset:"""Return a new ``TriangularGridDataset`` that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- TriangularGridDataset Extracted spatial data array. """ifany(bmin>bmaxforbmin,bmaxinzip(*bounds)):raiseDataError("Min and max bounds must be packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``.")# expand along normal directionnew_bounds=[list(bounds[0]),list(bounds[1])]new_bounds[0][self.normal_axis]=-infnew_bounds[1][self.normal_axis]=infreturnsuper().sel_inside(new_bounds)
[docs]defdoes_cover(self,bounds:Bound)->bool:"""Check whether data fully covers specified by ``bounds`` spatial region. If data contains only one point along a given direction, then it is assumed the data is constant along that direction and coverage is not checked. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- bool Full cover check outcome. """# expand along normal directionnew_bounds=[list(bounds[0]),list(bounds[1])]new_bounds[0][self.normal_axis]=self.normal_posnew_bounds[1][self.normal_axis]=self.normal_posreturnsuper().does_cover(new_bounds)
[docs]classTetrahedralGridDataset(UnstructuredGridDataset):"""Dataset for storing tetrahedral grid data. Data values are associated with the nodes of the grid. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> tet_grid_points = PointDataArray( ... [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], ... coords=dict(index=np.arange(4), axis=np.arange(3)), ... ) >>> >>> tet_grid_cells = CellDataArray( ... [[0, 1, 2, 3]], ... coords=dict(cell_index=np.arange(1), vertex_index=np.arange(4)), ... ) >>> >>> tet_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tet_grid = TetrahedralGridDataset( ... points=tet_grid_points, ... cells=tet_grid_cells, ... values=tet_grid_values, ... ) """@classmethoddef_point_dims(cls)->pd.PositiveInt:"""Dimensionality of stored grid point coordinates."""return3@cached_propertydef_points_3d_array(self)->Bound:"""3D coordinates of grid points."""returnself.points.data@classmethoddef_cell_num_vertices(cls)->pd.PositiveInt:"""Number of vertices in a cell."""return4@classmethod@requires_vtkdef_vtk_cell_type(cls):"""VTK cell type to use in the VTK representation."""returnvtk["mod"].VTK_TETRA@classmethod@requires_vtkdef_from_vtk_obj(cls,grid,field=None,remove_degenerate_cells:bool=False,remove_unused_points:bool=False,)->TetrahedralGridDataset:"""Initialize from a vtkUnstructuredGrid instance."""# read point, cells, and values info from a vtk instancecells_numpy=vtk["vtk_to_numpy"](grid.GetCells().GetConnectivityArray())points_numpy=vtk["vtk_to_numpy"](grid.GetPoints().GetData())values=cls._get_values_from_vtk(grid,len(points_numpy),field)# verify cell_typescells_types=vtk["vtk_to_numpy"](grid.GetCellTypesArray())ifnotnp.all(cells_types==cls._vtk_cell_type()):raiseDataError("Only tetrahedral 'vtkUnstructuredGrid' is currently supported")# pack point and cell information into Tidy3D arraysnum_cells=len(cells_numpy)//cls._cell_num_vertices()cells_numpy=np.reshape(cells_numpy,(num_cells,cls._cell_num_vertices()))cells=CellDataArray(cells_numpy,coords=dict(cell_index=np.arange(num_cells),vertex_index=np.arange(cls._cell_num_vertices())),)points=PointDataArray(points_numpy,coords=dict(index=np.arange(len(points_numpy)),axis=np.arange(cls._point_dims())),)ifremove_degenerate_cells:cells=cls._remove_degenerate_cells(cells=cells)ifremove_unused_points:points,values,cells=cls._remove_unused_points(points=points,values=values,cells=cells)returncls(points=points,cells=cells,values=values)
[docs]@requires_vtkdefplane_slice(self,axis:Axis,pos:float)->TriangularGridDataset:"""Slice data with a plane and return the resulting :class:.`TriangularGridDataset`. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- TriangularGridDataset The resulting slice. """slice_vtk=self._plane_slice_raw(axis=axis,pos=pos)returnTriangularGridDataset._from_vtk_obj(slice_vtk,remove_degenerate_cells=True,remove_unused_points=True)
[docs]@requires_vtkdefline_slice(self,axis:Axis,pos:Coordinate)->SpatialDataArray:"""Slice data with a line and return the resulting :class:.`SpatialDataArray`. Parameters ---------- axis : Axis The axis of the slicing line. pos : Tuple[float, float, float] Position of the slicing line. Returns ------- SpatialDataArray The resulting slice. """bounds=self.boundsstart=list(pos)end=list(pos)start[axis]=bounds[0][axis]end[axis]=bounds[1][axis]# create cutting planeline=vtk["mod"].vtkLineSource()line.SetPoint1(start)line.SetPoint2(end)line.SetResolution(1)# this should be done using vtkProbeLineFilter# but for some reason it crashes Python# so, we use a workaround:# 1) extract cells that are intersected by line (to speed up further slicing)# 2) do plane slice along first direction# 3) do second plane slice along second directionprober=vtk["mod"].vtkExtractCellsAlongPolyLine()prober.SetSourceConnection(line.GetOutputPort())prober.SetInputData(self._vtk_obj)prober.Update()extracted_cells_vtk=prober.GetOutput()ifextracted_cells_vtk.GetNumberOfPoints()==0:raiseDataError("Slicing line does not intersect the unstructured grid.")extracted_cells=TetrahedralGridDataset._from_vtk_obj(extracted_cells_vtk,remove_degenerate_cells=True,remove_unused_points=True)tan_dims=[0,1,2]tan_dims.remove(axis)# first plane sliceplane_slice=extracted_cells.plane_slice(axis=tan_dims[0],pos=pos[tan_dims[0]])# second plane sliceline_slice=plane_slice.plane_slice(axis=tan_dims[1],pos=pos[tan_dims[1]])returnline_slice
[docs]@requires_vtkdefsel(self,x:Union[float,ArrayLike]=None,y:Union[float,ArrayLike]=None,z:Union[float,ArrayLike]=None,)->Union[TriangularGridDataset,SpatialDataArray]:"""Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z must be provided. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. Returns ------- Union[TriangularGridDataset, SpatialDataArray] Extracted data. """xyz=[x,y,z]axes=[indforind,compinenumerate(xyz)ifcompisnotNone]num_provided=len(axes)ifnum_provided<3andany(notnp.isscalar(comp)forcompinxyzifcompisnotNone):raiseDataError("Providing x, y, or z as array is only allowed for interpolation. That is, when all"" three x, y, and z are provided or method '.interp()' is used explicitly.")ifnum_provided==0:raiseDataError("At least one of 'x', 'y', and 'z' must be specified.")ifnum_provided==1:axis=axes[0]returnself.plane_slice(axis=axis,pos=xyz[axis])ifnum_provided==2:axis=3-axes[0]-axes[1]xyz[axis]=0returnself.line_slice(axis=axis,pos=xyz)ifnum_provided==3:returnself.interp(x=x,y=y,z=z)
def_interp_py(self,x:ArrayLike,y:ArrayLike,z:ArrayLike,fill_value:float,max_samples_per_step:int,max_cells_per_step:int,rel_tol:float,)->ArrayLike:"""3D-specific function to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """returnself._interp_py_general(x=x,y=y,z=z,fill_value=fill_value,max_samples_per_step=max_samples_per_step,max_cells_per_step=max_cells_per_step,rel_tol=rel_tol,axis_ignore=None,)
UnstructuredGridDatasetType=Union[TriangularGridDataset,TetrahedralGridDataset]CustomSpatialDataType=Union[SpatialDataArray,UnstructuredGridDatasetType]CustomSpatialDataTypeAnnotated=Union[SpatialDataArray,annotate_type(UnstructuredGridDatasetType)]def_get_numpy_array(data_array:Union[ArrayLike,DataArray,UnstructuredGridDataset])->ArrayLike:"""Get numpy representation of dataarray/dataset values."""ifisinstance(data_array,UnstructuredGridDataset):returndata_array.values.valuesifisinstance(data_array,xr.DataArray):returndata_array.valuesreturnnp.array(data_array)def_zeros_like(data_array:Union[ArrayLike,xr.DataArray,UnstructuredGridDataset])->Union[ArrayLike,xr.DataArray,UnstructuredGridDataset]:"""Get a zeroed replica of dataarray/dataset."""ifisinstance(data_array,UnstructuredGridDataset):returndata_array.updated_copy(values=xr.zeros_like(data_array.values))ifisinstance(data_array,xr.DataArray):returnxr.zeros_like(data_array)returnnp.zeros_like(data_array)def_ones_like(data_array:Union[ArrayLike,xr.DataArray,UnstructuredGridDataset])->Union[ArrayLike,xr.DataArray,UnstructuredGridDataset]:"""Get a unity replica of dataarray/dataset."""ifisinstance(data_array,UnstructuredGridDataset):returndata_array.updated_copy(values=xr.ones_like(data_array.values))ifisinstance(data_array,xr.DataArray):returnxr.ones_like(data_array)returnnp.ones_like(data_array)def_check_same_coordinates(a:Union[ArrayLike,xr.DataArray,UnstructuredGridDataset],b:Union[ArrayLike,xr.DataArray,UnstructuredGridDataset],)->bool:"""Check whether two array are defined at the same coordinates."""# we can have xarray.DataArray's of different types but still same coordinates# we will deal with that case separatelyboth_xarrays=isinstance(a,xr.DataArray)andisinstance(b,xr.DataArray)if(notboth_xarrays)andtype(a)!=type(b):returnFalseifisinstance(a,UnstructuredGridDataset):ifnotnp.allclose(a.points,b.points)ornotnp.all(a.cells==b.cells):returnFalseifisinstance(a,TriangularGridDataset):ifa.normal_axis!=b.normal_axisora.normal_pos!=b.normal_pos:returnFalseelifisinstance(a,xr.DataArray):ifa.coords.keys()!=b.coords.keys()ora.coords!=b.coords:returnFalseelse:ifnp.shape(a)!=np.shape(b):returnFalsereturnTrue