"""Collections of DataArrays."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Union, Dict, Callable, Any
import xarray as xr
import numpy as np
import pydantic.v1 as pd
from matplotlib.tri import Triangulation
from matplotlib import pyplot as plt
import numbers
from .data_array import DataArray
from .data_array import ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray
from .data_array import ModeIndexDataArray, GroupIndexDataArray, ModeDispersionDataArray
from .data_array import TriangleMeshDataArray
from .data_array import TimeDataArray
from .data_array import PointDataArray, IndexedDataArray, CellDataArray, SpatialDataArray
from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
from ..base import Tidy3dBaseModel, cached_property
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal
from ...packaging import vtk, requires_vtk
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
from ...constants import PICOSECOND_PER_NANOMETER_PER_KILOMETER
from ...log import log
class Dataset(Tidy3dBaseModel, ABC):
"""Abstract base class for objects that store collections of `:class:`.DataArray`s."""
[docs]
class AbstractFieldDataset(Dataset, ABC):
"""Collection of scalar fields with some symmetry properties."""
@property
@abstractmethod
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to their associated data."""
[docs]
def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy where all elements are phase-shifted by a value (in radians)."""
if phase == 0.0:
return self
phasor = np.exp(1j * phase)
field_components_shifted = {}
for fld_name, fld_cmp in self.field_components.items():
fld_cmp_shifted = phasor * fld_cmp
field_components_shifted[fld_name] = fld_cmp_shifted
return self.updated_copy(**field_components_shifted)
@property
@abstractmethod
def grid_locations(self) -> Dict[str, str]:
"""Maps field components to the string key of their grid locations on the yee lattice."""
@property
@abstractmethod
def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]:
"""Maps field components to their (positive) symmetry eigenvalues."""
[docs]
def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArray]) -> Any:
"""How to package the dictionary of fields computed via self.colocate()."""
return xr.Dataset(centered_fields)
[docs]
def colocate(self, x=None, y=None, z=None) -> xr.Dataset:
"""Colocate all of the data at a set of x, y, z coordinates.
Parameters
----------
x : Optional[array-like] = None
x coordinates of locations.
If not supplied, does not try to colocate on this dimension.
y : Optional[array-like] = None
y coordinates of locations.
If not supplied, does not try to colocate on this dimension.
z : Optional[array-like] = None
z coordinates of locations.
If not supplied, does not try to colocate on this dimension.
Returns
-------
xr.Dataset
Dataset containing all fields at the same spatial locations.
For more details refer to `xarray's Documentation <https://tinyurl.com/cyca3krz>`_.
Note
----
For many operations (such as flux calculations and plotting),
it is important that the fields are colocated at the same spatial locations.
Be sure to apply this method to your field data in those cases.
"""
if hasattr(self, "monitor") and self.monitor.colocate:
with log as consolidated_logger:
consolidated_logger.warning(
"Colocating data that has already been colocated during the solver "
"run. For most accurate results when colocating to custom coordinates set "
"'Monitor.colocate' to 'False' to use the raw data on the Yee grid "
"and avoid double interpolation. Note: the default value was changed to 'True' "
"in Tidy3D version 2.4.0."
)
# convert supplied coordinates to array and assign string mapping to them
supplied_coord_map = {k: np.array(v) for k, v in zip("xyz", (x, y, z)) if v is not None}
# dict of data arrays to combine in dataset and return
centered_fields = {}
# loop through field components
for field_name, field_data in self.field_components.items():
# loop through x, y, z dimensions and raise an error if only one element along dim
for coord_name, coords_supplied in supplied_coord_map.items():
coord_data = np.array(field_data.coords[coord_name])
if coord_data.size == 1:
raise DataError(
f"colocate given {coord_name}={coords_supplied}, but "
f"data only has one coordinate at {coord_name}={coord_data[0]}. "
"Therefore, can't colocate along this dimension. "
f"supply {coord_name}=None to skip it."
)
centered_fields[field_name] = field_data.interp(
**supplied_coord_map, kwargs={"bounds_error": True}
)
# combine all centered fields in a dataset
return self.package_colocate_results(centered_fields)
EMScalarFieldType = Union[ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray]
class ElectromagneticFieldDataset(AbstractFieldDataset, ABC):
"""Stores a collection of E and H fields with x, y, z components."""
Ex: EMScalarFieldType = pd.Field(
None,
title="Ex",
description="Spatial distribution of the x-component of the electric field.",
)
Ey: EMScalarFieldType = pd.Field(
None,
title="Ey",
description="Spatial distribution of the y-component of the electric field.",
)
Ez: EMScalarFieldType = pd.Field(
None,
title="Ez",
description="Spatial distribution of the z-component of the electric field.",
)
Hx: EMScalarFieldType = pd.Field(
None,
title="Hx",
description="Spatial distribution of the x-component of the magnetic field.",
)
Hy: EMScalarFieldType = pd.Field(
None,
title="Hy",
description="Spatial distribution of the y-component of the magnetic field.",
)
Hz: EMScalarFieldType = pd.Field(
None,
title="Hz",
description="Spatial distribution of the z-component of the magnetic field.",
)
@property
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to their associated data."""
fields = {
"Ex": self.Ex,
"Ey": self.Ey,
"Ez": self.Ez,
"Hx": self.Hx,
"Hy": self.Hy,
"Hz": self.Hz,
}
return {field_name: field for field_name, field in fields.items() if field is not None}
@property
def grid_locations(self) -> Dict[str, str]:
"""Maps field components to the string key of their grid locations on the yee lattice."""
return dict(Ex="Ex", Ey="Ey", Ez="Ez", Hx="Hx", Hy="Hy", Hz="Hz")
@property
def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]:
"""Maps field components to their (positive) symmetry eigenvalues."""
return dict(
Ex=lambda dim: -1 if (dim == 0) else +1,
Ey=lambda dim: -1 if (dim == 1) else +1,
Ez=lambda dim: -1 if (dim == 2) else +1,
Hx=lambda dim: +1 if (dim == 0) else -1,
Hy=lambda dim: +1 if (dim == 1) else -1,
Hz=lambda dim: +1 if (dim == 2) else -1,
)
[docs]
class FieldDataset(ElectromagneticFieldDataset):
"""Dataset storing a collection of the scalar components of E and H fields in the freq. domain
Example
-------
>>> x = [-1,1]
>>> y = [-2,0,2]
>>> z = [-3,-1,1,3]
>>> f = [2e14, 3e14]
>>> coords = dict(x=x, y=y, z=z, f=f)
>>> scalar_field = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords)
>>> data = FieldDataset(Ex=scalar_field, Hz=scalar_field)
"""
Ex: ScalarFieldDataArray = pd.Field(
None,
title="Ex",
description="Spatial distribution of the x-component of the electric field.",
)
Ey: ScalarFieldDataArray = pd.Field(
None,
title="Ey",
description="Spatial distribution of the y-component of the electric field.",
)
Ez: ScalarFieldDataArray = pd.Field(
None,
title="Ez",
description="Spatial distribution of the z-component of the electric field.",
)
Hx: ScalarFieldDataArray = pd.Field(
None,
title="Hx",
description="Spatial distribution of the x-component of the magnetic field.",
)
Hy: ScalarFieldDataArray = pd.Field(
None,
title="Hy",
description="Spatial distribution of the y-component of the magnetic field.",
)
Hz: ScalarFieldDataArray = pd.Field(
None,
title="Hz",
description="Spatial distribution of the z-component of the magnetic field.",
)
[docs]
class FieldTimeDataset(ElectromagneticFieldDataset):
"""Dataset storing a collection of the scalar components of E and H fields in the time domain
Example
-------
>>> x = [-1,1]
>>> y = [-2,0,2]
>>> z = [-3,-1,1,3]
>>> t = [0, 1e-12, 2e-12]
>>> coords = dict(x=x, y=y, z=z, t=t)
>>> scalar_field = ScalarFieldTimeDataArray(np.random.random((2,3,4,3)), coords=coords)
>>> data = FieldTimeDataset(Ex=scalar_field, Hz=scalar_field)
"""
Ex: ScalarFieldTimeDataArray = pd.Field(
None,
title="Ex",
description="Spatial distribution of the x-component of the electric field.",
)
Ey: ScalarFieldTimeDataArray = pd.Field(
None,
title="Ey",
description="Spatial distribution of the y-component of the electric field.",
)
Ez: ScalarFieldTimeDataArray = pd.Field(
None,
title="Ez",
description="Spatial distribution of the z-component of the electric field.",
)
Hx: ScalarFieldTimeDataArray = pd.Field(
None,
title="Hx",
description="Spatial distribution of the x-component of the magnetic field.",
)
Hy: ScalarFieldTimeDataArray = pd.Field(
None,
title="Hy",
description="Spatial distribution of the y-component of the magnetic field.",
)
Hz: ScalarFieldTimeDataArray = pd.Field(
None,
title="Hz",
description="Spatial distribution of the z-component of the magnetic field.",
)
[docs]
def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy where all elements are phase-shifted by a value (in radians)."""
if phase != 0.0:
raise ValueError("Can't apply phase to time-domain field data, which is real-valued.")
return self
[docs]
class ModeSolverDataset(ElectromagneticFieldDataset):
"""Dataset storing scalar components of E and H fields as a function of freq. and mode_index.
Example
-------
>>> from tidy3d import ModeSpec
>>> x = [-1,1]
>>> y = [0]
>>> z = [-3,-1,1,3]
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> field_coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index)
>>> field = ScalarModeFieldDataArray((1+1j)*np.random.random((2,1,4,2,5)), coords=field_coords)
>>> index_coords = dict(f=f, mode_index=mode_index)
>>> index_data = ModeIndexDataArray((1+1j) * np.random.random((2,5)), coords=index_coords)
>>> data = ModeSolverDataset(
... Ex=field,
... Ey=field,
... Ez=field,
... Hx=field,
... Hy=field,
... Hz=field,
... n_complex=index_data
... )
"""
Ex: ScalarModeFieldDataArray = pd.Field(
None,
title="Ex",
description="Spatial distribution of the x-component of the electric field of the mode.",
)
Ey: ScalarModeFieldDataArray = pd.Field(
None,
title="Ey",
description="Spatial distribution of the y-component of the electric field of the mode.",
)
Ez: ScalarModeFieldDataArray = pd.Field(
None,
title="Ez",
description="Spatial distribution of the z-component of the electric field of the mode.",
)
Hx: ScalarModeFieldDataArray = pd.Field(
None,
title="Hx",
description="Spatial distribution of the x-component of the magnetic field of the mode.",
)
Hy: ScalarModeFieldDataArray = pd.Field(
None,
title="Hy",
description="Spatial distribution of the y-component of the magnetic field of the mode.",
)
Hz: ScalarModeFieldDataArray = pd.Field(
None,
title="Hz",
description="Spatial distribution of the z-component of the magnetic field of the mode.",
)
n_complex: ModeIndexDataArray = pd.Field(
...,
title="Propagation Index",
description="Complex-valued effective propagation constants associated with the mode.",
)
n_group_raw: GroupIndexDataArray = pd.Field(
None,
alias="n_group", # This is for backwards compatibility only when loading old data
title="Group Index",
description="Index associated with group velocity of the mode.",
)
dispersion_raw: ModeDispersionDataArray = pd.Field(
None,
title="Dispersion",
description="Dispersion parameter for the mode.",
units=PICOSECOND_PER_NANOMETER_PER_KILOMETER,
)
@property
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to their associated data."""
fields = {
"Ex": self.Ex,
"Ey": self.Ey,
"Ez": self.Ez,
"Hx": self.Hx,
"Hy": self.Hy,
"Hz": self.Hz,
}
return {field_name: field for field_name, field in fields.items() if field is not None}
@property
def n_eff(self) -> ModeIndexDataArray:
"""Real part of the propagation index."""
return self.n_complex.real
@property
def k_eff(self) -> ModeIndexDataArray:
"""Imaginary part of the propagation index."""
return self.n_complex.imag
@property
def n_group(self) -> GroupIndexDataArray:
"""Group index."""
if self.n_group_raw is None:
log.warning(
"The group index was not computed. To calculate group index, pass "
"'group_index_step = True' in the 'ModeSpec'.",
log_once=True,
)
return self.n_group_raw
@property
def dispersion(self) -> ModeDispersionDataArray:
r"""Dispersion parameter.
.. math::
D = -\frac{\lambda}{c_0} \frac{{\rm d}^2 n_{\text{eff}}}{{\rm d}\lambda^2}
"""
if self.dispersion_raw is None:
log.warning(
"The dispersion was not computed. To calculate dispersion, pass "
"'group_index_step = True' in the 'ModeSpec'.",
log_once=True,
)
return self.dispersion_raw
[docs]
def plot_field(self, *args, **kwargs):
"""Warn user to use the :class:`.ModeSolver` ``plot_field`` function now."""
raise DeprecationWarning(
"The 'plot_field()' method was moved to the 'ModeSolver' object."
"Once the 'ModeSolver' is constructed, one may call '.plot_field()' on the object and "
"the modes will be computed and displayed with 'Simulation' overlay."
)
[docs]
class PermittivityDataset(AbstractFieldDataset):
"""Dataset storing the diagonal components of the permittivity tensor.
Example
-------
>>> x = [-1,1]
>>> y = [-2,0,2]
>>> z = [-3,-1,1,3]
>>> f = [2e14, 3e14]
>>> coords = dict(x=x, y=y, z=z, f=f)
>>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords)
>>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld)
"""
@property
def field_components(self) -> Dict[str, ScalarFieldDataArray]:
"""Maps the field components to their associated data."""
return dict(eps_xx=self.eps_xx, eps_yy=self.eps_yy, eps_zz=self.eps_zz)
@property
def grid_locations(self) -> Dict[str, str]:
"""Maps field components to the string key of their grid locations on the yee lattice."""
return dict(eps_xx="Ex", eps_yy="Ey", eps_zz="Ez")
@property
def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]:
"""Maps field components to their (positive) symmetry eigenvalues."""
return dict(eps_xx=None, eps_yy=None, eps_zz=None)
eps_xx: ScalarFieldDataArray = pd.Field(
...,
title="Epsilon xx",
description="Spatial distribution of the xx-component of the relative permittivity.",
)
eps_yy: ScalarFieldDataArray = pd.Field(
...,
title="Epsilon yy",
description="Spatial distribution of the yy-component of the relative permittivity.",
)
eps_zz: ScalarFieldDataArray = pd.Field(
...,
title="Epsilon zz",
description="Spatial distribution of the zz-component of the relative permittivity.",
)
class TriangleMeshDataset(Dataset):
"""Dataset for storing triangular surface data."""
surface_mesh: TriangleMeshDataArray = pd.Field(
...,
title="Surface mesh data",
description="Dataset containing the surface triangles and corresponding face indices "
"for a surface mesh.",
)
class TimeDataset(Dataset):
"""Dataset for storing a function of time."""
values: TimeDataArray = pd.Field(
..., title="Values", description="Values as a function of time."
)
class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC):
"""Abstract base for datasets that store unstructured grid data."""
points: PointDataArray = pd.Field(
...,
title="Grid Points",
description="Coordinates of points composing the unstructured grid.",
)
values: IndexedDataArray = pd.Field(
...,
title="Point Values",
description="Values stored at the grid points.",
)
cells: CellDataArray = pd.Field(
...,
title="Grid Cells",
description="Cells composing the unstructured grid specified as connections between grid "
"points.",
)
@property
def name(self) -> str:
"""Dataset name."""
# we redirect name to values.name
return self.values.name
@pd.validator("cells", always=True)
def match_cells_to_vtk_type(cls, val):
"""Check that cell connections does not have duplicate points."""
if vtk is None:
return val
# using val.astype(np.int32/64) directly causes issues when dataarray are later checked ==
return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords)
@pd.validator("values", always=True)
@skip_if_fields_missing(["points"])
def number_of_values_matches_points(cls, val, values):
"""Check that the number of data values matches the number of grid points."""
num_values = len(val)
points = values.get("points")
num_points = len(points)
if num_points != num_values:
raise ValidationError(
f"The number of data values ({num_values}) does not match the number of grid "
f"points ({num_points})."
)
return val
@pd.validator("points", always=True)
def points_right_dims(cls, val):
"""Check that point coordinates have the right dimensionality."""
axis_coords_expected = np.arange(cls._point_dims())
axis_coords_given = val.axis.data
if np.any(axis_coords_given != axis_coords_expected):
raise ValidationError(
f"Points array is expected to have {axis_coords_expected} coord values along 'axis'"
f" (given: {axis_coords_given})."
)
return val
@pd.validator("cells", always=True)
def cells_right_type(cls, val):
"""Check that cell are of the right type."""
vertex_coords_expected = np.arange(cls._cell_num_vertices())
vertex_coords_given = val.vertex_index.data
if np.any(vertex_coords_given != vertex_coords_expected):
raise ValidationError(
f"Cell connections array is expected to have {vertex_coords_expected} coord values"
f" along 'vertex_index' (given: {vertex_coords_given})."
)
return val
@pd.validator("cells", always=True)
@skip_if_fields_missing(["points"])
def check_cell_vertex_range(cls, val, values):
"""Check that cell connections use only defined points."""
all_point_indices_used = val.data.ravel()
min_index_used = np.min(all_point_indices_used)
max_index_used = np.max(all_point_indices_used)
points = values.get("points")
num_points = len(points)
if max_index_used != num_points - 1 or min_index_used != 0:
raise ValidationError(
"Cell connections array uses undefined point indices in the range "
f"[{min_index_used}, {max_index_used}]. The valid range of point indices is "
f"[0, {num_points-1}]."
)
return val
@pd.validator("cells", always=True)
def check_valid_cells(cls, val):
"""Check that cell connections does not have duplicate points."""
indices = val.data
for i in range(cls._cell_num_vertices() - 1):
for j in range(i + 1, cls._cell_num_vertices()):
if np.any(indices[:, i] == indices[:, j]):
log.warning("Unstructured grid contains degenerate cells.")
return val
def rename(self, name: str) -> UnstructuredGridDataset:
"""Return a renamed array."""
return self.updated_copy(values=self.values.rename(name))
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Override of numpy functions."""
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with the same class or a scalar
if not isinstance(x, (numbers.Number, type(self))):
return Tidy3dNotImplementedError
# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(x.values if isinstance(x, UnstructuredGridDataset) else x for x in inputs)
if out:
kwargs["out"] = tuple(
x.values if isinstance(x, UnstructuredGridDataset) else x for x in out
)
result = getattr(ufunc, method)(*inputs, **kwargs)
if type(result) is tuple:
# multiple return values
return tuple(self.updated_copy(values=x) for x in result)
elif method == "at":
# no return value
return None
else:
# one return value
return self.updated_copy(values=result)
@property
def real(self) -> UnstructuredGridDataset:
"""Real part of dataset."""
return self.updated_copy(values=self.values.real)
@property
def imag(self) -> UnstructuredGridDataset:
"""Imaginary part of dataset."""
return self.updated_copy(values=self.values.imag)
@property
def abs(self) -> UnstructuredGridDataset:
"""Absolute value of dataset."""
return self.updated_copy(values=self.values.abs)
@cached_property
def bounds(self) -> Bound:
"""Grid bounds."""
return tuple(np.min(self.points.data, axis=0)), tuple(np.max(self.points.data, axis=0))
@classmethod
@abstractmethod
def _point_dims(cls) -> pd.PositiveInt:
"""Dimensionality of stored grid point coordinates."""
@cached_property
@abstractmethod
def _points_3d_array(self) -> Bound:
"""3D coordinates of grid points."""
@classmethod
@abstractmethod
def _cell_num_vertices(cls) -> pd.PositiveInt:
"""Number of vertices in a cell."""
@classmethod
@abstractmethod
@requires_vtk
def _vtk_cell_type(cls):
"""VTK cell type to use in the VTK representation."""
@cached_property
def _vtk_offsets(self) -> ArrayLike:
"""Offsets array to use in the VTK representation."""
offsets = np.arange(len(self.cells) + 1) * self._cell_num_vertices()
if vtk is None:
return offsets
return offsets.astype(vtk["id_type"], copy=False)
@property
@requires_vtk
def _vtk_cells(self):
"""VTK cell array to use in the VTK representation."""
cells = vtk["mod"].vtkCellArray()
cells.SetData(
vtk["numpy_to_vtkIdTypeArray"](self._vtk_offsets),
vtk["numpy_to_vtkIdTypeArray"](self.cells.data.ravel()),
)
return cells
@property
@requires_vtk
def _vtk_points(self):
"""VTK point array to use in the VTK representation."""
pts = vtk["mod"].vtkPoints()
pts.SetData(vtk["numpy_to_vtk"](self._points_3d_array))
return pts
@property
@requires_vtk
def _vtk_obj(self):
"""A VTK representation (vtkUnstructuredGrid) of the grid."""
grid = vtk["mod"].vtkUnstructuredGrid()
grid.SetPoints(self._vtk_points)
grid.SetCells(self._vtk_cell_type(), self._vtk_cells)
point_data_vtk = vtk["numpy_to_vtk"](self.values.data)
point_data_vtk.SetName(self.values.name)
grid.GetPointData().AddArray(point_data_vtk)
return grid
@requires_vtk
def _plane_slice_raw(self, axis: Axis, pos: float):
"""Slice data with a plane and return the resulting VTK object."""
if pos > self.bounds[1][axis] or pos < self.bounds[0][axis]:
raise DataError(
f"Slicing plane (axis: {axis}, pos: {pos}) does not intersect the unstructured grid "
f"(extent along axis {axis}: {self.bounds[0][axis]}, {self.bounds[1][axis]})."
)
origin = [0, 0, 0]
origin[axis] = pos
normal = [0, 0, 0]
normal[axis] = 1
# create cutting plane
plane = vtk["mod"].vtkPlane()
plane.SetOrigin(origin[0], origin[1], origin[2])
plane.SetNormal(normal[0], normal[1], normal[2])
# create cutter
cutter = vtk["mod"].vtkPlaneCutter()
cutter.SetPlane(plane)
cutter.SetInputData(self._vtk_obj)
cutter.InterpolateAttributesOn()
cutter.Update()
# clean up the slice
cleaner = vtk["mod"].vtkCleanPolyData()
cleaner.SetInputData(cutter.GetOutput())
cleaner.Update()
return cleaner.GetOutput()
@abstractmethod
@requires_vtk
def plane_slice(
self, axis: Axis, pos: float
) -> Union[SpatialDataArray, UnstructuredGridDataset]:
"""Slice data with a plane and return the Tidy3D representation of the result
(``UnstructuredGridDataset``).
Parameters
----------
axis : Axis
The normal direction of the slicing plane.
pos : float
Position of the slicing plane along its normal direction.
Returns
-------
Union[SpatialDataArray, UnstructuredGridDataset]
The resulting slice.
"""
@staticmethod
@requires_vtk
def _read_vtkUnstructuredGrid(fname: str):
"""Load a :class:`vtkUnstructuredGrid` from a file."""
reader = vtk["mod"].vtkXMLUnstructuredGridReader()
reader.SetFileName(fname)
reader.Update()
grid = reader.GetOutput()
return grid
@classmethod
@abstractmethod
@requires_vtk
def _from_vtk_obj(cls, vtk_obj, field=None) -> UnstructuredGridDataset:
"""Initialize from a vtk object."""
@classmethod
@requires_vtk
def from_vtu(cls, file: str, field: str = None) -> UnstructuredGridDataset:
"""Load unstructured data from a vtu file.
Parameters
----------
fname : str
Full path to the .vtu file to load the unstructured data from.
field : str = None
Name of the field to load.
Returns
-------
UnstructuredGridDataset
Unstructured data.
"""
grid = cls._read_vtkUnstructuredGrid(file)
return cls._from_vtk_obj(grid, field=field)
@requires_vtk
def to_vtu(self, fname: str):
"""Exports unstructured grid data into a .vtu file.
Parameters
----------
fname : str
Full path to the .vtu file to save the unstructured data to.
"""
writer = vtk["mod"].vtkXMLUnstructuredGridWriter()
writer.SetFileName(fname)
writer.SetInputData(self._vtk_obj)
writer.Write()
@classmethod
@requires_vtk
def _get_values_from_vtk(
cls, vtk_obj, num_points: pd.PositiveInt, field: str = None
) -> IndexedDataArray:
"""Get point data values from a VTK object."""
point_data = vtk_obj.GetPointData()
num_point_arrays = point_data.GetNumberOfArrays()
if num_point_arrays == 0:
log.warning(
"No point data is found in a VTK object. '.values' will be initialized to zeros."
)
values_numpy = np.zeros(num_points)
values_name = None
else:
if field is not None:
array_vtk = point_data.GetAbstractArray(field)
else:
array_vtk = point_data.GetAbstractArray(0)
# currently we assume there is only one point data array provided in the VTK object
if num_point_arrays > 1 and field is None:
array_name = array_vtk.GetName()
log.warning(
f"{num_point_arrays} point data arrays are found in a VTK object. "
f"Only the first array (name: {array_name}) will be used to initialize "
"'.values' while the rest will be ignored."
)
# currently we assume data is scalar
num_components = array_vtk.GetNumberOfComponents()
if num_components > 1:
raise DataError(
f"Found point data array in a VTK object is expected to have only 1 component. Found {num_components} components."
)
# check that number of values matches number of grid points
num_tuples = array_vtk.GetNumberOfTuples()
if num_tuples != num_points:
raise DataError(
f"The length of found point data array ({num_tuples}) does not match the number of grid points ({num_points})."
)
values_numpy = vtk["vtk_to_numpy"](array_vtk)
values_name = array_vtk.GetName()
values = IndexedDataArray(
values_numpy, coords=dict(index=np.arange(len(values_numpy))), name=values_name
)
return values
@requires_vtk
def box_clip(self, bounds: Bound) -> UnstructuredGridDataset:
"""Clip the unstructured grid using a box defined by ``bounds``.
Parameters
----------
bounds : Tuple[float, float, float], Tuple[float, float float]
Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``.
Returns
-------
UnstructuredGridDataset
Clipped grid.
"""
# make and run a VTK clipper
clipper = vtk["mod"].vtkBoxClipDataSet()
clipper.SetOrientation(0)
clipper.SetBoxClip(
bounds[0][0], bounds[1][0], bounds[0][1], bounds[1][1], bounds[0][2], bounds[1][2]
)
clipper.SetInputData(self._vtk_obj)
clipper.GenerateClipScalarsOn()
clipper.GenerateClippedOutputOff()
clipper.Update()
clip = clipper.GetOutput()
# clean grid from unused points
grid_cleaner = vtk["mod"].vtkRemoveUnusedPoints()
grid_cleaner.SetInputData(clip)
grid_cleaner.GenerateOriginalPointIdsOff()
grid_cleaner.Update()
clean_clip = grid_cleaner.GetOutput()
# no intersection check
if clean_clip.GetNumberOfPoints() == 0:
raise DataError("Clipping box does not intersect the unstructured grid.")
return self._from_vtk_obj(clean_clip)
@requires_vtk
def interp(
self,
x: Union[float, ArrayLike],
y: Union[float, ArrayLike],
z: Union[float, ArrayLike],
fill_value: float = 0,
) -> SpatialDataArray:
"""Interpolate data at provided x, y, and z.
Parameters
----------
x : Union[float, ArrayLike]
x-coordinates of sampling points.
y : Union[float, ArrayLike]
y-coordinates of sampling points.
z : Union[float, ArrayLike]
z-coordinates of sampling points.
fill_value : float = 0
Value to use when filling points without interpolated values.
Returns
-------
SpatialDataArray
Interpolated data.
"""
# calculate the resulting array shape
x = np.atleast_1d(x)
y = np.atleast_1d(y)
z = np.atleast_1d(z)
shape = (len(x), len(y), len(z))
# create a VTK rectilinear grid to sample onto
structured_grid = vtk["mod"].vtkRectilinearGrid()
structured_grid.SetDimensions(shape)
structured_grid.SetXCoordinates(vtk["numpy_to_vtk"](x))
structured_grid.SetYCoordinates(vtk["numpy_to_vtk"](y))
structured_grid.SetZCoordinates(vtk["numpy_to_vtk"](z))
# create and execute VTK interpolator
interpolator = vtk["mod"].vtkResampleWithDataSet()
interpolator.SetInputData(structured_grid)
interpolator.SetSourceData(self._vtk_obj)
interpolator.Update()
interpolated = interpolator.GetOutput()
# get results in a numpy representation
array_id = 0 if self.values.name is None else self.values.name
values_numpy = vtk["vtk_to_numpy"](interpolated.GetPointData().GetAbstractArray(array_id))
# fill points without interpolated values
if fill_value != 0:
mask = vtk["vtk_to_numpy"](
interpolated.GetPointData().GetAbstractArray("vtkValidPointMask")
)
values_numpy[mask != 1] = fill_value
# VTK arrays are the z-y-x order, reorder interpolation results to x-y-z order
values_reordered = np.transpose(np.reshape(values_numpy, shape[::-1]), (2, 1, 0))
return SpatialDataArray(values_reordered, coords=dict(x=x, y=y, z=z), name=self.values.name)
@abstractmethod
@requires_vtk
def sel(
self,
x: Union[float, ArrayLike] = None,
y: Union[float, ArrayLike] = None,
z: Union[float, ArrayLike] = None,
) -> Union[UnstructuredGridDataset, SpatialDataArray]:
"""Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z
must be provided.
Parameters
----------
x : Union[float, ArrayLike] = None
x-coordinate of the slice.
y : Union[float, ArrayLike] = None
y-coordinate of the slice.
z : Union[float, ArrayLike] = None
z-coordinate of the slice.
Returns
-------
Union[TriangularGridDataset, SpatialDataArray]
Extracted data.
"""
@requires_vtk
def reflect(
self, axis: Axis, center: float, reflection_only: bool = False
) -> UnstructuredGridDataset:
"""Reflect unstructured data across the plane define by parameters ``axis`` and ``center``.
By default the original data is preserved, setting ``reflection_only`` to ``True`` will
produce only deflected data.
Parameters
----------
axis : Literal[0, 1, 2]
Normal direction of the reflection plane.
center : float
Location of the reflection plane along its normal direction.
reflection_only : bool = False
Return only reflected data.
Returns
-------
UnstructuredGridDataset
Data after reflextion is performed.
"""
reflector = vtk["mod"].vtkReflectionFilter()
reflector.SetPlane([reflector.USE_X, reflector.USE_Y, reflector.USE_Z][axis])
reflector.SetCenter(center)
reflector.SetCopyInput(not reflection_only)
reflector.SetInputData(self._vtk_obj)
reflector.Update()
return self._from_vtk_obj(reflector.GetOutput())
class TriangularGridDataset(UnstructuredGridDataset):
"""Dataset for storing triangular grid data.
Note
----
To use full functionality of unstructured datasets one must install ``vtk`` package (``pip
install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured
datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations.
Example
-------
>>> tri_grid_points = PointDataArray(
... [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
... coords=dict(index=np.arange(4), axis=np.arange(2)),
... )
>>>
>>> tri_grid_cells = CellDataArray(
... [[0, 1, 2], [1, 2, 3]],
... coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)),
... )
>>>
>>> tri_grid_values = IndexedDataArray(
... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)),
... )
>>>
>>> tri_grid = TriangularGridDataset(
... normal_axis=1,
... normal_pos=0,
... points=tri_grid_points,
... cells=tri_grid_cells,
... values=tri_grid_values,
... )
"""
normal_axis: Axis = pd.Field(
...,
title="Grid Axis",
description="Orientation of the grid.",
)
normal_pos: float = pd.Field(
...,
title="Position",
description="Coordinate of the grid along the normal direction.",
)
@cached_property
def bounds(self) -> Bound:
"""Grid bounds."""
bounds_2d = super().bounds
bounds_3d = self._points_2d_to_3d(bounds_2d)
return tuple(bounds_3d[0]), tuple(bounds_3d[1])
@classmethod
def _point_dims(cls) -> pd.PositiveInt:
"""Dimensionality of stored grid point coordinates."""
return 2
def _points_2d_to_3d(self, pts: ArrayLike) -> ArrayLike:
"""Convert 2d points into 3d points."""
return np.insert(pts, obj=self.normal_axis, values=self.normal_pos, axis=1)
@cached_property
def _points_3d_array(self) -> ArrayLike:
"""3D representation of grid points."""
return self._points_2d_to_3d(self.points.data)
@classmethod
def _cell_num_vertices(cls) -> pd.PositiveInt:
"""Number of vertices in a cell."""
return 3
@classmethod
@requires_vtk
def _vtk_cell_type(cls):
"""VTK cell type to use in the VTK representation."""
return vtk["mod"].VTK_TRIANGLE
@classmethod
@requires_vtk
def _from_vtk_obj(cls, vtk_obj, field=None):
"""Initialize from a vtkUnstructuredGrid instance."""
# get points cells data from vtk object
if isinstance(vtk_obj, vtk["mod"].vtkPolyData):
cells_vtk = vtk_obj.GetPolys()
elif isinstance(vtk_obj, vtk["mod"].vtkUnstructuredGrid):
cells_vtk = vtk_obj.GetCells()
cells_numpy = vtk["vtk_to_numpy"](cells_vtk.GetConnectivityArray())
cell_offsets = vtk["vtk_to_numpy"](cells_vtk.GetOffsetsArray())
if not np.all(np.diff(cell_offsets) == cls._cell_num_vertices()):
raise DataError(
"Only triangular 'vtkUnstructuredGrid' or 'vtkPolyData' can be converted into "
"'TriangularGridDataset'."
)
points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData())
# data values are read directly into Tidy3D array
values = cls._get_values_from_vtk(vtk_obj, len(points_numpy), field)
# detect zero size dimension
bounds = np.max(points_numpy, axis=0) - np.min(points_numpy, axis=0)
zero_dims = np.where(np.isclose(bounds, 0))[0]
if len(zero_dims) != 1:
raise DataError(
f"Provided vtk grid does not represent a two dimensional grid. Found zero size dimensions are {zero_dims}."
)
normal_axis = zero_dims[0]
normal_pos = points_numpy[0][normal_axis]
tan_dims = [0, 1, 2]
tan_dims.remove(normal_axis)
# convert 3d coordinates into 2d
points_2d_numpy = points_numpy[:, tan_dims]
# create Tidy3D points and cells arrays
num_cells = len(cells_numpy) // cls._cell_num_vertices()
cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices()))
cells = CellDataArray(
cells_numpy,
coords=dict(
cell_index=np.arange(num_cells), vertex_index=np.arange(cls._cell_num_vertices())
),
)
points = PointDataArray(
points_2d_numpy,
coords=dict(index=np.arange(len(points_numpy)), axis=np.arange(cls._point_dims())),
)
return cls(
normal_axis=normal_axis,
normal_pos=normal_pos,
points=points,
cells=cells,
values=values,
)
@requires_vtk
def plane_slice(self, axis: Axis, pos: float) -> SpatialDataArray:
"""Slice data with a plane and return the resulting line as a SpatialDataArray.
Parameters
----------
axis : Axis
The normal direction of the slicing plane.
pos : float
Position of the slicing plane along its normal direction.
Returns
-------
SpatialDataArray
The resulting slice.
"""
if axis == self.normal_axis:
raise DataError(
f"Triangular grid (normal: {self.normal_axis}) cannot be sliced by a parallel "
"plane."
)
# perform slicing in vtk and get unprocessed points and values
slice_vtk = self._plane_slice_raw(axis=axis, pos=pos)
points_numpy = vtk["vtk_to_numpy"](slice_vtk.GetPoints().GetData())
values = self._get_values_from_vtk(slice_vtk, len(points_numpy))
# axis of the resulting line
slice_axis = 3 - self.normal_axis - axis
# sort found intersection in ascending order
sorting = np.argsort(points_numpy[:, slice_axis], kind="mergesort")
# assemble coords for SpatialDataArray
coords = [None, None, None]
coords[axis] = [pos]
coords[self.normal_axis] = [self.normal_pos]
coords[slice_axis] = points_numpy[sorting, slice_axis]
coords_dict = dict(zip("xyz", coords))
# reshape values from a 1d array into a 3d array
new_shape = [1, 1, 1]
new_shape[slice_axis] = len(values)
values_reshaped = np.reshape(values.data[sorting], new_shape)
return SpatialDataArray(values_reshaped, coords=coords_dict, name=self.values.name)
@property
def _triangulation_obj(self) -> Triangulation:
"""Matplotlib triangular representation of the grid to use in plotting."""
return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)
@equal_aspect
@add_ax_if_none
def plot(
self,
ax: Ax = None,
field: bool = True,
grid: bool = True,
cbar: bool = True,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
shading: Literal["gourand", "flat"] = "gouraud",
cbar_kwargs: Dict = None,
) -> Ax:
"""Plot the data field and/or the unstructured grid.
Parameters
----------
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
field : bool = True
Whether to plot the data field.
grid : bool = True
Whether to plot the unstructured grid.
cbar : bool = True
Display colorbar (only if ``field == True``).
cmap : str = "viridis"
Color map to use for plotting.
vmin : float = None
The lower bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
vmax : float = None
The upper bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
shading : Literal["gourand", "flat"] = "gourand"
Type of shading to use when plotting the data field.
cbar_kwargs : Dict = {}
Additional parameters passed to colorbar object.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
if cbar_kwargs is None:
cbar_kwargs = {}
if not (field or grid):
raise DataError("Nothing to plot ('field == False', 'grid == False').")
# plot data field if requested
if field:
plot_obj = ax.tripcolor(
self._triangulation_obj,
self.values.data,
shading=shading,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
if cbar:
label_kwargs = {}
if "label" not in cbar_kwargs:
label_kwargs["label"] = self.values.name
plt.colorbar(plot_obj, **cbar_kwargs, **label_kwargs)
# plot grid if requested
if grid:
ax.triplot(
self._triangulation_obj,
color=plot_params_grid.edgecolor,
linewidth=plot_params_grid.linewidth,
)
# set labels and titles
ax_labels = ["x", "y", "z"]
normal_axis_name = ax_labels.pop(self.normal_axis)
ax.set_xlabel(ax_labels[0])
ax.set_ylabel(ax_labels[1])
ax.set_title(f"{normal_axis_name} = {self.normal_pos}")
return ax
@requires_vtk
def interp(
self,
x: Union[float, ArrayLike],
y: Union[float, ArrayLike],
z: Union[float, ArrayLike],
fill_value: float = 0,
ignore_normal_pos: bool = True,
) -> SpatialDataArray:
"""Interpolate data at provided x, y, and z.
Parameters
----------
x : Union[float, ArrayLike]
x-coordinates of sampling points.
y : Union[float, ArrayLike]
y-coordinates of sampling points.
z : Union[float, ArrayLike]
z-coordinates of sampling points.
fill_value : float = 0
Value to use when filling points without interpolated values.
ignore_normal_pos : bool = True
Assume data is invariant along the normal direction to the grid plane.
Returns
-------
SpatialDataArray
Interpolated data.
"""
x = np.atleast_1d(x)
y = np.atleast_1d(y)
z = np.atleast_1d(z)
if ignore_normal_pos:
xyz = [x, y, z]
xyz[self.normal_axis] = [self.normal_pos]
interp_inplane = super().interp(**dict(zip("xyz", xyz)), fill_value=fill_value)
interp_broadcasted = np.broadcast_to(
interp_inplane, [len(np.atleast_1d(comp)) for comp in [x, y, z]]
)
return SpatialDataArray(
interp_broadcasted, coords=dict(x=x, y=y, z=z), name=self.values.name
)
return super().interp(x=x, y=y, z=z, fill_value=fill_value)
@requires_vtk
def sel(
self,
x: Union[float, ArrayLike] = None,
y: Union[float, ArrayLike] = None,
z: Union[float, ArrayLike] = None,
) -> SpatialDataArray:
"""Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z
must be provided.
Parameters
----------
x : Union[float, ArrayLike] = None
x-coordinate of the slice.
y : Union[float, ArrayLike] = None
y-coordinate of the slice.
z : Union[float, ArrayLike] = None
z-coordinate of the slice.
Returns
-------
SpatialDataArray
Extracted data.
"""
xyz = [x, y, z]
axes = [ind for ind, comp in enumerate(xyz) if comp is not None]
num_provided = len(axes)
if self.normal_axis in axes:
if xyz[self.normal_axis] != self.normal_pos:
raise DataError(
f"No data for {'xyz'[self.normal_axis]} = {xyz[self.normal_axis]} (unstructured"
f" grid is defined at {'xyz'[self.normal_axis]} = {self.normal_pos})."
)
if num_provided < 3:
num_provided -= 1
axes.remove(self.normal_axis)
if num_provided == 0:
raise DataError("At least one of 'x', 'y', and 'z' must be specified.")
if num_provided == 1:
axis = axes[0]
return self.plane_slice(axis=axis, pos=xyz[axis])
if num_provided == 2:
pos = [x, y, z]
pos[self.normal_axis] = [self.normal_pos]
return self.interp(x=pos[0], y=pos[1], z=pos[2])
if num_provided == 3:
return self.interp(x=x, y=y, z=z)
@requires_vtk
def reflect(
self, axis: Axis, center: float, reflection_only: bool = False
) -> UnstructuredGridDataset:
"""Reflect unstructured data across the plane define by parameters ``axis`` and ``center``.
By default the original data is preserved, setting ``reflection_only`` to ``True`` will
produce only deflected data.
Parameters
----------
axis : Literal[0, 1, 2]
Normal direction of the reflection plane.
center : float
Location of the reflection plane along its normal direction.
reflection_only : bool = False
Return only reflected data.
Returns
-------
UnstructuredGridDataset
Data after reflextion is performed.
"""
# disallow reflecting along normal direction
if axis == self.normal_axis:
raise DataError("Reflection in the normal direction to the grid is prohibited.")
return super().reflect(axis=axis, center=center, reflection_only=reflection_only)
class TetrahedralGridDataset(UnstructuredGridDataset):
"""Dataset for storing tetrahedral grid data.
Note
----
To use full functionality of unstructured datasets one must install ``vtk`` package (``pip
install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured
datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations.
Example
-------
>>> tet_grid_points = PointDataArray(
... [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
... coords=dict(index=np.arange(4), axis=np.arange(3)),
... )
>>>
>>> tet_grid_cells = CellDataArray(
... [[0, 1, 2, 3]],
... coords=dict(cell_index=np.arange(1), vertex_index=np.arange(4)),
... )
>>>
>>> tet_grid_values = IndexedDataArray(
... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)),
... )
>>>
>>> tet_grid = TetrahedralGridDataset(
... points=tet_grid_points,
... cells=tet_grid_cells,
... values=tet_grid_values,
... )
"""
@classmethod
def _point_dims(cls) -> pd.PositiveInt:
"""Dimensionality of stored grid point coordinates."""
return 3
@cached_property
def _points_3d_array(self) -> Bound:
"""3D coordinates of grid points."""
return self.points.data
@classmethod
def _cell_num_vertices(cls) -> pd.PositiveInt:
"""Number of vertices in a cell."""
return 4
@classmethod
@requires_vtk
def _vtk_cell_type(cls):
"""VTK cell type to use in the VTK representation."""
return vtk["mod"].VTK_TETRA
@classmethod
@requires_vtk
def _from_vtk_obj(cls, grid, field=None) -> TetrahedralGridDataset:
"""Initialize from a vtkUnstructuredGrid instance."""
# read point, cells, and values info from a vtk instance
cells_numpy = vtk["vtk_to_numpy"](grid.GetCells().GetConnectivityArray())
points_numpy = vtk["vtk_to_numpy"](grid.GetPoints().GetData())
values = cls._get_values_from_vtk(grid, len(points_numpy), field)
# verify cell_types
cells_types = vtk["vtk_to_numpy"](grid.GetCellTypesArray())
if not np.all(cells_types == cls._vtk_cell_type()):
raise DataError("Only tetrahedral 'vtkUnstructuredGrid' is currently supported")
# pack point and cell information into Tidy3D arrays
num_cells = len(cells_numpy) // cls._cell_num_vertices()
cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices()))
cells = CellDataArray(
cells_numpy,
coords=dict(
cell_index=np.arange(num_cells), vertex_index=np.arange(cls._cell_num_vertices())
),
)
points = PointDataArray(
points_numpy,
coords=dict(index=np.arange(len(points_numpy)), axis=np.arange(cls._point_dims())),
)
return cls(points=points, cells=cells, values=values)
@requires_vtk
def plane_slice(self, axis: Axis, pos: float) -> TriangularGridDataset:
"""Slice data with a plane and return the resulting :class:.`TriangularGridDataset`.
Parameters
----------
axis : Axis
The normal direction of the slicing plane.
pos : float
Position of the slicing plane along its normal direction.
Returns
-------
TriangularGridDataset
The resulting slice.
"""
slice_vtk = self._plane_slice_raw(axis=axis, pos=pos)
return TriangularGridDataset._from_vtk_obj(slice_vtk)
@requires_vtk
def line_slice(self, axis: Axis, pos: Coordinate) -> SpatialDataArray:
"""Slice data with a line and return the resulting :class:.`SpatialDataArray`.
Parameters
----------
axis : Axis
The axis of the slicing line.
pos : Tuple[float, float, float]
Position of the slicing line.
Returns
-------
SpatialDataArray
The resulting slice.
"""
bounds = self.bounds
start = list(pos)
end = list(pos)
start[axis] = bounds[0][axis]
end[axis] = bounds[1][axis]
# create cutting plane
line = vtk["mod"].vtkLineSource()
line.SetPoint1(start)
line.SetPoint2(end)
line.SetResolution(1)
# this should be done using vtkProbeLineFilter
# but for some reason it crashes Python
# so, we use a workaround:
# 1) extract cells that are intersected by line (to speed up further slicing)
# 2) do plane slice along first direction
# 3) do second plane slice along second direction
prober = vtk["mod"].vtkExtractCellsAlongPolyLine()
prober.SetSourceConnection(line.GetOutputPort())
prober.SetInputData(self._vtk_obj)
prober.Update()
extracted_cells_vtk = prober.GetOutput()
if extracted_cells_vtk.GetNumberOfPoints() == 0:
raise DataError("Slicing line does not intersect the unstructured grid.")
extracted_cells = TetrahedralGridDataset._from_vtk_obj(extracted_cells_vtk)
tan_dims = [0, 1, 2]
tan_dims.remove(axis)
# first plane slice
plane_slice = extracted_cells.plane_slice(axis=tan_dims[0], pos=pos[tan_dims[0]])
# second plane slice
line_slice = plane_slice.plane_slice(axis=tan_dims[1], pos=pos[tan_dims[1]])
return line_slice
@requires_vtk
def sel(
self,
x: Union[float, ArrayLike] = None,
y: Union[float, ArrayLike] = None,
z: Union[float, ArrayLike] = None,
) -> Union[TriangularGridDataset, SpatialDataArray]:
"""Extract/interpolate data along one or more Cartesian directions. At least of x, y, and z
must be provided.
Parameters
----------
x : Union[float, ArrayLike] = None
x-coordinate of the slice.
y : Union[float, ArrayLike] = None
y-coordinate of the slice.
z : Union[float, ArrayLike] = None
z-coordinate of the slice.
Returns
-------
Union[TriangularGridDataset, SpatialDataArray]
Extracted data.
"""
xyz = [x, y, z]
axes = [ind for ind, comp in enumerate(xyz) if comp is not None]
num_provided = len(axes)
if num_provided < 3 and any(not np.isscalar(comp) for comp in xyz if comp is not None):
raise DataError(
"Providing x, y, or z as array is only allowed for interpolation. That is, when all"
" three x, y, and z are provided or method '.interp()' is used explicitly."
)
if num_provided == 0:
raise DataError("At least one of 'x', 'y', and 'z' must be specified.")
if num_provided == 1:
axis = axes[0]
return self.plane_slice(axis=axis, pos=xyz[axis])
if num_provided == 2:
axis = 3 - axes[0] - axes[1]
xyz[axis] = 0
return self.line_slice(axis=axis, pos=xyz)
if num_provided == 3:
return self.interp(x=x, y=y, z=z)
UnstructuredGridDatasetType = Union[TriangularGridDataset, TetrahedralGridDataset]