Source code for tidy3d.components.data.unstructured.triangular

"""Defines triangular grid datasets."""

from __future__ import annotations

from typing import Literal, Optional, Union

import numpy as np
import pydantic.v1 as pd

try:
    from matplotlib import pyplot as plt
    from matplotlib.tri import Triangulation
except ImportError:
    pass

from xarray import DataArray as XrDataArray

from tidy3d.components.base import cached_property
from tidy3d.components.data.data_array import (
    CellDataArray,
    IndexedDataArray,
    PointDataArray,
    SpatialDataArray,
)
from tidy3d.components.types import ArrayLike, Ax, Axis, Bound
from tidy3d.components.viz import add_ax_if_none, equal_aspect, plot_params_grid
from tidy3d.constants import inf
from tidy3d.exceptions import DataError
from tidy3d.log import log
from tidy3d.packaging import requires_vtk, vtk

from .base import (
    DEFAULT_MAX_CELLS_PER_STEP,
    DEFAULT_MAX_SAMPLES_PER_STEP,
    DEFAULT_TOLERANCE_CELL_FINDING,
    UnstructuredGridDataset,
)


[docs] class TriangularGridDataset(UnstructuredGridDataset): """Dataset for storing triangular grid data. Data values are associated with the nodes of the grid. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> tri_grid_points = PointDataArray( ... [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], ... coords=dict(index=np.arange(4), axis=np.arange(2)), ... ) >>> >>> tri_grid_cells = CellDataArray( ... [[0, 1, 2], [1, 2, 3]], ... coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)), ... ) >>> >>> tri_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tri_grid = TriangularGridDataset( ... normal_axis=1, ... normal_pos=0, ... points=tri_grid_points, ... cells=tri_grid_cells, ... values=tri_grid_values, ... ) """ normal_axis: Axis = pd.Field( ..., title="Grid Axis", description="Orientation of the grid.", ) normal_pos: float = pd.Field( ..., title="Position", description="Coordinate of the grid along the normal direction.", ) """ Fundamental parameters to set up based on grid dimensionality """ @classmethod def _point_dims(cls) -> pd.PositiveInt: """Dimensionality of stored grid point coordinates.""" return 2 @classmethod def _cell_num_vertices(cls) -> pd.PositiveInt: """Number of vertices in a cell.""" return 3 """ Convenience properties """ @cached_property def bounds(self) -> Bound: """Grid bounds.""" bounds_2d = super().bounds bounds_3d = self._points_2d_to_3d(bounds_2d) return tuple(bounds_3d[0]), tuple(bounds_3d[1]) def _points_2d_to_3d(self, pts: ArrayLike) -> ArrayLike: """Convert 2d points into 3d points.""" return np.insert(pts, obj=self.normal_axis, values=self.normal_pos, axis=1) @cached_property def _points_3d_array(self) -> ArrayLike: """3D representation of grid points.""" return self._points_2d_to_3d(self.points.data) """ VTK interfacing """ @classmethod @requires_vtk def _vtk_cell_type(cls): """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TRIANGLE @classmethod @requires_vtk def _from_vtk_obj( cls, vtk_obj, field=None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, values_type=IndexedDataArray, expect_complex=None, ignore_invalid_cells: bool = False, ): """Initialize from a vtkUnstructuredGrid instance.""" # get points cells data from vtk object if isinstance(vtk_obj, vtk["mod"].vtkPolyData): cells_vtk = vtk_obj.GetPolys() elif isinstance(vtk_obj, vtk["mod"].vtkUnstructuredGrid): cells_vtk = vtk_obj.GetCells() cells_numpy = vtk["vtk_to_numpy"](cells_vtk.GetConnectivityArray()) # verify cell_types cell_offsets = vtk["vtk_to_numpy"](cells_vtk.GetOffsetsArray()) invalid_cells = np.diff(cell_offsets) != cls._cell_num_vertices() if np.any(invalid_cells): if ignore_invalid_cells: valid_cell_offsets = cell_offsets[:-1][invalid_cells == 0] cells_numpy = cells_numpy[ np.ravel(valid_cell_offsets[:, None] + np.arange(3, dtype=int)[None, :]) ] else: raise DataError( "Only triangular 'vtkUnstructuredGrid' or 'vtkPolyData' can be converted into " "'TriangularGridDataset'." ) points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) # data values are read directly into Tidy3D array values = cls._get_values_from_vtk( vtk_obj, len(points_numpy), field, values_type, expect_complex ) # detect zero size dimension bounds = np.max(points_numpy, axis=0) - np.min(points_numpy, axis=0) zero_dims = np.where(np.isclose(bounds, 0, atol=1e-6))[0] if len(zero_dims) != 1: raise DataError( f"Provided vtk grid does not represent a two dimensional grid. Found zero size dimensions are {zero_dims}." ) normal_axis = zero_dims[0] normal_pos = points_numpy[0][normal_axis] tan_dims = [0, 1, 2] tan_dims.remove(normal_axis) # convert 3d coordinates into 2d points_2d_numpy = points_numpy[:, tan_dims] # create Tidy3D points and cells arrays num_cells = len(cells_numpy) // cls._cell_num_vertices() cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices())) cells = CellDataArray( cells_numpy, coords={ "cell_index": np.arange(num_cells), "vertex_index": np.arange(cls._cell_num_vertices()), }, ) points = PointDataArray( points_2d_numpy, coords={"index": np.arange(len(points_numpy)), "axis": np.arange(cls._point_dims())}, ) if remove_degenerate_cells: cells = cls._remove_degenerate_cells(cells=cells) if remove_unused_points: points, values, cells = cls._remove_unused_points( points=points, values=values, cells=cells ) return cls( normal_axis=normal_axis, normal_pos=normal_pos, points=points, cells=cells, values=values, ) """ Grid operations """
[docs] @requires_vtk def plane_slice(self, axis: Axis, pos: float) -> XrDataArray: """Slice data with a plane and return the resulting line as a DataArray. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- xarray.DataArray The resulting slice. """ if axis == self.normal_axis: raise DataError( f"Triangular grid (normal: {self.normal_axis}) cannot be sliced by a parallel " "plane." ) # perform slicing in vtk and get unprocessed points and values slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) points_numpy = vtk["vtk_to_numpy"](slice_vtk.GetPoints().GetData()) values = self._get_values_from_vtk( slice_vtk, len(points_numpy), field=self._values_coords_dict, values_type=self._values_type, expect_complex=self.is_complex, ) # axis of the resulting line slice_axis = 3 - self.normal_axis - axis # assemble coords for DataArray coords = [None, None, None] coords[axis] = [pos] coords[self.normal_axis] = [self.normal_pos] coords[slice_axis] = points_numpy[:, slice_axis] coords_dict = dict(zip("xyz", coords)) coords_dict.update(self._values_coords_dict) # reshape values from a 1d array into a 3d array new_shape = [1, 1, 1] new_shape[slice_axis] = len(values.index) new_shape = new_shape + list(np.shape(values.data))[1:] values_reshaped = np.reshape(values.data, new_shape) return XrDataArray(values_reshaped, coords=coords_dict, name=self.values.name).sortby( "xyz"[slice_axis] )
[docs] @requires_vtk def reflect( self, axis: Axis, center: float, reflection_only: bool = False ) -> UnstructuredGridDataset: """Reflect unstructured data across the plane define by parameters ``axis`` and ``center``. By default the original data is preserved, setting ``reflection_only`` to ``True`` will produce only deflected data. Parameters ---------- axis : Literal[0, 1, 2] Normal direction of the reflection plane. center : float Location of the reflection plane along its normal direction. reflection_only : bool = False Return only reflected data. Returns ------- UnstructuredGridDataset Data after reflextion is performed. """ # disallow reflecting along normal direction if axis == self.normal_axis: if reflection_only: return self.updated_copy(normal_pos=2 * center - self.normal_pos) else: raise DataError( "Reflection in the normal direction to the grid is prohibited unless 'reflection_only=True'." ) return super().reflect(axis=axis, center=center, reflection_only=reflection_only)
""" Interpolation """ def _spatial_interp( self, x: Union[float, ArrayLike], y: Union[float, ArrayLike], z: Union[float, ArrayLike], fill_value: Optional[Union[float, Literal["extrapolate"]]] = None, use_vtk: bool = False, method: Literal["linear", "nearest"] = "linear", ignore_normal_pos: bool = True, max_samples_per_step: int = DEFAULT_MAX_SAMPLES_PER_STEP, max_cells_per_step: int = DEFAULT_MAX_CELLS_PER_STEP, rel_tol: float = DEFAULT_TOLERANCE_CELL_FINDING, ) -> XrDataArray: """Interpolate data at provided x, y, and z. Note that data is assumed to be invariant along the dataset's normal direction. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : Union[float, Literal["extrapolate"]] = 0 Value to use when filling points without interpolated values. If ``"extrapolate"`` then nearest values are used. Note: in a future version the default value will be changed to ``"extrapolate"``. use_vtk : bool = False Use vtk's interpolation functionality or Tidy3D's own implementation. Note: this option will be removed in a future version. method: Literal["linear", "nearest"] = "linear" Interpolation method to use. ignore_normal_pos : bool = True (Depreciated) Assume data is invariant along the normal direction to the grid plane. max_samples_per_step : int = 1e4 Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int = 1e4 Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float = 1e-6 Relative tolerance when determining whether a point belongs to a cell. Returns ------- xarray.DataArray Interpolated data. """ if not ignore_normal_pos: log.warning( "Parameter 'ignore_normal_pos' is depreciated. It is always assumed that data " "contained in 'TriangularGridDataset' is invariant in the normal direction. " "That is, 'ignore_normal_pos=True' is used." ) x = np.atleast_1d(x) y = np.atleast_1d(y) z = np.atleast_1d(z) xyz = [x, y, z] xyz[self.normal_axis] = [self.normal_pos] interp_inplane = super()._spatial_interp( **dict(zip("xyz", xyz)), fill_value=fill_value, use_vtk=use_vtk, method=method, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, ) interp_broadcasted = np.broadcast_to( interp_inplane, [len(np.atleast_1d(comp)) for comp in [x, y, z]] + self._fields_shape ) coords_dict = {"x": x, "y": y, "z": z} coords_dict.update(self._values_coords_dict) if len(self._values_coords_dict) == 0: return SpatialDataArray(interp_broadcasted, coords=coords_dict, name=self.values.name) else: return XrDataArray(interp_broadcasted, coords=coords_dict, name=self.values.name) def _interp_py( self, x: ArrayLike, y: ArrayLike, z: ArrayLike, fill_value: float, max_samples_per_step: int, max_cells_per_step: int, rel_tol: float, ) -> ArrayLike: """2D-specific function to interpolate data at provided x, y, and z using vectorized python implementation. Parameters ---------- x : Union[float, ArrayLike] x-coordinates of sampling points. y : Union[float, ArrayLike] y-coordinates of sampling points. z : Union[float, ArrayLike] z-coordinates of sampling points. fill_value : float Value to use when filling points without interpolated values. max_samples_per_step : int Max number of points to interpolate at per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. max_cells_per_step : int Max number of cells to interpolate from per iteration (used only if `use_vtk=False`). Using a higher number may speed up calculations but, at the same time, it increases RAM usage. rel_tol : float Relative tolerance when determining whether a point belongs to a cell. Returns ------- ArrayLike Interpolated data. """ return self._interp_py_general( x=x, y=y, z=z, fill_value=fill_value, max_samples_per_step=max_samples_per_step, max_cells_per_step=max_cells_per_step, rel_tol=rel_tol, axis_ignore=self.normal_axis, ) """ Data selection """
[docs] @requires_vtk def sel( self, x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, method: Optional[Literal["None", "nearest", "pad", "ffill", "backfill", "bfill"]] = None, **sel_kwargs, ) -> XrDataArray: """Extract/interpolate data along one or more spatial or non-spatial directions. Must provide at least one argument among 'x', 'y', 'z' or non-spatial dimensions through additional arguments. Along spatial dimensions a suitable slicing of grid is applied (plane slice, line slice, or interpolation). Selection along non-spatial dimensions is forwarded to .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. Returns ------- xarray.DataArray Extracted data. """ xyz = [x, y, z] axes = [ind for ind, comp in enumerate(xyz) if comp is not None] num_provided = len(axes) if self.normal_axis in axes: if xyz[self.normal_axis] != self.normal_pos: raise DataError( f"No data for {'xyz'[self.normal_axis]} = {xyz[self.normal_axis]} (unstructured" f" grid is defined at {'xyz'[self.normal_axis]} = {self.normal_pos})." ) if num_provided < 3: num_provided -= 1 axes.remove(self.normal_axis) if num_provided == 0 and len(sel_kwargs) == 0: raise DataError("At least one dimension for selection must be provided.") self_after_non_spatial_sel = self._non_spatial_sel(method=method, **sel_kwargs) if num_provided == 1: axis = axes[0] return self_after_non_spatial_sel.plane_slice(axis=axis, pos=xyz[axis]) if num_provided == 2: pos = [x, y, z] pos[self.normal_axis] = [self.normal_pos] return self_after_non_spatial_sel.interp(x=pos[0], y=pos[1], z=pos[2]) if num_provided == 3: return self_after_non_spatial_sel.interp(x=x, y=y, z=z) return self_after_non_spatial_sel
[docs] @requires_vtk def sel_inside(self, bounds: Bound) -> TriangularGridDataset: """Return a new ``TriangularGridDataset`` that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- TriangularGridDataset Extracted spatial data array. """ if any(bmin > bmax for bmin, bmax in zip(*bounds)): raise DataError( "Min and max bounds must be packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``." ) # expand along normal direction new_bounds = [list(bounds[0]), list(bounds[1])] new_bounds[0][self.normal_axis] = -inf new_bounds[1][self.normal_axis] = inf return super().sel_inside(new_bounds)
[docs] def does_cover(self, bounds: Bound) -> bool: """Check whether data fully covers specified by ``bounds`` spatial region. If data contains only one point along a given direction, then it is assumed the data is constant along that direction and coverage is not checked. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- bool Full cover check outcome. """ # expand along normal direction new_bounds = [list(bounds[0]), list(bounds[1])] new_bounds[0][self.normal_axis] = self.normal_pos new_bounds[1][self.normal_axis] = self.normal_pos return super().does_cover(new_bounds)
""" Plotting """ @property def _triangulation_obj(self) -> Triangulation: """Matplotlib triangular representation of the grid to use in plotting.""" return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)
[docs] @equal_aspect @add_ax_if_none def plot( self, ax: Ax = None, field: bool = True, grid: bool = True, cbar: bool = True, cmap: str = "viridis", vmin: Optional[float] = None, vmax: Optional[float] = None, shading: Literal["gourand", "flat"] = "gouraud", cbar_kwargs: Optional[dict] = None, pcolor_kwargs: Optional[dict] = None, ) -> Ax: """Plot the data field and/or the unstructured grid. Parameters ---------- ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. field : bool = True Whether to plot the data field. grid : bool = True Whether to plot the unstructured grid. cbar : bool = True Display colorbar (only if ``field == True``). cmap : str = "viridis" Color map to use for plotting. vmin : float = None The lower bound of data range that the colormap covers. If ``None``, they are inferred from the data and other keyword arguments. vmax : float = None The upper bound of data range that the colormap covers. If ``None``, they are inferred from the data and other keyword arguments. shading : Literal["gourand", "flat"] = "gourand" Type of shading to use when plotting the data field. cbar_kwargs : Dict = {} Additional parameters passed to colorbar object. pcolor_kwargs: Dict = {} Additional parameters passed to ax.tripcolor() Returns ------- matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ if cbar_kwargs is None: cbar_kwargs = {} if pcolor_kwargs is None: pcolor_kwargs = {} if not (field or grid): raise DataError("Nothing to plot ('field == False', 'grid == False').") # plot data field if requested if field: if self._num_fields != 1: raise DataError( "Unstructured dataset contains more than 1 field. " "Use '.sel()' to select a single field from available dimensions " f"{self._values_coords_dict} before plotting." ) plot_obj = ax.tripcolor( self._triangulation_obj, self.values.data.ravel(), shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, **pcolor_kwargs, ) if cbar: label_kwargs = {} if "label" not in cbar_kwargs: label_kwargs["label"] = self.values.name plt.colorbar(plot_obj, **cbar_kwargs, **label_kwargs) # plot grid if requested if grid: ax.triplot( self._triangulation_obj, color=plot_params_grid.edgecolor, linewidth=plot_params_grid.linewidth, ) # set labels and titles ax_labels = ["x", "y", "z"] normal_axis_name = ax_labels.pop(self.normal_axis) ax.set_xlabel(ax_labels[0]) ax.set_ylabel(ax_labels[1]) ax.set_title(f"{normal_axis_name} = {self.normal_pos}") return ax
[docs] def get_cell_volumes(self): """Get areas associated to each cell of the grid.""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 e02 = self.points[self.cells.sel(vertex_index=2)] - v0 areas = e01[:, 0] * e02[:, 1] - e01[:, 1] * e02[:, 0] return 0.5 * np.abs(areas)