Source code for tidy3d.components.data.unstructured.tetrahedral

"""Defines tetrahedral grid datasets."""

from __future__ import annotations

from typing import Union

import numpy as np
import pydantic.v1 as pd
from xarray import DataArray as XrDataArray

from tidy3d.components.base import cached_property
from tidy3d.components.data.data_array import (
    CellDataArray,
    IndexedDataArray,
    PointDataArray,
)
from tidy3d.components.types import ArrayLike, Axis, Bound, Coordinate
from tidy3d.exceptions import DataError
from tidy3d.packaging import requires_vtk, vtk

from .base import UnstructuredGridDataset
from .triangular import TriangularGridDataset


[docs] class TetrahedralGridDataset(UnstructuredGridDataset): """Dataset for storing tetrahedral grid data. Data values are associated with the nodes of the grid. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> tet_grid_points = PointDataArray( ... [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], ... coords=dict(index=np.arange(4), axis=np.arange(3)), ... ) >>> >>> tet_grid_cells = CellDataArray( ... [[0, 1, 2, 3]], ... coords=dict(cell_index=np.arange(1), vertex_index=np.arange(4)), ... ) >>> >>> tet_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tet_grid = TetrahedralGridDataset( ... points=tet_grid_points, ... cells=tet_grid_cells, ... values=tet_grid_values, ... ) """ """ Fundametal parameters to set up based on grid dimensionality """ @classmethod def _traingular_dataset_type(cls) -> type: """Corresponding class for triangular grid datasets. We need to know this when creating a triangular slice from a tetrahedral grid.""" return TriangularGridDataset @classmethod def _point_dims(cls) -> pd.PositiveInt: """Dimensionality of stored grid point coordinates.""" return 3 @classmethod def _cell_num_vertices(cls) -> pd.PositiveInt: """Number of vertices in a cell.""" return 4 """ Convenience properties """ @cached_property def _points_3d_array(self) -> Bound: """3D coordinates of grid points.""" return self.points.data """ VTK interfacing """ @classmethod @requires_vtk def _vtk_cell_type(cls): """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TETRA @classmethod @requires_vtk def _from_vtk_obj( cls, vtk_obj, field=None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, values_type=IndexedDataArray, expect_complex: bool = False, ignore_invalid_cells: bool = False, ) -> TetrahedralGridDataset: """Initialize from a vtkUnstructuredGrid instance.""" # read point, cells, and values info from a vtk instance cells_numpy = vtk["vtk_to_numpy"](vtk_obj.GetCells().GetConnectivityArray()) points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) values = cls._get_values_from_vtk( vtk_obj, len(points_numpy), field, values_type, expect_complex ) # verify cell_types cells_types = vtk["vtk_to_numpy"](vtk_obj.GetCellTypesArray()) invalid_cells = cells_types != cls._vtk_cell_type() if any(invalid_cells): if ignore_invalid_cells: cell_offsets = vtk["vtk_to_numpy"](vtk_obj.GetCells().GetOffsetsArray()) valid_cell_offsets = cell_offsets[:-1][invalid_cells == 0] cells_numpy = cells_numpy[ np.ravel( valid_cell_offsets[:, None] + np.arange(cls._cell_num_vertices(), dtype=int)[None, :] ) ] else: raise DataError("Only tetrahedral 'vtkUnstructuredGrid' is currently supported") # pack point and cell information into Tidy3D arrays num_cells = len(cells_numpy) // cls._cell_num_vertices() cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices())) cells = CellDataArray( cells_numpy, coords={ "cell_index": np.arange(num_cells), "vertex_index": np.arange(cls._cell_num_vertices()), }, ) points = PointDataArray( points_numpy, coords={"index": np.arange(len(points_numpy)), "axis": np.arange(cls._point_dims())}, ) if remove_degenerate_cells: cells = cls._remove_degenerate_cells(cells=cells) if remove_unused_points: points, values, cells = cls._remove_unused_points( points=points, values=values, cells=cells ) return cls(points=points, cells=cells, values=values) """ Grid operations """
[docs] @requires_vtk def plane_slice(self, axis: Axis, pos: float) -> TriangularGridDataset: """Slice data with a plane and return the resulting :class:`.TriangularGridDataset`. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- TriangularGridDataset The resulting slice. """ slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) return self._traingular_dataset_type()._from_vtk_obj( slice_vtk, remove_degenerate_cells=True, remove_unused_points=True, field=self._values_coords_dict, values_type=self._values_type, expect_complex=self.is_complex, )
[docs] @requires_vtk def line_slice(self, axis: Axis, pos: Coordinate) -> XrDataArray: """Slice data with a line and return the resulting xarray.DataArray. Parameters ---------- axis : Axis The axis of the slicing line. pos : Tuple[float, float, float] Position of the slicing line. Returns ------- xarray.DataArray The resulting slice. """ bounds = self.bounds start = list(pos) end = list(pos) start[axis] = bounds[0][axis] end[axis] = bounds[1][axis] # create cutting plane line = vtk["mod"].vtkLineSource() line.SetPoint1(start) line.SetPoint2(end) line.SetResolution(1) # this should be done using vtkProbeLineFilter # but for some reason it crashes Python # so, we use a workaround: # 1) extract cells that are intersected by line (to speed up further slicing) # 2) do plane slice along first direction # 3) do second plane slice along second direction prober = vtk["mod"].vtkExtractCellsAlongPolyLine() prober.SetSourceConnection(line.GetOutputPort()) prober.SetInputData(self._vtk_obj) prober.Update() extracted_cells_vtk = prober.GetOutput() if extracted_cells_vtk.GetNumberOfPoints() == 0: raise DataError("Slicing line does not intersect the unstructured grid.") extracted_cells = self._from_vtk_obj_internal(extracted_cells_vtk) tan_dims = [0, 1, 2] tan_dims.remove(axis) # first plane slice plane_slice = extracted_cells.plane_slice(axis=tan_dims[0], pos=pos[tan_dims[0]]) # second plane slice line_slice = plane_slice.plane_slice(axis=tan_dims[1], pos=pos[tan_dims[1]]) return line_slice
""" Interpolation """ def _interp_py( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, max_samples_per_step: int, max_cells_per_step: int, rel_tol: float, ) -> ArrayLike: """3D-specific function to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """ return self._interp_py_general( x=x, y=y, z=z, fill_value=fill_value, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, rel_tol=rel_tol, axis_ignore=None, ) """ Data selection """
[docs] @requires_vtk def sel( self, x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, method=None, **sel_kwargs, ) -> Union[TriangularGridDataset, XrDataArray]: """Extract/interpolate data along one or more spatial or non-spatial directions. Must provide at least one argument among 'x', 'y', 'z' or non-spatial dimensions through additional arguments. Along spatial dimensions a suitable slicing of grid is applied (plane slice, line slice, or interpolation). Selection along non-spatial dimensions is forwarded to .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. Returns ------- Union[TriangularGridDataset, xarray.DataArray] Extracted data. """ xyz = [x, y, z] axes = [ind for ind, comp in enumerate(xyz) if comp is not None] num_provided = len(axes) if num_provided < 3 and any(not np.isscalar(comp) for comp in xyz if comp is not None): raise DataError( "Providing x, y, or z as array is only allowed for interpolation. That is, when all" " three x, y, and z are provided or method '.interp()' is used explicitly." ) if num_provided == 0 and len(sel_kwargs) == 0: raise DataError( "Must provide at least one dimension to select along " "(available: {self._non_spatial_dims + list('xyz')})." ) self_after_non_spatial_sel = self._non_spatial_sel(method=method, **sel_kwargs) if num_provided == 1: axis = axes[0] return self_after_non_spatial_sel.plane_slice(axis=axis, pos=xyz[axis]) if num_provided == 2: axis = 3 - axes[0] - axes[1] xyz[axis] = 0 return self_after_non_spatial_sel.line_slice(axis=axis, pos=xyz) if num_provided == 3: return self_after_non_spatial_sel.interp(x=x, y=y, z=z) return self_after_non_spatial_sel
[docs] def get_cell_volumes(self): """Get the volumes associated to each cell in the grid""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 e02 = self.points[self.cells.sel(vertex_index=2)] - v0 e03 = self.points[self.cells.sel(vertex_index=3)] - v0 return np.abs(np.sum(np.cross(e01, e02) * e03, axis=1)) / 6