Source code for tidy3d.components.data.unstructured.surface

"""Defines triangular grid datasets."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from pydantic import Field

from tidy3d.components.base import cached_property
from tidy3d.components.data.data_array import (
    CellDataArray,
    IndexedDataArrayTypes,
    PointDataArray,
)
from tidy3d.components.viz import add_plotter_if_none
from tidy3d.exceptions import DataError, Tidy3dNotImplementedError
from tidy3d.packaging import pyvista, requires_pyvista, requires_vtk, vtk

from .base import (
    DEFAULT_MAX_CELLS_PER_STEP,
    DEFAULT_MAX_SAMPLES_PER_STEP,
    DEFAULT_TOLERANCE_CELL_FINDING,
    UnstructuredDataset,
)

if TYPE_CHECKING:
    from typing import Any, Literal, Optional, Union

    from pydantic import PositiveInt
    from xarray import DataArray as XrDataArray

    from tidy3d.components.types import ArrayLike, Axis


[docs] class TriangularSurfaceDataset(UnstructuredDataset): """Dataset for storing triangulated surface data. Data values are associated with the nodes of the mesh. Note ---- To use full functionality of unstructured datasets one must install ``vtk`` package (``pip install tidy3d[vtk]`` or ``pip install vtk``). For visualization, ``pyvista`` is recommended (``pip install pyvista``). Otherwise the functionality of unstructured datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. Example ------- >>> import numpy as np >>> from tidy3d.components.data.data_array import PointDataArray, CellDataArray, IndexedDataArray >>> >>> # Create a simple triangulated surface >>> tri_grid_points = PointDataArray( ... [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], ... coords=dict(index=np.arange(4), axis=np.arange(3)), ... ) >>> >>> tri_grid_cells = CellDataArray( ... [[0, 1, 2], [1, 2, 3]], ... coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)), ... ) >>> >>> tri_grid_values = IndexedDataArray( ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), ... ) >>> >>> tri_grid = TriangularSurfaceDataset( ... points=tri_grid_points, ... cells=tri_grid_cells, ... values=tri_grid_values, ... ) >>> >>> # Visualize the surface (change show=False to show=True to display the plot) >>> _ = tri_grid.plot(show=False) >>> >>> # Customize the visualization (change show=False to show=True to display the plot) >>> _ = tri_grid.plot(cmap='plasma', grid=True, grid_color='white', show=False) >>> >>> # For vector fields >>> vector_values = IndexedDataArray( ... [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0]], ... coords=dict(index=np.arange(4), axis=np.arange(3)), ... ) >>> vector_grid = tri_grid.updated_copy(values=vector_values) >>> >>> # Plot as arrow field (change show=False to show=True to display the plot) >>> _ = vector_grid.quiver(scale=0.2, show=False) """ points: PointDataArray = Field( ..., title="Surface Points", description="Coordinates of points composing the triangulated surface.", ) values: IndexedDataArrayTypes = Field( ..., title="Surface Values", description="Values stored at the surface points.", ) cells: CellDataArray = Field( ..., title="Surface Cells", description="Cells composing the triangulated surface specified as connections between surface " "points.", ) """ Fundamental parameters to set up based on grid dimensionality """ @classmethod def _point_dims(cls) -> PositiveInt: """Dimensionality of stored surface point coordinates.""" return 3 @classmethod def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" return 3 """ Convenience properties """ @cached_property def _points_3d_array(self) -> ArrayLike: """3D representation of points.""" return self.points.data """ VTK interfacing """ @classmethod @requires_vtk def _vtk_cell_type(cls) -> int: """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TRIANGLE """ Grid operations """
[docs] @requires_vtk def plane_slice(self, axis: Axis, pos: float) -> XrDataArray: """Slice data with a plane and return the resulting line as a DataArray. Parameters ---------- axis : Axis The normal direction of the slicing plane. pos : float Position of the slicing plane along its normal direction. Returns ------- xarray.DataArray The resulting slice. """ raise Tidy3dNotImplementedError("Slicing of unstructured surfaces is not implemented yet.")
""" Interpolation """ def _spatial_interp( self, x: Union[float, ArrayLike], y: Union[float, ArrayLike], z: Union[float, ArrayLike], fill_value: Optional[Union[float, Literal["extrapolate"]]] = None, use_vtk: bool = False, method: Literal["linear", "nearest"] = "linear", ignore_normal_pos: bool = True, max_samples_per_step: int = DEFAULT_MAX_SAMPLES_PER_STEP, max_cells_per_step: int = DEFAULT_MAX_CELLS_PER_STEP, rel_tol: float = DEFAULT_TOLERANCE_CELL_FINDING, ) -> XrDataArray: """Interpolate data along spatial dimensions at provided x, y, and z.""" raise Tidy3dNotImplementedError( "Spatial interpolation from unstructured surfaces is not implemented yet." ) """ Data selection """
[docs] def sel( self, x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, method: Optional[Literal["None", "nearest", "pad", "ffill", "backfill", "bfill"]] = None, **sel_kwargs: Any, ) -> Union[TriangularSurfaceDataset, XrDataArray]: """Extract/interpolate data along one or more spatial or non-spatial directions. Currently works only for non-spatial dimensions through additional arguments. Selection along non-spatial dimensions is forwarded to .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions. Parameters ---------- x : Union[float, ArrayLike] = None x-coordinate of the slice. y : Union[float, ArrayLike] = None y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. Returns ------- Union[TriangularSurfaceDataset, xarray.DataArray] Extracted data. """ if any(comp is not None for comp in [x, y, z]): raise Tidy3dNotImplementedError( "Surface datasets do not support selection along x, y, or z yet." ) return self._non_spatial_sel(method=method, **sel_kwargs)
[docs] def get_cell_volumes(self) -> XrDataArray: """Get areas associated to each cell of the grid.""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 e02 = self.points[self.cells.sel(vertex_index=2)] - v0 return 0.5 * np.linalg.norm(np.cross(e01, e02), axis=1)
""" Plotting """
[docs] @requires_pyvista @add_plotter_if_none def plot( self, plotter: Any = None, field: bool = True, grid: bool = False, cbar: bool = True, cmap: str = "viridis", vmin: Optional[float] = None, vmax: Optional[float] = None, grid_color: str = "black", grid_width: float = 1.0, opacity: float = 1.0, show: bool = True, windowed: Optional[bool] = None, window_size: tuple = (800, 600), **mesh_kwargs: Any, ) -> Any: """Plot the surface mesh and/or associated data using PyVista. Parameters ---------- plotter : pyvista.Plotter = None PyVista plotter to add the mesh to. If not specified, one is created. field : bool = True Whether to plot the data field. grid : bool = False Whether to show the mesh edges/grid lines. cbar : bool = True Display scalar bar (only if ``field == True``). cmap : str = "viridis" Color map to use for plotting. vmin : float = None The lower bound of data range that the colormap covers. If ``None``, inferred from the data. vmax : float = None The upper bound of data range that the colormap covers. If ``None``, inferred from the data. grid_color : str = "black" Color of grid lines when ``grid == True``. grid_width : float = 1.0 Width of grid lines when ``grid == True``. opacity : float = 1.0 Opacity of the mesh (0=transparent, 1=opaque). show : bool = True Whether to display the plot immediately. If False, returns plotter for further customization. windowed : bool = None Whether to display in an external window. If None (default), automatically detects environment (inline in notebooks, windowed otherwise). Set to True to force external window even when running in a notebook, which provides better interactivity and performance. Only used when ``plotter`` is None. window_size : tuple = (800, 600) Size of the window (width, height) when creating new plotter. **mesh_kwargs Additional keyword arguments passed to plotter.add_mesh(). Returns ------- pyvista.Plotter The plotter object with the mesh added. Returns None if ``show=True`` and plotter was auto-created. Raises ------ DataError If nothing to plot or if multiple fields present without selection. """ if not (field or grid): raise DataError("Nothing to plot ('field == False', 'grid == False').") # Validate field selection if field and self._num_fields != 1: raise DataError( "Unstructured dataset contains more than 1 field. " "Use '.sel()' to select a single field from available dimensions " f"{self._non_spatial_coords_dict} before plotting." ) # Check for complex values if field and self.is_complex: raise DataError( "Cannot plot complex-valued data directly. " "Please select a real component using '.real', '.imag', or '.abs' before plotting." ) # Get PyVista module pv = pyvista["mod"] values_name = self.values.name if self.values.name else "values" # Create PyVista PolyData directly from points and cells # Faces format: [3, v0, v1, v2, 3, v3, v4, v5, ...] where 3 is the number of vertices faces = np.hstack([np.full((len(self.cells), 1), 3, dtype=int), self.cells.data]) mesh = pv.PolyData(self.points.data, faces.ravel()) # Setup scalar field if needed scalars = None if field: # Add scalar values to the mesh data_values = self.values.values if len(self._non_spatial_shape) > 0: data_values = data_values.reshape(len(self.points.values), self._num_fields) mesh[values_name] = data_values scalars = values_name # Add mesh to plotter plotter.add_mesh( mesh, scalars=scalars if field else None, cmap=cmap, clim=(vmin, vmax) if (vmin is not None or vmax is not None) else None, show_edges=grid, edge_color=grid_color, line_width=grid_width, opacity=opacity, show_scalar_bar=False, # We'll manually add scalar bar if cbar=True **mesh_kwargs, ) # Add scalar bar if field and cbar: scalar_bar_args = { "title": self.values.name or "Value", "n_labels": 5, "italic": False, "fmt": "%.2e", "font_family": "arial", } plotter.add_scalar_bar(**scalar_bar_args) # Set axis labels plotter.add_axes(xlabel="x", ylabel="y", zlabel="z") return plotter
[docs] @requires_pyvista @add_plotter_if_none def quiver( self, plotter: Any = None, dim: str = "axis", scale: float = 0.1, downsampling: int = 1, color: str = "magnitude", cbar: bool = True, cmap: str = "Spectral", show: bool = True, windowed: Optional[bool] = None, window_size: tuple = (800, 600), **arrow_kwargs: Any, ) -> Any: """Plot the associated data as vector field using arrows. Field ``values`` must have length 3 along the dimension representing x, y, and z components. Parameters ---------- plotter : pyvista.Plotter = None PyVista plotter to add the arrows to. If not specified, one is created. dim : str = "axis" Dimension along which x, y, z components are stored. scale : float = 0.1 Size of arrows relative to the diagonal length of the surface bounding box. downsampling : int = 1 Step for selecting points for plotting (1 for plotting all points). color : str = "magnitude" How to color arrows. If "magnitude", colors by vector magnitude using cmap. Otherwise, should be a valid color string (e.g., 'red', 'blue', '#FF0000'). cbar : bool = True Display scalar bar (only if ``color == "magnitude"``). cmap : str = "Spectral" Color map to use when coloring by magnitude. show : bool = True Whether to display the plot immediately. windowed : bool = None Whether to display in an external window. If None (default), automatically detects environment (inline in notebooks, windowed otherwise). Set to True to force external window. Only used when ``plotter`` is None. window_size : tuple = (800, 600) Size of the window (width, height) when creating new plotter. **arrow_kwargs Additional keyword arguments passed to plotter.add_mesh() for the arrow glyphs. Returns ------- pyvista.Plotter The plotter object with arrows added. Returns None if ``show=True`` and plotter was auto-created. Raises ------ DataError If dataset doesn't contain exactly 3 fields for vector components. """ # Validate vector field if self._num_fields != 3: raise DataError( "Unstructured dataset must contain exactly 3 fields for quiver plotting. " "Use '.sel()' to select fields from available dimensions " f"{self._non_spatial_coords_dict} before plotting." ) # Extract downsampled points points = self.points.data[::downsampling] # Extract vector components u = self.values.sel(**{dim: 0}).real.data[::downsampling] v = self.values.sel(**{dim: 1}).real.data[::downsampling] w = self.values.sel(**{dim: 2}).real.data[::downsampling] vectors = np.column_stack([u, v, w]) # Compute magnitude mag = np.linalg.norm(vectors, axis=1) mag_max = np.max(mag) # Compute scaling factor bounds = np.array(self.bounds) size = np.subtract(bounds[1], bounds[0]) diag = np.linalg.norm(size) scale_factor = scale * diag / mag_max if mag_max > 0 else scale * diag # Get PyVista module pv = pyvista["mod"] # Create point cloud with scaled vectors point_cloud = pv.PolyData(points) point_cloud["vectors"] = vectors # Create arrow glyphs arrows = point_cloud.glyph( orient="vectors", scale=True, # We already scaled the vectors factor=scale_factor, geom=pv.Arrow(), ) # Add magnitude data for coloring if color == "magnitude": # Each point generates multiple vertices for the arrow geometry # We need to repeat the magnitude for all vertices of each arrow n_points_per_arrow = arrows.n_points // len(points) arrows["magnitude"] = np.repeat(mag, n_points_per_arrow) plotter.add_mesh(arrows, scalars="magnitude", cmap=cmap, **arrow_kwargs) if cbar: scalar_bar_args = { "title": self.values.name or "Magnitude", "n_labels": 5, "italic": False, "fmt": "%.2e", "font_family": "arial", } plotter.add_scalar_bar(**scalar_bar_args) else: plotter.add_mesh(arrows, color=color, **arrow_kwargs) # Set axis labels plotter.add_axes(xlabel="x", ylabel="y", zlabel="z") return plotter