Source code for tidy3d.components.data.dataset

"""Collections of DataArrays."""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Union, Dict, Callable, Any, Tuple

import xarray as xr
import numpy as np
import pydantic.v1 as pd
from matplotlib.tri import Triangulation
from matplotlib import pyplot as plt
import numbers

from .data_array import DataArray, DATA_ARRAY_MAP
from .data_array import ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray
from .data_array import ModeIndexDataArray, GroupIndexDataArray, ModeDispersionDataArray
from .data_array import TriangleMeshDataArray
from .data_array import TimeDataArray
from .data_array import PointDataArray, IndexedDataArray, CellDataArray, SpatialDataArray
from .data_array import EMEScalarFieldDataArray, EMEScalarModeFieldDataArray

from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
from ..base import Tidy3dBaseModel, cached_property
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal, annotate_type
from ...packaging import vtk, requires_vtk
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
from ...constants import PICOSECOND_PER_NANOMETER_PER_KILOMETER, inf
from ...log import log


DEFAULT_MAX_SAMPLES_PER_STEP = 10_000
DEFAULT_MAX_CELLS_PER_STEP = 10_000
DEFAULT_TOLERANCE_CELL_FINDING = 1e-6


class Dataset(Tidy3dBaseModel, ABC):
    """Abstract base class for objects that store collections of `:class:`.DataArray`s."""


[docs] class AbstractFieldDataset(Dataset, ABC): """Collection of scalar fields with some symmetry properties.""" @property @abstractmethod def field_components(self) -> Dict[str, DataArray]: """Maps the field components to their associated data."""
[docs] def apply_phase(self, phase: float) -> AbstractFieldDataset: """Create a copy where all elements are phase-shifted by a value (in radians).""" if phase == 0.0: return self phasor = np.exp(1j * phase) field_components_shifted = {} for fld_name, fld_cmp in self.field_components.items(): fld_cmp_shifted = phasor * fld_cmp field_components_shifted[fld_name] = fld_cmp_shifted return self.updated_copy(**field_components_shifted)
@property @abstractmethod def grid_locations(self) -> Dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" @property @abstractmethod def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues."""
[docs] def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArray]) -> Any: """How to package the dictionary of fields computed via self.colocate().""" return xr.Dataset(centered_fields)
[docs] def colocate(self, x=None, y=None, z=None) -> xr.Dataset: """Colocate all of the data at a set of x, y, z coordinates. Parameters ---------- x : Optional[array-like] = None x coordinates of locations. If not supplied, does not try to colocate on this dimension. y : Optional[array-like] = None y coordinates of locations. If not supplied, does not try to colocate on this dimension. z : Optional[array-like] = None z coordinates of locations. If not supplied, does not try to colocate on this dimension. Returns ------- xr.Dataset Dataset containing all fields at the same spatial locations. For more details refer to `xarray's Documentation <https://tinyurl.com/cyca3krz>`_. Note ---- For many operations (such as flux calculations and plotting), it is important that the fields are colocated at the same spatial locations. Be sure to apply this method to your field data in those cases. """ if hasattr(self, "monitor") and self.monitor.colocate: with log as consolidated_logger: consolidated_logger.warning( "Colocating data that has already been colocated during the solver " "run. For most accurate results when colocating to custom coordinates set " "'Monitor.colocate' to 'False' to use the raw data on the Yee grid " "and avoid double interpolation. Note: the default value was changed to 'True' " "in Tidy3D version 2.4.0." ) # convert supplied coordinates to array and assign string mapping to them supplied_coord_map = {k: np.array(v) for k, v in zip("xyz", (x, y, z)) if v is not None} # dict of data arrays to combine in dataset and return centered_fields = {} # loop through field components for field_name, field_data in self.field_components.items(): # loop through x, y, z dimensions and raise an error if only one element along dim for coord_name, coords_supplied in supplied_coord_map.items(): coord_data = np.array(field_data.coords[coord_name]) if coord_data.size == 1: raise DataError( f"colocate given {coord_name}={coords_supplied}, but " f"data only has one coordinate at {coord_name}={coord_data[0]}. " "Therefore, can't colocate along this dimension. " f"supply {coord_name}=None to skip it." ) centered_fields[field_name] = field_data.interp( **supplied_coord_map, kwargs={"bounds_error": True} ) # combine all centered fields in a dataset return self.package_colocate_results(centered_fields)
EMScalarFieldType = Union[ ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray, EMEScalarModeFieldDataArray, EMEScalarFieldDataArray, ] class ElectromagneticFieldDataset(AbstractFieldDataset, ABC): """Stores a collection of E and H fields with x, y, z components.""" Ex: EMScalarFieldType = pd.Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) Ey: EMScalarFieldType = pd.Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) Ez: EMScalarFieldType = pd.Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) Hx: EMScalarFieldType = pd.Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) Hy: EMScalarFieldType = pd.Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) Hz: EMScalarFieldType = pd.Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", ) @property def field_components(self) -> Dict[str, DataArray]: """Maps the field components to their associated data.""" fields = { "Ex": self.Ex, "Ey": self.Ey, "Ez": self.Ez, "Hx": self.Hx, "Hy": self.Hy, "Hz": self.Hz, } return {field_name: field for field_name, field in fields.items() if field is not None} @property def grid_locations(self) -> Dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" return dict(Ex="Ex", Ey="Ey", Ez="Ez", Hx="Hx", Hy="Hy", Hz="Hz") @property def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues.""" return dict( Ex=lambda dim: -1 if (dim == 0) else +1, Ey=lambda dim: -1 if (dim == 1) else +1, Ez=lambda dim: -1 if (dim == 2) else +1, Hx=lambda dim: +1 if (dim == 0) else -1, Hy=lambda dim: +1 if (dim == 1) else -1, Hz=lambda dim: +1 if (dim == 2) else -1, )
[docs] class FieldDataset(ElectromagneticFieldDataset): """Dataset storing a collection of the scalar components of E and H fields in the freq. domain Example ------- >>> x = [-1,1] >>> y = [-2,0,2] >>> z = [-3,-1,1,3] >>> f = [2e14, 3e14] >>> coords = dict(x=x, y=y, z=z, f=f) >>> scalar_field = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) >>> data = FieldDataset(Ex=scalar_field, Hz=scalar_field) """ Ex: ScalarFieldDataArray = pd.Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) Ey: ScalarFieldDataArray = pd.Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) Ez: ScalarFieldDataArray = pd.Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) Hx: ScalarFieldDataArray = pd.Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) Hy: ScalarFieldDataArray = pd.Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) Hz: ScalarFieldDataArray = pd.Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", )
[docs] class FieldTimeDataset(ElectromagneticFieldDataset): """Dataset storing a collection of the scalar components of E and H fields in the time domain Example ------- >>> x = [-1,1] >>> y = [-2,0,2] >>> z = [-3,-1,1,3] >>> t = [0, 1e-12, 2e-12] >>> coords = dict(x=x, y=y, z=z, t=t) >>> scalar_field = ScalarFieldTimeDataArray(np.random.random((2,3,4,3)), coords=coords) >>> data = FieldTimeDataset(Ex=scalar_field, Hz=scalar_field) """ Ex: ScalarFieldTimeDataArray = pd.Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) Ey: ScalarFieldTimeDataArray = pd.Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) Ez: ScalarFieldTimeDataArray = pd.Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) Hx: ScalarFieldTimeDataArray = pd.Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) Hy: ScalarFieldTimeDataArray = pd.Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) Hz: ScalarFieldTimeDataArray = pd.Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", )
[docs] def apply_phase(self, phase: float) -> AbstractFieldDataset: """Create a copy where all elements are phase-shifted by a value (in radians).""" if phase != 0.0: raise ValueError("Can't apply phase to time-domain field data, which is real-valued.") return self
[docs] class ModeSolverDataset(ElectromagneticFieldDataset): """Dataset storing scalar components of E and H fields as a function of freq. and mode_index. Example ------- >>> from tidy3d import ModeSpec >>> x = [-1,1] >>> y = [0] >>> z = [-3,-1,1,3] >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> field_coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index) >>> field = ScalarModeFieldDataArray((1+1j)*np.random.random((2,1,4,2,5)), coords=field_coords) >>> index_coords = dict(f=f, mode_index=mode_index) >>> index_data = ModeIndexDataArray((1+1j) * np.random.random((2,5)), coords=index_coords) >>> data = ModeSolverDataset( ... Ex=field, ... Ey=field, ... Ez=field, ... Hx=field, ... Hy=field, ... Hz=field, ... n_complex=index_data ... ) """ Ex: ScalarModeFieldDataArray = pd.Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) Ey: ScalarModeFieldDataArray = pd.Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) Ez: ScalarModeFieldDataArray = pd.Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) Hx: ScalarModeFieldDataArray = pd.Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) Hy: ScalarModeFieldDataArray = pd.Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) Hz: ScalarModeFieldDataArray = pd.Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) n_complex: ModeIndexDataArray = pd.Field( ..., title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) n_group_raw: GroupIndexDataArray = pd.Field( None, alias="n_group", # This is for backwards compatibility only when loading old data title="Group Index", description="Index associated with group velocity of the mode.", ) dispersion_raw: ModeDispersionDataArray = pd.Field( None, title="Dispersion", description="Dispersion parameter for the mode.", units=PICOSECOND_PER_NANOMETER_PER_KILOMETER, ) @property def field_components(self) -> Dict[str, DataArray]: """Maps the field components to their associated data.""" fields = { "Ex": self.Ex, "Ey": self.Ey, "Ez": self.Ez, "Hx": self.Hx, "Hy": self.Hy, "Hz": self.Hz, } return {field_name: field for field_name, field in fields.items() if field is not None} @property def n_eff(self) -> ModeIndexDataArray: """Real part of the propagation index.""" return self.n_complex.real @property def k_eff(self) -> ModeIndexDataArray: """Imaginary part of the propagation index.""" return self.n_complex.imag @property def n_group(self) -> GroupIndexDataArray: """Group index.""" if self.n_group_raw is None: log.warning( "The group index was not computed. To calculate group index, pass " "'group_index_step = True' in the 'ModeSpec'.", log_once=True, ) return self.n_group_raw @property def dispersion(self) -> ModeDispersionDataArray: r"""Dispersion parameter. .. math:: D = -\frac{\lambda}{c_0} \frac{{\rm d}^2 n_{\text{eff}}}{{\rm d}\lambda^2} """ if self.dispersion_raw is None: log.warning( "The dispersion was not computed. To calculate dispersion, pass " "'group_index_step = True' in the 'ModeSpec'.", log_once=True, ) return self.dispersion_raw
[docs] def plot_field(self, *args, **kwargs): """Warn user to use the :class:`.ModeSolver` ``plot_field`` function now.""" raise DeprecationWarning( "The 'plot_field()' method was moved to the 'ModeSolver' object." "Once the 'ModeSolver' is constructed, one may call '.plot_field()' on the object and " "the modes will be computed and displayed with 'Simulation' overlay." )
[docs] class PermittivityDataset(AbstractFieldDataset): """Dataset storing the diagonal components of the permittivity tensor. Example ------- >>> x = [-1,1] >>> y = [-2,0,2] >>> z = [-3,-1,1,3] >>> f = [2e14, 3e14] >>> coords = dict(x=x, y=y, z=z, f=f) >>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) >>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld) """ @property def field_components(self) -> Dict[str, ScalarFieldDataArray]: """Maps the field components to their associated data.""" return dict(eps_xx=self.eps_xx, eps_yy=self.eps_yy, eps_zz=self.eps_zz) @property def grid_locations(self) -> Dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" return dict(eps_xx="Ex", eps_yy="Ey", eps_zz="Ez") @property def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues.""" return dict(eps_xx=None, eps_yy=None, eps_zz=None) eps_xx: ScalarFieldDataArray = pd.Field( ..., title="Epsilon xx", description="Spatial distribution of the xx-component of the relative permittivity.", ) eps_yy: ScalarFieldDataArray = pd.Field( ..., title="Epsilon yy", description="Spatial distribution of the yy-component of the relative permittivity.", ) eps_zz: ScalarFieldDataArray = pd.Field( ..., title="Epsilon zz", description="Spatial distribution of the zz-component of the relative permittivity.", )
class TriangleMeshDataset(Dataset): """Dataset for storing triangular surface data.""" surface_mesh: TriangleMeshDataArray = pd.Field( ..., title="Surface mesh data", description="Dataset containing the surface triangles and corresponding face indices " "for a surface mesh.", ) class TimeDataset(Dataset): """Dataset for storing a function of time.""" values: TimeDataArray = pd.Field( ..., title="Values", description="Values as a function of time." ) class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC): """Abstract base for datasets that store unstructured grid data.""" points: PointDataArray = pd.Field( ..., title="Grid Points", description="Coordinates of points composing the unstructured grid.", ) values: IndexedDataArray = pd.Field( ..., title="Point Values", description="Values stored at the grid points.", ) cells: CellDataArray = pd.Field( ..., title="Grid Cells", description="Cells composing the unstructured grid specified as connections between grid " "points.", ) @property def name(self) -> str: """Dataset name.""" # we redirect name to values.name return self.values.name @property def is_complex(self) -> bool: """Data type.""" return np.iscomplexobj(self.values) @property def _double_type(self): """Corresponding double data type.""" return np.complex128 if self.is_complex else np.float64 @pd.validator("points", always=True) def points_right_dims(cls, val): """Check that point coordinates have the right dimensionality.""" # currently support only the standard axis ordering, that is 01(2) axis_coords_expected = np.arange(cls._point_dims()) axis_coords_given = val.axis.data if np.any(axis_coords_given != axis_coords_expected): raise ValidationError( f"Points array is expected to have {axis_coords_expected} coord values along 'axis'" f" (given: {axis_coords_given})." ) return val @property def is_uniform(self): """Whether each element is of equal value in ``values``.""" return self.values.is_uniform @pd.validator("points", always=True) def points_right_indexing(cls, val): """Check that points are indexed corrrectly.""" indices_expected = np.arange(len(val.data)) indices_given = val.index.data if np.any(indices_expected != indices_given): raise ValidationError( "Coordinate 'index' of array 'points' is expected to have values (0, 1, 2, ...). " "This can be easily achieved, for example, by using " "PointDataArray(data, dims=['index', 'axis'])." ) return val @pd.validator("values", always=True) def values_right_indexing(cls, val): """Check that data values are indexed correctly.""" # currently support only simple ordered indexing of points, that is, 0, 1, 2, ... indices_expected = np.arange(len(val.data)) indices_given = val.index.data if np.any(indices_expected != indices_given): raise ValidationError( "Coordinate 'index' of array 'values' is expected to have values (0, 1, 2, ...). " "This can be easily achieved, for example, by using " "IndexedDataArray(data, dims=['index'])." ) return val @pd.validator("values", always=True) @skip_if_fields_missing(["points"]) def number_of_values_matches_points(cls, val, values): """Check that the number of data values matches the number of grid points.""" num_values = len(val) points = values.get("points") num_points = len(points) if num_points != num_values: raise ValidationError( f"The number of data values ({num_values}) does not match the number of grid " f"points ({num_points})." ) return val @pd.validator("cells", always=True) def match_cells_to_vtk_type(cls, val): """Check that cell connections does not have duplicate points.""" if vtk is None: return val # using val.astype(np.int32/64) directly causes issues when dataarray are later checked == return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords) @pd.validator("cells", always=True) def cells_right_type(cls, val): """Check that cell are of the right type.""" # only supporting the standard ordering of cell vertices 012(3) vertex_coords_expected = np.arange(cls._cell_num_vertices()) vertex_coords_given = val.vertex_index.data if np.any(vertex_coords_given != vertex_coords_expected): raise ValidationError( f"Cell connections array is expected to have {vertex_coords_expected} coord values" f" along 'vertex_index' (given: {vertex_coords_given})." ) return val @pd.validator("cells", always=True) @skip_if_fields_missing(["points"]) def check_cell_vertex_range(cls, val, values): """Check that cell connections use only defined points.""" all_point_indices_used = val.data.ravel() # skip validation if zero size data if len(all_point_indices_used) > 0: min_index_used = np.min(all_point_indices_used) max_index_used = np.max(all_point_indices_used) points = values.get("points") num_points = len(points) if max_index_used > num_points - 1 or min_index_used < 0: raise ValidationError( "Cell connections array uses undefined point indices in the range " f"[{min_index_used}, {max_index_used}]. The valid range of point indices is " f"[0, {num_points-1}]." ) return val @classmethod def _find_degenerate_cells(cls, cells: CellDataArray): """Find explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """ indices = cells.data # skip validation if zero size data degenerate_cell_inds = set() if len(indices) > 0: for i in range(cls._cell_num_vertices() - 1): for j in range(i + 1, cls._cell_num_vertices()): degenerate_cell_inds = degenerate_cell_inds.union( np.where(indices[:, i] == indices[:, j])[0] ) return degenerate_cell_inds @classmethod def _remove_degenerate_cells(cls, cells: CellDataArray): """Remove explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """ degenerate_cells = cls._find_degenerate_cells(cells=cells) if len(degenerate_cells) > 0: data = np.delete(cells.values, list(degenerate_cells), axis=0) cell_index = np.delete(cells.cell_index.values, list(degenerate_cells)) return CellDataArray( data=data, coords=dict(cell_index=cell_index, vertex_index=cells.vertex_index) ) return cells @classmethod def _remove_unused_points( cls, points: PointDataArray, values: IndexedDataArray, cells: CellDataArray ): """Remove unused points if any. That is, points that are not used in any grid cell. """ used_indices = np.unique(cells.values.ravel()) num_points = len(points) if len(used_indices) != num_points or np.any(np.diff(used_indices) != 1): min_index = np.min(used_indices) map_len = np.max(used_indices) - min_index + 1 index_map = np.zeros(map_len) index_map[used_indices - min_index] = np.arange(len(used_indices)) cells = CellDataArray(data=index_map[cells.data - min_index], coords=cells.coords) points = PointDataArray(points.data[used_indices, :], dims=["index", "axis"]) values = IndexedDataArray(values.data[used_indices], dims=["index"]) return points, values, cells def clean(self, remove_degenerate_cells=True, remove_unused_points=True): """Remove degenerate cells and/or unused points.""" if remove_degenerate_cells: cells = self._remove_degenerate_cells(cells=self.cells) else: cells = self.cells if remove_unused_points: points, values, cells = self._remove_unused_points(self.points, self.values, cells) else: points = self.points values = self.values return self.updated_copy(points=points, values=values, cells=cells) @pd.validator("cells", always=True) def warn_degenerate_cells(cls, val): """Check that cell connections does not have duplicate points.""" degenerate_cells = cls._find_degenerate_cells(val) num_degenerate_cells = len(degenerate_cells) if num_degenerate_cells > 0: log.warning( f"Unstructured grid contains {num_degenerate_cells} degenerate cell(s). " "Such cells can be removed by using function " "'.clean(remove_degenerate_cells: bool = True, remove_unused_points: bool = True)'. " "For example, 'dataset = dataset.clean()'." ) return val @pd.root_validator(pre=True, allow_reuse=True) def _warn_if_none(cls, values): """Warn if any of data arrays are not loaded.""" no_data_fields = [] for field_name in ["points", "cells", "values"]: field = values.get(field_name) if isinstance(field, str) and field in DATA_ARRAY_MAP.keys(): no_data_fields.append(field_name) if len(no_data_fields) > 0: formatted_names = [f"'{fname}'" for fname in no_data_fields] log.warning( f"Loading {', '.join(formatted_names)} without data. Constructing an empty dataset." ) values["points"] = PointDataArray( np.zeros((0, cls._point_dims())), dims=["index", "axis"] ) values["cells"] = CellDataArray( np.zeros((0, cls._cell_num_vertices())), dims=["cell_index", "vertex_index"] ) values["values"] = IndexedDataArray(np.zeros(0), dims=["index"]) return values @pd.root_validator(skip_on_failure=True, allow_reuse=True) def _warn_unused_points(cls, values): """Warn if some points are unused.""" point_indices = set(np.arange(len(values["values"].data))) used_indices = set(values["cells"].values.ravel()) if not point_indices.issubset(used_indices): log.warning( "Unstructured grid dataset contains unused points. " "Consider calling 'clean()' to remove them." ) return values def rename(self, name: str) -> UnstructuredGridDataset: """Return a renamed array.""" return self.updated_copy(values=self.values.rename(name)) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Override of numpy functions.""" out = kwargs.get("out", ()) for x in inputs + out: # Only support operations with the same class or a scalar if not isinstance(x, (numbers.Number, type(self))): return Tidy3dNotImplementedError # Defer to the implementation of the ufunc on unwrapped values. inputs = tuple(x.values if isinstance(x, UnstructuredGridDataset) else x for x in inputs) if out: kwargs["out"] = tuple( x.values if isinstance(x, UnstructuredGridDataset) else x for x in out ) result = getattr(ufunc, method)(*inputs, **kwargs) if type(result) is tuple: # multiple return values return tuple(self.updated_copy(values=x) for x in result) elif method == "at": # no return value return None else: # one return value return self.updated_copy(values=result) @property def real(self) -> UnstructuredGridDataset: """Real part of dataset.""" return self.updated_copy(values=self.values.real) @property def imag(self) -> UnstructuredGridDataset: """Imaginary part of dataset.""" return self.updated_copy(values=self.values.imag) @property def abs(self) -> UnstructuredGridDataset: """Absolute value of dataset.""" return self.updated_copy(values=self.values.abs) @cached_property def bounds(self) -> Bound: """Grid bounds.""" return tuple(np.min(self.points.data, axis=0)), tuple(np.max(self.points.data, axis=0)) @classmethod @abstractmethod def _point_dims(cls) -> pd.PositiveInt: """Dimensionality of stored grid point coordinates.""" @cached_property @abstractmethod def _points_3d_array(self): """3D coordinates of grid points.""" @classmethod @abstractmethod def _cell_num_vertices(cls) -> pd.PositiveInt: """Number of vertices in a cell.""" @classmethod @abstractmethod @requires_vtk def _vtk_cell_type(cls): """VTK cell type to use in the VTK representation.""" @cached_property def _vtk_offsets(self) -> ArrayLike: """Offsets array to use in the VTK representation.""" offsets = np.arange(len(self.cells) + 1) * self._cell_num_vertices() if vtk is None: return offsets return offsets.astype(vtk["id_type"], copy=False) @property @requires_vtk def _vtk_cells(self): """VTK cell array to use in the VTK representation.""" cells = vtk["mod"].vtkCellArray() cells.SetData( vtk["numpy_to_vtkIdTypeArray"](self._vtk_offsets), vtk["numpy_to_vtkIdTypeArray"](self.cells.data.ravel()), ) return cells @property @requires_vtk def _vtk_points(self): """VTK point array to use in the VTK representation.""" pts = vtk["mod"].vtkPoints() pts.SetData(vtk["numpy_to_vtk"](self._points_3d_array)) return pts @property @requires_vtk def _vtk_obj(self): """A VTK representation (vtkUnstructuredGrid) of the grid.""" grid = vtk["mod"].vtkUnstructuredGrid() grid.SetPoints(self._vtk_points) grid.SetCells(self._vtk_cell_type(), self._vtk_cells) if self.is_complex: # vtk doesn't support complex numbers # so we will store our complex array as a two-component vtk array data_values = self.values.values.view("(2,)float") else: data_values = self.values.values point_data_vtk = vtk["numpy_to_vtk"](data_values) point_data_vtk.SetName(self.values.name) grid.GetPointData().AddArray(point_data_vtk) return grid @requires_vtk def _plane_slice_raw(self, axis: Axis, pos: float): """Slice data with a plane and return the resulting VTK object.""" if pos > self.bounds[1][axis] or pos < self.bounds[0][axis]: raise DataError( f"Slicing plane (axis: {axis}, pos: {pos}) does not intersect the unstructured grid " f"(extent along axis {axis}: {self.bounds[0][axis]}, {self.bounds[1][axis]})." ) origin = [0, 0, 0] origin[axis] = pos normal = [0, 0, 0] # orientation of normal is important for edge (literally) cases normal[axis] = -1 if pos > (self.bounds[0][axis] + self.bounds[1][axis]) / 2: normal[axis] = 1 # create cutting plane plane = vtk["mod"].vtkPlane() plane.SetOrigin(origin[0], origin[1], origin[2]) plane.SetNormal(normal[0], normal[1], normal[2]) # create cutter cutter = vtk["mod"].vtkPlaneCutter() cutter.SetPlane(plane) cutter.SetInputData(self._vtk_obj) cutter.InterpolateAttributesOn() cutter.Update() # clean up the slice cleaner = vtk["mod"].vtkCleanPolyData() cleaner.SetInputData(cutter.GetOutput()) cleaner.Update() return cleaner.GetOutput() @abstractmethod @requires_vtk def plane_slice( self, axis: Axis, pos: float ) -> Union[SpatialDataArray, UnstructuredGridDataset]: """Slice data with a plane and return the Tidy3D representation of the result (``UnstructuredGridDataset``). Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- Union[SpatialDataArray, UnstructuredGridDataset] The resulting slice. """ @staticmethod @requires_vtk def _read_vtkUnstructuredGrid(fname: str): """Load a :class:`vtkUnstructuredGrid` from a file.""" reader = vtk["mod"].vtkXMLUnstructuredGridReader() reader.SetFileName(fname) reader.Update() grid = reader.GetOutput() return grid @classmethod @abstractmethod @requires_vtk def _from_vtk_obj( cls, vtk_obj, field: str = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, ) -> UnstructuredGridDataset: """Initialize from a vtk object.""" @classmethod @requires_vtk def from_vtu( cls, file: str, field: str = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, ) -> UnstructuredGridDataset: """Load unstructured data from a vtu file. Parameters ---------- fname : str Full path to the .vtu file to load the unstructured data from. field : str = None Name of the field to load. remove_degenerate_cells : bool = False Remove explicitly degenerate cells. remove_unused_points : bool = False Remove unused points. Returns ------- UnstructuredGridDataset Unstructured data. """ grid = cls._read_vtkUnstructuredGrid(file) return cls._from_vtk_obj( grid, field=field, remove_degenerate_cells=remove_degenerate_cells, remove_unused_points=remove_unused_points, ) @requires_vtk def to_vtu(self, fname: str): """Exports unstructured grid data into a .vtu file. Parameters ---------- fname : str Full path to the .vtu file to save the unstructured data to. """ writer = vtk["mod"].vtkXMLUnstructuredGridWriter() writer.SetFileName(fname) writer.SetInputData(self._vtk_obj) writer.Write() @classmethod @requires_vtk def _get_values_from_vtk( cls, vtk_obj, num_points: pd.PositiveInt, field: str = None ) -> IndexedDataArray: """Get point data values from a VTK object.""" point_data = vtk_obj.GetPointData() num_point_arrays = point_data.GetNumberOfArrays() if num_point_arrays == 0: log.warning( "No point data is found in a VTK object. '.values' will be initialized to zeros." ) values_numpy = np.zeros(num_points) values_name = None else: if field is not None: array_vtk = point_data.GetAbstractArray(field) else: array_vtk = point_data.GetAbstractArray(0) # currently we assume there is only one point data array provided in the VTK object if num_point_arrays > 1 and field is None: array_name = array_vtk.GetName() log.warning( f"{num_point_arrays} point data arrays are found in a VTK object. " f"Only the first array (name: {array_name}) will be used to initialize " "'.values' while the rest will be ignored." ) # currently we assume data is real or complex scalar num_components = array_vtk.GetNumberOfComponents() if num_components > 2: raise DataError( "Found point data array in a VTK object is expected to have maximum 2 " "components (1 is for real data, 2 is for complex data). " f"Found {num_components} components." ) # check that number of values matches number of grid points num_tuples = array_vtk.GetNumberOfTuples() if num_tuples != num_points: raise DataError( f"The length of found point data array ({num_tuples}) does not match the number" f" of grid points ({num_points})." ) values_numpy = vtk["vtk_to_numpy"](array_vtk) values_name = array_vtk.GetName() # vtk doesn't support complex numbers # we store our complex array as a two-component vtk array # so here we convert that into a single component complex array if num_components == 2: values_numpy = values_numpy.view("complex")[:, 0] values = IndexedDataArray( values_numpy, coords=dict(index=np.arange(len(values_numpy))), name=values_name ) return values @requires_vtk def box_clip(self, bounds: Bound) -> UnstructuredGridDataset: """Clip the unstructured grid using a box defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- UnstructuredGridDataset Clipped grid. """ # make and run a VTK clipper clipper = vtk["mod"].vtkBoxClipDataSet() clipper.SetOrientation(0) clipper.SetBoxClip( bounds[0][0], bounds[1][0], bounds[0][1], bounds[1][1], bounds[0][2], bounds[1][2] ) clipper.SetInputData(self._vtk_obj) clipper.GenerateClipScalarsOn() clipper.GenerateClippedOutputOff() clipper.Update() clip = clipper.GetOutput() # clean grid from unused points grid_cleaner = vtk["mod"].vtkRemoveUnusedPoints() grid_cleaner.SetInputData(clip) grid_cleaner.GenerateOriginalPointIdsOff() grid_cleaner.Update() clean_clip = grid_cleaner.GetOutput() # no intersection check if clean_clip.GetNumberOfPoints() == 0: raise DataError("Clipping box does not intersect the unstructured grid.") return self._from_vtk_obj( clean_clip, remove_degenerate_cells=True, remove_unused_points=True ) def interp( self, x: Union[float, ArrayLike], y: Union[float, ArrayLike], z: Union[float, ArrayLike], fill_value: Union[float, Literal["extrapolate"]] = None, use_vtk: bool = False, method: Literal["linear", "nearest"] = "linear", max_samples_per_step: int = DEFAULT_MAX_SAMPLES_PER_STEP, max_cells_per_step: int = DEFAULT_MAX_CELLS_PER_STEP, rel_tol: float = DEFAULT_TOLERANCE_CELL_FINDING, ) -> SpatialDataArray: """Interpolate data at provided x, y, and z. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : Union[float, Literal["extrapolate"]] = 0 Value to use when filling points without interpolated values. If ``"extrapolate"`` then nearest values are used. Note: in a future version the default value will be changed to ``"extrapolate"``. use_vtk : bool = False Use vtk's interpolation functionality or Tidy3D's own implementation. Note: this option will be removed in a future version. method: Literal["linear", "nearest"] = "linear" Interpolation method to use. max_samples_per_step : int = 1e4 Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int = 1e4 Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float = 1e-6 Relative tolerance when determining whether a point belongs to a cell. Returns ------- SpatialDataArray Interpolated data. """ if fill_value is None: log.warning( "Default parameter setting 'fill_value=0' will be changed to " "'fill_value=``extrapolate``' in a future version." ) fill_value = 0 # calculate the resulting array shape x = np.atleast_1d(x) y = np.atleast_1d(y) z = np.atleast_1d(z) if method == "nearest": interpolated_values = self._interp_nearest(x=x, y=y, z=z) else: if fill_value == "extrapolate": fill_value_actual = np.nan else: fill_value_actual = fill_value if use_vtk: log.warning("Note that option 'use_vtk=True' will be removed in future versions.") interpolated_values = self._interp_vtk(x=x, y=y, z=z, fill_value=fill_value_actual) else: interpolated_values = self._interp_py( x=x, y=y, z=z, fill_value=fill_value_actual, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, rel_tol=rel_tol, ) if fill_value == "extrapolate" and method != "nearest": interpolated_values = self._fill_nans_from_nearests( interpolated_values, x=x, y=y, z=z ) return SpatialDataArray( interpolated_values, coords=dict(x=x, y=y, z=z), name=self.values.name ) def _interp_nearest( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, ) -> ArrayLike: """Interpolate data at provided x, y, and z using Scipy's nearest neighbor interpolator. Parameters ---------- x : ArrayLike x-coordinates of sampling points. y : ArrayLike y-coordinates of sampling points. z : ArrayLike z-coordinates of sampling points. Returns ------- ArrayLike Interpolated data. """ from scipy.interpolate import NearestNDInterpolator # use scipy's nearest neighbor interpolator X, Y, Z = np.meshgrid(x, y, z, indexing="ij") interp = NearestNDInterpolator(self._points_3d_array, self.values.values) values = interp(X, Y, Z) return values def _fill_nans_from_nearests( self, values: ArrayLike, x: ArrayLike, y: ArrayLike, z: ArrayLike, ) -> ArrayLike: """Replace nan's in ``values`` with nearest data points. Parameters ---------- values : ArrayLike 3D array containing nan's x : ArrayLike x-coordinates of sampling points. y : ArrayLike y-coordinates of sampling points. z : ArrayLike z-coordinates of sampling points. Returns ------- ArrayLike Data without nan's. """ # locate all nans nans = np.isnan(values) if np.sum(nans) > 0: from scipy.interpolate import NearestNDInterpolator # use scipy's nearest neighbor interpolator X, Y, Z = np.meshgrid(x, y, z, indexing="ij") interp = NearestNDInterpolator(self._points_3d_array, self.values.values) values_to_replace_nans = interp(X[nans], Y[nans], Z[nans]) values[nans] = values_to_replace_nans return values @requires_vtk def _interp_vtk( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, ) -> ArrayLike: """Interpolate data at provided x, y, and z using vtk package. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float = 0 Value to use when filling points without interpolated values. Returns ------- ArrayLike Interpolated data. """ shape = (len(x), len(y), len(z)) # create a VTK rectilinear grid to sample onto structured_grid = vtk["mod"].vtkRectilinearGrid() structured_grid.SetDimensions(shape) structured_grid.SetXCoordinates(vtk["numpy_to_vtk"](x)) structured_grid.SetYCoordinates(vtk["numpy_to_vtk"](y)) structured_grid.SetZCoordinates(vtk["numpy_to_vtk"](z)) # create and execute VTK interpolator interpolator = vtk["mod"].vtkResampleWithDataSet() interpolator.SetInputData(structured_grid) interpolator.SetSourceData(self._vtk_obj) interpolator.Update() interpolated = interpolator.GetOutput() # get results in a numpy representation array_id = 0 if self.values.name is None else self.values.name values_numpy = vtk["vtk_to_numpy"](interpolated.GetPointData().GetAbstractArray(array_id)) # fill points without interpolated values if fill_value != 0: mask = vtk["vtk_to_numpy"]( interpolated.GetPointData().GetAbstractArray("vtkValidPointMask") ) values_numpy[mask != 1] = fill_value # VTK arrays are the z-y-x order, reorder interpolation results to x-y-z order values_reordered = np.transpose(np.reshape(values_numpy, shape[::-1]), (2, 1, 0)) return values_reordered @abstractmethod def _interp_py( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, max_samples_per_step: int, max_cells_per_step: int, rel_tol: float, ) -> ArrayLike: """Dimensionality-specific function (2D and 3D) to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """ def _interp_py_general( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, max_samples_per_step: int, max_cells_per_step: int, rel_tol: float, axis_ignore: Union[Axis, None], ) -> ArrayLike: """A general function (2D and 3D) to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. axis_ignore : Union[Axis, None] When interpolating from a 2D dataset, must specify normal axis. Returns ------- ArrayLike Interpolated data. """ # get dimensionality of data num_dims = self._point_dims() if num_dims == 2 and axis_ignore is None: raise DataError("Must porvide 'axis_ignore' when interpolating from a 2d dataset.") xyz_grid = [x, y, z] if axis_ignore is not None: xyz_grid.pop(axis_ignore) # get numpy arrays for points and cells cell_connections = ( self.cells.values ) # (num_cells, num_cell_vertices), num_cell_vertices=num_cell_faces points = self.points.values # (num_points, num_dims) num_cells = len(cell_connections) num_points = len(points) # compute tolerances based on total size of unstructured grid bounds = self.bounds size = np.subtract(bounds[1], bounds[0]) tol = size * rel_tol diag_tol = np.linalg.norm(tol) # compute (index) positions of unstructured points w.r.t. target Cartesian grid points # (i.e. between which Cartesian grid points a given unstructured grid point is located) # we perturb grid values in both directions to make sure we don't miss any points # due to numerical precision xyz_pos_l = np.zeros((num_dims, num_points), dtype=int) xyz_pos_r = np.zeros((num_dims, num_points), dtype=int) for dim in range(num_dims): xyz_pos_l[dim] = np.searchsorted(xyz_grid[dim] + tol[dim], points[:, dim]) xyz_pos_r[dim] = np.searchsorted(xyz_grid[dim] - tol[dim], points[:, dim]) # let's allocate an array for resulting values # every time we process a chunk of samples, we will write into this array interpolated_values = fill_value + np.zeros( [len(xyz_comp) for xyz_comp in xyz_grid], dtype=self.values.dtype ) processed_cells_global = 0 # to ovoid OOM for large datasets, we process only certain number of cells at a time while processed_cells_global < num_cells: target_processed_cells_global = min( num_cells, processed_cells_global + max_cells_per_step ) connections_to_process = cell_connections[ processed_cells_global:target_processed_cells_global ] # now we transfer this information to each cell. That is, each cell knows how its vertices # positioned relative to Cartesian grid points. # (num_dims, num_cells, num_vertices=num_cell_faces) xyz_pos_l_per_cell = xyz_pos_l[:, connections_to_process] xyz_pos_r_per_cell = xyz_pos_r[:, connections_to_process] # taking min/max among all cell vertices (per each dimension separately) # we get min and max indices of Cartesian grid points that may receive their values # from a given cell. # (num_dims, num_cells) cell_ind_min = np.min(xyz_pos_l_per_cell, axis=2) cell_ind_max = np.max(xyz_pos_r_per_cell, axis=2) # calculate number of Cartesian grid points where we will perform interpolation for a given # cell. Note that this number is much larger than actually needed, because essentially for # each cell we consider all Cartesian grid points that fall into the cell's bounding box. # We use word "sample" to represent such Cartesian grid points. # (num_cells,) num_samples_per_cell = np.prod(cell_ind_max - cell_ind_min, axis=0) # find cells that have non-zero number of samples # we use "ne" as a shortcut for "non empty" ne_cells = num_samples_per_cell > 0 # (num_cells,) num_ne_cells = np.sum(ne_cells) # indices of cells with non-zero number of samples in the original list of cells # (num_cells,) ne_cell_inds = np.arange(processed_cells_global, target_processed_cells_global)[ ne_cells ] # restrict to non-empty cells only num_samples_per_ne_cell = num_samples_per_cell[ne_cells] cum_num_samples_per_ne_cell = np.cumsum(num_samples_per_ne_cell) ne_cell_ind_min = cell_ind_min[:, ne_cells] ne_cell_ind_max = cell_ind_max[:, ne_cells] # Next we need to perform actual interpolation at all sample points # this is computationally expensive operation and because we try to do everything # in the vectorized form, it can require a lot of memory, sometimes even causing OOM errors. # To avoid that, we impose restrictions on how many cells/samples can be processed at a time # effectivelly performing these operations in chunks. # Note that currently this is done sequentially, but could be relatively easy to parallelize # start counters of how many cells/samples have been processed processed_samples = 0 processed_cells = 0 while processed_cells < num_ne_cells: # how many cells we would like to process by the end of this step target_processed_cells = min(num_ne_cells, processed_cells + max_cells_per_step) # find how many cells we can processed based on number of allowed samples target_processed_samples = processed_samples + max_samples_per_step target_processed_cells_from_samples = ( np.searchsorted(cum_num_samples_per_ne_cell, target_processed_samples) + 1 ) # take min between the two target_processed_cells = min( target_processed_cells, target_processed_cells_from_samples ) # select cells and corresponding samples to process step_ne_cell_ind_min = ne_cell_ind_min[:, processed_cells:target_processed_cells] step_ne_cell_ind_max = ne_cell_ind_max[:, processed_cells:target_processed_cells] step_ne_cell_inds = ne_cell_inds[processed_cells:target_processed_cells] # process selected cells and points xyz_inds, interpolated = self._interp_py_chunk( xyz_grid=xyz_grid, cell_inds=step_ne_cell_inds, cell_ind_min=step_ne_cell_ind_min, cell_ind_max=step_ne_cell_ind_max, sdf_tol=diag_tol, ) if num_dims == 3: interpolated_values[xyz_inds[0], xyz_inds[1], xyz_inds[2]] = interpolated else: interpolated_values[xyz_inds[0], xyz_inds[1]] = interpolated processed_cells = target_processed_cells processed_samples = cum_num_samples_per_ne_cell[target_processed_cells - 1] processed_cells_global = target_processed_cells_global # in case of 2d grid broadcast results along normal direction assuming translational # invariance if num_dims == 2: orig_shape = [len(x), len(y), len(z)] flat_shape = orig_shape.copy() flat_shape[axis_ignore] = 1 interpolated_values = np.reshape(interpolated_values, flat_shape) interpolated_values = np.broadcast_to( interpolated_values, (len(x), len(y), len(z)) ).copy() return interpolated_values def _interp_py_chunk( self, xyz_grid: Tuple[ArrayLike[float], ...], cell_inds: ArrayLike[int], cell_ind_min: ArrayLike[int], cell_ind_max: ArrayLike[int], sdf_tol: float, ) -> Tuple[Tuple[ArrayLike, ...], ArrayLike]: """For each cell listed in ``cell_inds`` perform interpolation at a rectilinear subarray of xyz_grid given by a (3D) index span (cell_ind_min, cell_ind_max). Parameters ---------- xyz_grid : Tuple[ArrayLike[float], ...] x, y, and z coordiantes defining rectilinear grid. cell_inds : ArrayLike[int] Indices of cells to perfrom interpolation from. cell_ind_min : ArrayLike[int] Starting x, y, and z indices of points for interpolation for each cell. cell_ind_max : ArrayLike[int] End x, y, and z indices of points for interpolation for each cell. sdf_tol : float Effective zero level set value, below which a point is considered to be inside a cell. Returns ------- Tuple[Tuple[ArrayLike, ...], ArrayLike] x, y, and z indices of interpolated values and values themselves. """ # get dimensionality of data num_dims = self._point_dims() num_cell_faces = self._cell_num_vertices() # get mesh info as numpy arrays points = self.points.values # (num_points, num_dims) data_values = self.values.values # (num_points,) cell_connections = self.cells.values[cell_inds] # compute number of samples to generate per cell num_samples_per_cell = np.prod(cell_ind_max - cell_ind_min, axis=0) # at this point we know how many samples we need to perform per each cell and we also # know span indices of these samples (in x, y, and z arrays) # we would like to perform all interpolations in a vectorized form, however, we have # a different number of interpolation samples for different cells. Thus, we need to # arange all samples in a linear way (flatten). Basically, we want to have data in this # form: # cell_ind | x_ind | y_ind | z_ind # -------------------------------- # 0 | 23 | 5 | 11 # 0 | 23 | 5 | 12 # 0 | 23 | 6 | 11 # 0 | 23 | 6 | 12 # 1 | 41 | 11 | 0 # 1 | 42 | 11 | 0 # ... | ... | ... | ... # to do that we start with performing arange for each cell, but in vectorized way # this gives us something like this # [0, 1, 2, 3, 0, 1, 0, 1, 2, 3, 4, 5, 6, ...] # |<-cell 0->|<-cell 1->|<- cell 2 ->|<- ... num_cells = len(num_samples_per_cell) num_samples_cumul = num_samples_per_cell.cumsum() num_samples_total = num_samples_cumul[-1] # one big arange array inds_flat = np.arange(num_samples_total) # now subtract previous number of samples inds_flat[num_samples_per_cell[0] :] -= np.repeat( num_samples_cumul[:-1], num_samples_per_cell[1:] ) # convert flat indices into 3d/2d indices as: # x_ind = [23, 23, 23, 23, 41, 41, ...] # y_ind = [ 5, 5, 5, 5, 6, 6, ...] # z_ind = [11, 12, 11, 12, 0, 0, ...] # |<- cell 0 ->|<- cell 1 ->|<- ... num_samples_y = np.repeat(cell_ind_max[1] - cell_ind_min[1], num_samples_per_cell) # note: in 2d x, y correspond to (x, y, z).pop(normal_axis) if num_dims == 3: num_samples_z = np.repeat(cell_ind_max[2] - cell_ind_min[2], num_samples_per_cell) inds_flat, z_inds = np.divmod(inds_flat, num_samples_z) x_inds, y_inds = np.divmod(inds_flat, num_samples_y) start_inds = np.repeat(cell_ind_min, num_samples_per_cell, axis=1) x_inds = x_inds + start_inds[0] y_inds = y_inds + start_inds[1] if num_dims == 3: z_inds = z_inds + start_inds[2] # finally, we repeat cell indices corresponding number of times to obtain how # (x_ind, y_ind, z_ind) map to cell indices. So, now we have four arras: # x_ind = [23, 23, 23, 23, 41, 41, ...] # y_ind = [ 5, 5, 5, 5, 6, 6, ...] # z_ind = [11, 12, 11, 12, 0, 0, ...] # cell_map = [ 0, 0, 0, 0, 1, 1, ...] # |<- cell 0 ->|<- cell 1 ->|<- ... step_cell_map = np.repeat(np.arange(num_cells), num_samples_per_cell) # let's put these arrays aside for a moment and perform the second preparatory step # specifically, for each face of each cell we will compute normal vector and distance # to the opposing cell vertex. This will allows us quickly calculate SDF of a cell at # each sample point as well as perform linear interpolation. # first, we collect coordinates of cell vertices into a single array # (num_cells, num_cell_vertices, num_dims) cell_vertices = np.float64(points[cell_connections, :]) # array for resulting normals and distances normal = np.zeros((num_cell_faces, num_cells, num_dims)) dist = np.zeros((num_cell_faces, num_cells)) # loop face by face # note that by face_ind we denote both index of face in a cell and index of the opposing vertex for face_ind in range(num_cell_faces): # select vertices forming the given face face_pinds = list(np.arange(num_cell_faces)) face_pinds.pop(face_ind) # calculate normal to the face # in 3D: cross product of two vectors lying in the face plane # in 2D: (-ty, tx) for a vector (tx, ty) along the face p0 = cell_vertices[:, face_pinds[0]] p01 = cell_vertices[:, face_pinds[1]] - p0 p0Opp = cell_vertices[:, face_ind] - p0 if num_dims == 3: p02 = cell_vertices[:, face_pinds[2]] - p0 n = np.cross(p01, p02) else: n = np.roll(p01, 1, axis=1) n[:, 0] = -n[:, 0] n_norm = np.linalg.norm(n, axis=1) n = n / n_norm[:, None] # compute distance to the opposing vertex by taking a dot product between normal # and a vector connecting the opposing vertex and the face d = np.einsum("ij,ij->i", n, p0Opp) # obtained normal direction is arbitrary here. We will orient it such that it points # away from the triangle (and distance to the opposing vertex is negative). to_flip = d > 0 d[to_flip] *= -1 n[to_flip, :] *= -1 # set distances in degenerate triangles to something positive to ignore later dist_zero = d == 0 if any(dist_zero): d[dist_zero] = 1 # record obtained info normal[face_ind] = n dist[face_ind] = d # now we all set up to proceed with actual interpolation at each sample point # the main idea here is that: # - we use `cell_map` to grab normals and distances # of cells in which the given sample point is (potentially) located. # - use `x_ind, y_ind, z_ind` to find actual coordinates of a given sample point # - combine the above two to calculate cell SDF and interpolated value at a given sample # point # - having cell SDF at the sample point actually tells us whether its inside the cell # (keep value) or outside of it (discard interpolated value) # to perform SDF calculation and interpolation we will loop face by face and recording # their contributions. That is, # cell_sdf = max(face0_sdf, face1_sdf, ...) # interpolated_value = value0 * face0_sdf / dist0_sdf + ... # (because face0_sdf / dist0_sdf is linear shape function for vertex0) sdf = -inf * np.ones(num_samples_total) interpolated = np.zeros(num_samples_total, dtype=self._double_type) # coordinates of each sample point sample_xyz = np.zeros((num_samples_total, num_dims)) sample_xyz[:, 0] = xyz_grid[0][x_inds] sample_xyz[:, 1] = xyz_grid[1][y_inds] if num_dims == 3: sample_xyz[:, 2] = xyz_grid[2][z_inds] # loop face by face for face_ind in range(num_cell_faces): # find a vector connecting sample point and face if face_ind == 0: vertex_ind = 1 # anythin other than 0 vec = sample_xyz - cell_vertices[step_cell_map, vertex_ind, :] if face_ind == 1: # since three faces share a point only do this once vertex_ind = 0 # it belongs to every face 1, 2, and 3 vec = sample_xyz - cell_vertices[step_cell_map, 0, :] # compute distance from every sample point to the face of corresponding cell # using dot product tmp = normal[face_ind, step_cell_map, :] * vec d = np.sum(tmp, axis=1) # take max between distance to obtain the overall SDF of a cell sdf = np.maximum(sdf, d) # perform linear interpolation. Here we use the fact that when computing face SDF # at a given point and dividing it by the distance to the opposing vertex we get # a linear shape function for that vertex. So, we just need to multiply that by # the data value at that vertex to find its contribution into intepolated value. # (decomposed in an attempt to reduce memory consumption) tmp = self._double_type(data_values[cell_connections[step_cell_map, face_ind]]) tmp *= d tmp /= dist[face_ind, step_cell_map] # ignore degenerate cells dist_zero = dist[face_ind, step_cell_map] > 0 if any(dist_zero): sdf[dist_zero] = 10 * sdf_tol interpolated += tmp # The resulting array of interpolated values contain multiple candidate values for # every Cartesian point because bounding boxes of cells overlap. # Thus, we need to keep only those that come cell actually containing a given point. # This can be easily determined by the sign of the cell SDF sampled at a given point. valid_samples = sdf < sdf_tol interpolated_valid = interpolated[valid_samples] xyz_valid_inds = [] xyz_valid_inds.append(x_inds[valid_samples]) xyz_valid_inds.append(y_inds[valid_samples]) if num_dims == 3: xyz_valid_inds.append(z_inds[valid_samples]) return xyz_valid_inds, interpolated_valid @abstractmethod @requires_vtk def sel( self, x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, ) -> Union[UnstructuredGridDataset, SpatialDataArray]: """Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z must be provided. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. Returns ------- Union[TriangularGridDataset, SpatialDataArray] Extracted data. """ @requires_vtk def sel_inside(self, bounds: Bound) -> UnstructuredGridDataset: """Return a new UnstructuredGridDataset that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- UnstructuredGridDataset Extracted spatial data array. """ if any(bmin > bmax for bmin, bmax in zip(*bounds)): raise DataError( "Min and max bounds must be packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``." ) data_bounds = self.bounds tol = 1e-6 # For extracting cells covering target region we use vtk's filter that extract cells based # on provided implicit function. However, when we provide to it the implicit function of # the entire box, it has a couple of issues coming from the fact that the algorithm # eliminates every cells for which the implicit function has positive sign at all vertices. # As result, sometimes there are cells that despite overlaping with the target domain still # being eliminated. Two common cases: # - near corners of the target domain # - target domain is very thin # That's why we perform selection by sequentially eliminating cells on the outer side of # each of the 6 surfaces of the bounding box separately. tmp = self._vtk_obj for direction in range(2): for dim in range(3): sign = -1 + 2 * direction plane_pos = bounds[direction][dim] # Dealing with situation when target region does intersect with any cell: # in this case we shift target region so that it barely touches at least some # of cells if sign < 0 and plane_pos > data_bounds[1][dim] - tol: plane_pos = data_bounds[1][dim] - tol if sign > 0 and plane_pos < data_bounds[0][dim] + tol: plane_pos = data_bounds[0][dim] + tol # if all cells are on the inside side of the plane for a given surface # we don't need to check for intersection if plane_pos <= data_bounds[1][dim] and plane_pos >= data_bounds[0][dim]: plane = vtk["mod"].vtkPlane() center = [0, 0, 0] normal = [0, 0, 0] center[dim] = plane_pos normal[dim] = sign plane.SetOrigin(center) plane.SetNormal(normal) extractor = vtk["mod"].vtkExtractGeometry() extractor.SetImplicitFunction(plane) extractor.ExtractInsideOn() extractor.ExtractBoundaryCellsOn() extractor.SetInputData(tmp) extractor.Update() tmp = extractor.GetOutput() return self._from_vtk_obj(tmp, remove_degenerate_cells=True, remove_unused_points=True) def does_cover(self, bounds: Bound) -> bool: """Check whether data fully covers specified by ``bounds`` spatial region. If data contains only one point along a given direction, then it is assumed the data is constant along that direction and coverage is not checked. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- bool Full cover check outcome. """ return all( (dmin <= smin and dmax >= smax) for dmin, dmax, smin, smax in zip(self.bounds[0], self.bounds[1], bounds[0], bounds[1]) ) @requires_vtk def reflect( self, axis: Axis, center: float, reflection_only: bool = False ) -> UnstructuredGridDataset: """Reflect unstructured data across the plane define by parameters ``axis`` and ``center``. By default the original data is preserved, setting ``reflection_only`` to ``True`` will produce only deflected data. Parameters ---------- axis : Literal[0, 1, 2] Normal direction of the reflection plane. center : float Location of the reflection plane along its normal direction. reflection_only : bool = False Return only reflected data. Returns ------- UnstructuredGridDataset Data after reflextion is performed. """ reflector = vtk["mod"].vtkReflectionFilter() reflector.SetPlane([reflector.USE_X, reflector.USE_Y, reflector.USE_Z][axis]) reflector.SetCenter(center) reflector.SetCopyInput(not reflection_only) reflector.SetInputData(self._vtk_obj) reflector.Update() return self._from_vtk_obj(reflector.GetOutput())
[docs] class TriangularGridDataset(UnstructuredGridDataset): """Dataset for storing triangular grid data. Data values are associated with the nodes of the grid. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> tri_grid_points = PointDataArray( ... [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], ... coords=dict(index=np.arange(4), axis=np.arange(2)), ... ) >>> >>> tri_grid_cells = CellDataArray( ... [[0, 1, 2], [1, 2, 3]], ... coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)), ... ) >>> >>> tri_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tri_grid = TriangularGridDataset( ... normal_axis=1, ... normal_pos=0, ... points=tri_grid_points, ... cells=tri_grid_cells, ... values=tri_grid_values, ... ) """ normal_axis: Axis = pd.Field( ..., title="Grid Axis", description="Orientation of the grid.", ) normal_pos: float = pd.Field( ..., title="Position", description="Coordinate of the grid along the normal direction.", ) @cached_property def bounds(self) -> Bound: """Grid bounds.""" bounds_2d = super().bounds bounds_3d = self._points_2d_to_3d(bounds_2d) return tuple(bounds_3d[0]), tuple(bounds_3d[1]) @classmethod def _point_dims(cls) -> pd.PositiveInt: """Dimensionality of stored grid point coordinates.""" return 2 def _points_2d_to_3d(self, pts: ArrayLike) -> ArrayLike: """Convert 2d points into 3d points.""" return np.insert(pts, obj=self.normal_axis, values=self.normal_pos, axis=1) @cached_property def _points_3d_array(self) -> ArrayLike: """3D representation of grid points.""" return self._points_2d_to_3d(self.points.data) @classmethod def _cell_num_vertices(cls) -> pd.PositiveInt: """Number of vertices in a cell.""" return 3 @classmethod @requires_vtk def _vtk_cell_type(cls): """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TRIANGLE @classmethod @requires_vtk def _from_vtk_obj( cls, vtk_obj, field=None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, ): """Initialize from a vtkUnstructuredGrid instance.""" # get points cells data from vtk object if isinstance(vtk_obj, vtk["mod"].vtkPolyData): cells_vtk = vtk_obj.GetPolys() elif isinstance(vtk_obj, vtk["mod"].vtkUnstructuredGrid): cells_vtk = vtk_obj.GetCells() cells_numpy = vtk["vtk_to_numpy"](cells_vtk.GetConnectivityArray()) cell_offsets = vtk["vtk_to_numpy"](cells_vtk.GetOffsetsArray()) if not np.all(np.diff(cell_offsets) == cls._cell_num_vertices()): raise DataError( "Only triangular 'vtkUnstructuredGrid' or 'vtkPolyData' can be converted into " "'TriangularGridDataset'." ) points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) # data values are read directly into Tidy3D array values = cls._get_values_from_vtk(vtk_obj, len(points_numpy), field) # detect zero size dimension bounds = np.max(points_numpy, axis=0) - np.min(points_numpy, axis=0) zero_dims = np.where(np.isclose(bounds, 0))[0] if len(zero_dims) != 1: raise DataError( f"Provided vtk grid does not represent a two dimensional grid. Found zero size dimensions are {zero_dims}." ) normal_axis = zero_dims[0] normal_pos = points_numpy[0][normal_axis] tan_dims = [0, 1, 2] tan_dims.remove(normal_axis) # convert 3d coordinates into 2d points_2d_numpy = points_numpy[:, tan_dims] # create Tidy3D points and cells arrays num_cells = len(cells_numpy) // cls._cell_num_vertices() cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices())) cells = CellDataArray( cells_numpy, coords=dict( cell_index=np.arange(num_cells), vertex_index=np.arange(cls._cell_num_vertices()) ), ) points = PointDataArray( points_2d_numpy, coords=dict(index=np.arange(len(points_numpy)), axis=np.arange(cls._point_dims())), ) if remove_degenerate_cells: cells = cls._remove_degenerate_cells(cells=cells) if remove_unused_points: points, values, cells = cls._remove_unused_points( points=points, values=values, cells=cells ) return cls( normal_axis=normal_axis, normal_pos=normal_pos, points=points, cells=cells, values=values, )
[docs] @requires_vtk def plane_slice(self, axis: Axis, pos: float) -> SpatialDataArray: """Slice data with a plane and return the resulting line as a SpatialDataArray. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- SpatialDataArray The resulting slice. """ if axis == self.normal_axis: raise DataError( f"Triangular grid (normal: {self.normal_axis}) cannot be sliced by a parallel " "plane." ) # perform slicing in vtk and get unprocessed points and values slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) points_numpy = vtk["vtk_to_numpy"](slice_vtk.GetPoints().GetData()) values = self._get_values_from_vtk(slice_vtk, len(points_numpy)) # axis of the resulting line slice_axis = 3 - self.normal_axis - axis # sort found intersection in ascending order sorting = np.argsort(points_numpy[:, slice_axis], kind="mergesort") # assemble coords for SpatialDataArray coords = [None, None, None] coords[axis] = [pos] coords[self.normal_axis] = [self.normal_pos] coords[slice_axis] = points_numpy[sorting, slice_axis] coords_dict = dict(zip("xyz", coords)) # reshape values from a 1d array into a 3d array new_shape = [1, 1, 1] new_shape[slice_axis] = len(values) values_reshaped = np.reshape(values.data[sorting], new_shape) return SpatialDataArray(values_reshaped, coords=coords_dict, name=self.values.name)
@property def _triangulation_obj(self) -> Triangulation: """Matplotlib triangular representation of the grid to use in plotting.""" return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)
[docs] @equal_aspect @add_ax_if_none def plot( self, ax: Ax = None, field: bool = True, grid: bool = True, cbar: bool = True, cmap: str = "viridis", vmin: float = None, vmax: float = None, shading: Literal["gourand", "flat"] = "gouraud", cbar_kwargs: Dict = None, pcolor_kwargs: Dict = None, ) -> Ax: """Plot the data field and/or the unstructured grid. Parameters ---------- ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. field : bool = True Whether to plot the data field. grid : bool = True Whether to plot the unstructured grid. cbar : bool = True Display colorbar (only if ``field == True``). cmap : str = "viridis" Color map to use for plotting. vmin : float = None The lower bound of data range that the colormap covers. If ``None``, they are inferred from the data and other keyword arguments. vmax : float = None The upper bound of data range that the colormap covers. If ``None``, they are inferred from the data and other keyword arguments. shading : Literal["gourand", "flat"] = "gourand" Type of shading to use when plotting the data field. cbar_kwargs : Dict = {} Additional parameters passed to colorbar object. Returns ------- matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ if cbar_kwargs is None: cbar_kwargs = {} if pcolor_kwargs is None: pcolor_kwargs = {} if not (field or grid): raise DataError("Nothing to plot ('field == False', 'grid == False').") # plot data field if requested if field: plot_obj = ax.tripcolor( self._triangulation_obj, self.values.data, shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, **pcolor_kwargs, ) if cbar: label_kwargs = {} if "label" not in cbar_kwargs: label_kwargs["label"] = self.values.name plt.colorbar(plot_obj, **cbar_kwargs, **label_kwargs) # plot grid if requested if grid: ax.triplot( self._triangulation_obj, color=plot_params_grid.edgecolor, linewidth=plot_params_grid.linewidth, ) # set labels and titles ax_labels = ["x", "y", "z"] normal_axis_name = ax_labels.pop(self.normal_axis) ax.set_xlabel(ax_labels[0]) ax.set_ylabel(ax_labels[1]) ax.set_title(f"{normal_axis_name} = {self.normal_pos}") return ax
[docs] def interp( self, x: Union[float, ArrayLike], y: Union[float, ArrayLike], z: Union[float, ArrayLike], fill_value: Union[float, Literal["extrapolate"]] = None, use_vtk: bool = False, method: Literal["linear", "nearest"] = "linear", ignore_normal_pos: bool = True, max_samples_per_step: int = DEFAULT_MAX_SAMPLES_PER_STEP, max_cells_per_step: int = DEFAULT_MAX_CELLS_PER_STEP, rel_tol: float = DEFAULT_TOLERANCE_CELL_FINDING, ) -> SpatialDataArray: """Interpolate data at provided x, y, and z. Note that data is assumed to be invariant along the dataset's normal direction. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : Union[float, Literal["extrapolate"]] = 0 Value to use when filling points without interpolated values. If ``"extrapolate"`` then nearest values are used. Note: in a future version the default value will be changed to ``"extrapolate"``. use_vtk : bool = False Use vtk's interpolation functionality or Tidy3D's own implementation. Note: this option will be removed in a future version. method: Literal["linear", "nearest"] = "linear" Interpolation method to use. ignore_normal_pos : bool = True (Depreciated) Assume data is invariant along the normal direction to the grid plane. max_samples_per_step : int = 1e4 Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int = 1e4 Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float = 1e-6 Relative tolerance when determining whether a point belongs to a cell. Returns ------- SpatialDataArray Interpolated data. """ if fill_value is None: log.warning( "Default parameter setting 'fill_value=0' will be changed to " "'fill_value=``extrapolate``' in a future version." ) fill_value = 0 if not ignore_normal_pos: log.warning( "Parameter 'ignore_normal_pos' is depreciated. It is always assumed that data " "contained in 'TriangularGridDataset' is invariant in the normal direction. " "That is, 'ignore_normal_pos=True' is used." ) x = np.atleast_1d(x) y = np.atleast_1d(y) z = np.atleast_1d(z) xyz = [x, y, z] xyz[self.normal_axis] = [self.normal_pos] interp_inplane = super().interp( **dict(zip("xyz", xyz)), fill_value=fill_value, use_vtk=use_vtk, method=method, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, ) interp_broadcasted = np.broadcast_to( interp_inplane, [len(np.atleast_1d(comp)) for comp in [x, y, z]] ) return SpatialDataArray( interp_broadcasted, coords=dict(x=x, y=y, z=z), name=self.values.name )
def _interp_py( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, max_samples_per_step: int, max_cells_per_step: int, rel_tol: float, ) -> ArrayLike: """2D-specific function to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """ return self._interp_py_general( x=x, y=y, z=z, fill_value=fill_value, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, rel_tol=rel_tol, axis_ignore=self.normal_axis, )
[docs] @requires_vtk def sel( self, x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, ) -> SpatialDataArray: """Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z must be provided. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. Returns ------- SpatialDataArray Extracted data. """ xyz = [x, y, z] axes = [ind for ind, comp in enumerate(xyz) if comp is not None] num_provided = len(axes) if self.normal_axis in axes: if xyz[self.normal_axis] != self.normal_pos: raise DataError( f"No data for {'xyz'[self.normal_axis]} = {xyz[self.normal_axis]} (unstructured" f" grid is defined at {'xyz'[self.normal_axis]} = {self.normal_pos})." ) if num_provided < 3: num_provided -= 1 axes.remove(self.normal_axis) if num_provided == 0: raise DataError("At least one of 'x', 'y', and 'z' must be specified.") if num_provided == 1: axis = axes[0] return self.plane_slice(axis=axis, pos=xyz[axis]) if num_provided == 2: pos = [x, y, z] pos[self.normal_axis] = [self.normal_pos] return self.interp(x=pos[0], y=pos[1], z=pos[2]) if num_provided == 3: return self.interp(x=x, y=y, z=z)
[docs] @requires_vtk def reflect( self, axis: Axis, center: float, reflection_only: bool = False ) -> UnstructuredGridDataset: """Reflect unstructured data across the plane define by parameters ``axis`` and ``center``. By default the original data is preserved, setting ``reflection_only`` to ``True`` will produce only deflected data. Parameters ---------- axis : Literal[0, 1, 2] Normal direction of the reflection plane. center : float Location of the reflection plane along its normal direction. reflection_only : bool = False Return only reflected data. Returns ------- UnstructuredGridDataset Data after reflextion is performed. """ # disallow reflecting along normal direction if axis == self.normal_axis: if reflection_only: return self.updated_copy(normal_pos=2 * center - self.normal_pos) else: raise DataError( "Reflection in the normal direction to the grid is prohibited unless 'reflection_only=True'." ) return super().reflect(axis=axis, center=center, reflection_only=reflection_only)
[docs] @requires_vtk def sel_inside(self, bounds: Bound) -> TriangularGridDataset: """Return a new ``TriangularGridDataset`` that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- TriangularGridDataset Extracted spatial data array. """ if any(bmin > bmax for bmin, bmax in zip(*bounds)): raise DataError( "Min and max bounds must be packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``." ) # expand along normal direction new_bounds = [list(bounds[0]), list(bounds[1])] new_bounds[0][self.normal_axis] = -inf new_bounds[1][self.normal_axis] = inf return super().sel_inside(new_bounds)
[docs] def does_cover(self, bounds: Bound) -> bool: """Check whether data fully covers specified by ``bounds`` spatial region. If data contains only one point along a given direction, then it is assumed the data is constant along that direction and coverage is not checked. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- bool Full cover check outcome. """ # expand along normal direction new_bounds = [list(bounds[0]), list(bounds[1])] new_bounds[0][self.normal_axis] = self.normal_pos new_bounds[1][self.normal_axis] = self.normal_pos return super().does_cover(new_bounds)
[docs] class TetrahedralGridDataset(UnstructuredGridDataset): """Dataset for storing tetrahedral grid data. Data values are associated with the nodes of the grid. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> tet_grid_points = PointDataArray( ... [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], ... coords=dict(index=np.arange(4), axis=np.arange(3)), ... ) >>> >>> tet_grid_cells = CellDataArray( ... [[0, 1, 2, 3]], ... coords=dict(cell_index=np.arange(1), vertex_index=np.arange(4)), ... ) >>> >>> tet_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tet_grid = TetrahedralGridDataset( ... points=tet_grid_points, ... cells=tet_grid_cells, ... values=tet_grid_values, ... ) """ @classmethod def _point_dims(cls) -> pd.PositiveInt: """Dimensionality of stored grid point coordinates.""" return 3 @cached_property def _points_3d_array(self) -> Bound: """3D coordinates of grid points.""" return self.points.data @classmethod def _cell_num_vertices(cls) -> pd.PositiveInt: """Number of vertices in a cell.""" return 4 @classmethod @requires_vtk def _vtk_cell_type(cls): """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TETRA @classmethod @requires_vtk def _from_vtk_obj( cls, grid, field=None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, ) -> TetrahedralGridDataset: """Initialize from a vtkUnstructuredGrid instance.""" # read point, cells, and values info from a vtk instance cells_numpy = vtk["vtk_to_numpy"](grid.GetCells().GetConnectivityArray()) points_numpy = vtk["vtk_to_numpy"](grid.GetPoints().GetData()) values = cls._get_values_from_vtk(grid, len(points_numpy), field) # verify cell_types cells_types = vtk["vtk_to_numpy"](grid.GetCellTypesArray()) if not np.all(cells_types == cls._vtk_cell_type()): raise DataError("Only tetrahedral 'vtkUnstructuredGrid' is currently supported") # pack point and cell information into Tidy3D arrays num_cells = len(cells_numpy) // cls._cell_num_vertices() cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices())) cells = CellDataArray( cells_numpy, coords=dict( cell_index=np.arange(num_cells), vertex_index=np.arange(cls._cell_num_vertices()) ), ) points = PointDataArray( points_numpy, coords=dict(index=np.arange(len(points_numpy)), axis=np.arange(cls._point_dims())), ) if remove_degenerate_cells: cells = cls._remove_degenerate_cells(cells=cells) if remove_unused_points: points, values, cells = cls._remove_unused_points( points=points, values=values, cells=cells ) return cls(points=points, cells=cells, values=values)
[docs] @requires_vtk def plane_slice(self, axis: Axis, pos: float) -> TriangularGridDataset: """Slice data with a plane and return the resulting :class:.`TriangularGridDataset`. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- TriangularGridDataset The resulting slice. """ slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) return TriangularGridDataset._from_vtk_obj( slice_vtk, remove_degenerate_cells=True, remove_unused_points=True )
[docs] @requires_vtk def line_slice(self, axis: Axis, pos: Coordinate) -> SpatialDataArray: """Slice data with a line and return the resulting :class:.`SpatialDataArray`. Parameters ---------- axis : Axis The axis of the slicing line. pos : Tuple[float, float, float] Position of the slicing line. Returns ------- SpatialDataArray The resulting slice. """ bounds = self.bounds start = list(pos) end = list(pos) start[axis] = bounds[0][axis] end[axis] = bounds[1][axis] # create cutting plane line = vtk["mod"].vtkLineSource() line.SetPoint1(start) line.SetPoint2(end) line.SetResolution(1) # this should be done using vtkProbeLineFilter # but for some reason it crashes Python # so, we use a workaround: # 1) extract cells that are intersected by line (to speed up further slicing) # 2) do plane slice along first direction # 3) do second plane slice along second direction prober = vtk["mod"].vtkExtractCellsAlongPolyLine() prober.SetSourceConnection(line.GetOutputPort()) prober.SetInputData(self._vtk_obj) prober.Update() extracted_cells_vtk = prober.GetOutput() if extracted_cells_vtk.GetNumberOfPoints() == 0: raise DataError("Slicing line does not intersect the unstructured grid.") extracted_cells = TetrahedralGridDataset._from_vtk_obj( extracted_cells_vtk, remove_degenerate_cells=True, remove_unused_points=True ) tan_dims = [0, 1, 2] tan_dims.remove(axis) # first plane slice plane_slice = extracted_cells.plane_slice(axis=tan_dims[0], pos=pos[tan_dims[0]]) # second plane slice line_slice = plane_slice.plane_slice(axis=tan_dims[1], pos=pos[tan_dims[1]]) return line_slice
[docs] @requires_vtk def sel( self, x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, ) -> Union[TriangularGridDataset, SpatialDataArray]: """Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z must be provided. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. Returns ------- Union[TriangularGridDataset, SpatialDataArray] Extracted data. """ xyz = [x, y, z] axes = [ind for ind, comp in enumerate(xyz) if comp is not None] num_provided = len(axes) if num_provided < 3 and any(not np.isscalar(comp) for comp in xyz if comp is not None): raise DataError( "Providing x, y, or z as array is only allowed for interpolation. That is, when all" " three x, y, and z are provided or method '.interp()' is used explicitly." ) if num_provided == 0: raise DataError("At least one of 'x', 'y', and 'z' must be specified.") if num_provided == 1: axis = axes[0] return self.plane_slice(axis=axis, pos=xyz[axis]) if num_provided == 2: axis = 3 - axes[0] - axes[1] xyz[axis] = 0 return self.line_slice(axis=axis, pos=xyz) if num_provided == 3: return self.interp(x=x, y=y, z=z)
def _interp_py( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, max_samples_per_step: int, max_cells_per_step: int, rel_tol: float, ) -> ArrayLike: """3D-specific function to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """ return self._interp_py_general( x=x, y=y, z=z, fill_value=fill_value, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, rel_tol=rel_tol, axis_ignore=None, )
UnstructuredGridDatasetType = Union[TriangularGridDataset, TetrahedralGridDataset] CustomSpatialDataType = Union[SpatialDataArray, UnstructuredGridDatasetType] CustomSpatialDataTypeAnnotated = Union[SpatialDataArray, annotate_type(UnstructuredGridDatasetType)] def _get_numpy_array(data_array: Union[ArrayLike, DataArray, UnstructuredGridDataset]) -> ArrayLike: """Get numpy representation of dataarray/dataset values.""" if isinstance(data_array, UnstructuredGridDataset): return data_array.values.values if isinstance(data_array, xr.DataArray): return data_array.values return np.array(data_array) def _zeros_like( data_array: Union[ArrayLike, xr.DataArray, UnstructuredGridDataset] ) -> Union[ArrayLike, xr.DataArray, UnstructuredGridDataset]: """Get a zeroed replica of dataarray/dataset.""" if isinstance(data_array, UnstructuredGridDataset): return data_array.updated_copy(values=xr.zeros_like(data_array.values)) if isinstance(data_array, xr.DataArray): return xr.zeros_like(data_array) return np.zeros_like(data_array) def _ones_like( data_array: Union[ArrayLike, xr.DataArray, UnstructuredGridDataset] ) -> Union[ArrayLike, xr.DataArray, UnstructuredGridDataset]: """Get a unity replica of dataarray/dataset.""" if isinstance(data_array, UnstructuredGridDataset): return data_array.updated_copy(values=xr.ones_like(data_array.values)) if isinstance(data_array, xr.DataArray): return xr.ones_like(data_array) return np.ones_like(data_array) def _check_same_coordinates( a: Union[ArrayLike, xr.DataArray, UnstructuredGridDataset], b: Union[ArrayLike, xr.DataArray, UnstructuredGridDataset], ) -> bool: """Check whether two array are defined at the same coordinates.""" # we can have xarray.DataArray's of different types but still same coordinates # we will deal with that case separately both_xarrays = isinstance(a, xr.DataArray) and isinstance(b, xr.DataArray) if (not both_xarrays) and type(a) != type(b): return False if isinstance(a, UnstructuredGridDataset): if not np.allclose(a.points, b.points) or not np.all(a.cells == b.cells): return False if isinstance(a, TriangularGridDataset): if a.normal_axis != b.normal_axis or a.normal_pos != b.normal_pos: return False elif isinstance(a, xr.DataArray): if a.coords.keys() != b.coords.keys() or a.coords != b.coords: return False else: if np.shape(a) != np.shape(b): return False return True