"""Storing tidy3d data at it's most fundamental level as xr.DataArray objects"""
from __future__ import annotations
import pathlib
from abc import ABC
from typing import TYPE_CHECKING, Any, Union
import autograd.numpy as anp
import h5py
import numpy as np
import xarray as xr
from autograd.tracer import isbox
from pydantic_core import core_schema
from xarray.core import missing
from xarray.core.indexes import PandasIndex
from xarray.core.indexing import _outer_to_numpy_indexer
from xarray.core.utils import OrderedSet, either_dict_or_kwargs
from xarray.core.variable import as_variable
from tidy3d.compat import alignment
from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box
from tidy3d.components.geometry.bound_ops import bounds_contains
from tidy3d.constants import (
AMP,
HERTZ,
MICROMETER,
OHM,
PICOSECOND_PER_NANOMETER_PER_KILOMETER,
RADIAN,
SECOND,
VOLT,
WATT,
)
from tidy3d.exceptions import DataError, FileError, format_chained_exception_message
if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from os import PathLike
from typing import Optional
from numpy.typing import NDArray
from pydantic.annotated_handlers import GetCoreSchemaHandler
from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue
from xarray.core.types import InterpOptions, Self
from tidy3d.components.autograd import InterpolationType
from tidy3d.components.grid.grid import Coords
from tidy3d.components.types import Axis, Bound
from tidy3d.components.types.base import Coordinate
# maps the dimension names to their attributes
DIM_ATTRS = {
"x": {"units": MICROMETER, "long_name": "x position"},
"y": {"units": MICROMETER, "long_name": "y position"},
"z": {"units": MICROMETER, "long_name": "z position"},
"f": {"units": HERTZ, "long_name": "frequency"},
"t": {"units": SECOND, "long_name": "time"},
"direction": {"long_name": "propagation direction"},
"mode_index": {"long_name": "mode index"},
"terminal_label": {"long_name": "terminal label"},
"terminal_label_out": {"long_name": "output terminal label"},
"terminal_label_in": {"long_name": "input terminal label"},
"eme_port_index": {"long_name": "EME port index"},
"eme_cell_index": {"long_name": "EME cell index"},
"mode_index_in": {"long_name": "mode index in"},
"mode_index_out": {"long_name": "mode index out"},
"sweep_index": {"long_name": "sweep index"},
"theta": {"units": RADIAN, "long_name": "elevation angle"},
"phi": {"units": RADIAN, "long_name": "azimuth angle"},
"ux": {"long_name": "normalized kx"},
"uy": {"long_name": "normalized ky"},
"orders_x": {"long_name": "diffraction order"},
"orders_y": {"long_name": "diffraction order"},
"face_index": {"long_name": "face index"},
"vertex_index": {"long_name": "vertex index"},
"axis": {"long_name": "axis"},
}
# name of the DataArray.values in the hdf5 file (xarray's default name too)
DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__"
[docs]
class DataArray(xr.DataArray):
"""Subclass of ``xr.DataArray`` that requires _dims to match the keys of the coords."""
# Always set __slots__ = () to avoid xarray warnings
__slots__ = ()
# stores an ordered tuple of strings corresponding to the data dimensions
_dims = ()
# stores a dictionary of attributes corresponding to the data values
_data_attrs: dict[str, str] = {}
def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None:
# if data is a vanilla autograd box, convert to our box
if isbox(data) and not is_tidy_box(data):
data = TidyArrayBox.from_arraybox(data)
# do the same for xr.Variable or xr.DataArray type
elif isinstance(data, (xr.Variable, xr.DataArray)):
if isbox(data.data) and not is_tidy_box(data.data):
data.data = TidyArrayBox.from_arraybox(data.data)
super().__init__(data, *args, **kwargs)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""Core schema definition for validation & serialization."""
def _initial_parser(value: Any) -> Self:
if isinstance(value, cls):
return value
if isinstance(value, str) and value == cls.__name__:
raise DataError(
f"Trying to load '{cls.__name__}' from string placeholder '{value}' "
"but the actual data is missing. DataArrays are not typically stored "
"in JSON. Load from HDF5 or ensure the DataArray object is provided."
)
try:
instance = cls(value)
if not isinstance(instance, cls):
raise TypeError(
f"Constructor for {cls.__name__} returned unexpected type {type(instance)}"
)
return instance
except Exception as e:
raise ValueError(
f"Could not construct '{cls.__name__}' from input of type '{type(value)}'. "
f"Ensure input is compatible with xarray.DataArray constructor. Original error: {e}"
) from e
validation_schema = core_schema.no_info_plain_validator_function(_initial_parser)
validation_schema = core_schema.no_info_after_validator_function(
cls._validate_dims, validation_schema
)
validation_schema = core_schema.no_info_after_validator_function(
cls._assign_data_attrs, validation_schema
)
validation_schema = core_schema.no_info_after_validator_function(
cls._assign_coord_attrs, validation_schema
)
def _serialize_to_name(instance: Self) -> str:
return type(instance).__name__
# serialization behavior:
# - for JSON ('json' mode), use the _serialize_to_name function.
# - for Python ('python' mode), use Pydantic's default for the object type
serialization_schema = core_schema.plain_serializer_function_ser_schema(
_serialize_to_name,
return_schema=core_schema.str_schema(),
when_used="json",
)
return core_schema.json_or_python_schema(
python_schema=validation_schema,
json_schema=validation_schema, # Use same validation rules for JSON input
serialization=serialization_schema,
)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema_obj: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""JSON schema definition (defines how it LOOKS in a schema, not the data)."""
return {
"type": "string",
"title": cls.__name__,
"description": (
f"Placeholder for a '{cls.__name__}' object. Actual data is typically "
"serialized separately (e.g., via HDF5) and not embedded in JSON."
),
}
@classmethod
def _validate_dims(cls, val: Self) -> Self:
"""Make sure the dims are the same as ``_dims``, then put them in the correct order."""
if set(val.dims) != set(cls._dims):
raise ValueError(
f"Wrong dims for {cls.__name__}, expected '{cls._dims}', got '{val.dims}'"
)
if val.dims != cls._dims:
val = val.transpose(*cls._dims)
return val
@classmethod
def _assign_data_attrs(cls, val: Self) -> Self:
"""Assign the correct data attributes to the :class:`.DataArray`."""
for attr_name, attr_val in cls._data_attrs.items():
val.attrs[attr_name] = attr_val
return val
@classmethod
def _assign_coord_attrs(cls, val: Self) -> Self:
"""Assign the correct coordinate attributes to the :class:`.DataArray`."""
target_dims = set(val.dims) & set(cls._dims) & set(val.coords)
for dim in target_dims:
template = DIM_ATTRS.get(dim)
if not template:
continue
coord_attrs = val.coords[dim].attrs
missing = {k: v for k, v in template.items() if coord_attrs.get(k) != v}
coord_attrs.update(missing)
return val
def _interp_validator(self, field_name: Optional[str] = None) -> None:
"""Ensure the data can be interpolated or selected by checking for duplicate coordinates.
NOTE
----
This does not check every 'DataArray' by default. Instead, when required, this check can be
called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'.
"""
if field_name is None:
field_name = self.__class__.__name__
for dim, coord in self.coords.items():
if coord.to_index().duplicated().any():
raise DataError(
f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. "
"Duplicates can be removed by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'."
)
def __eq__(self, other: Any) -> bool:
"""Whether two data array objects are equal."""
if not isinstance(other, xr.DataArray):
return False
if not self.data.shape == other.data.shape or not np.all(self.data == other.data):
return False
for key, val in self.coords.items():
if not np.all(np.array(val) == np.array(other.coords[key])):
return False
return True
@property
def values(self) -> NDArray:
"""
The array's data converted to a numpy.ndarray.
Returns
-------
np.ndarray
The values of the DataArray.
"""
return self.data if isbox(self.data) else super().values
@values.setter
def values(self, value: Any) -> None:
self.variable.values = value
[docs]
def to_numpy(self) -> np.ndarray:
"""Return `.data` when traced to avoid `dtype=object` NumPy conversion."""
return self.data if isbox(self.data) else super().to_numpy()
@property
def abs(self) -> Self:
"""Absolute value of data array."""
return abs(self)
@property
def angle(self) -> Self:
"""Angle or phase value of data array."""
values = np.angle(self.values)
return type(self)(values, coords=self.coords)
@property
def is_uniform(self) -> bool:
"""Whether each element is of equal value in the data array"""
raw_data = self.data.ravel()
return np.allclose(raw_data, raw_data[0])
[docs]
def to_hdf5(self, fname: Union[PathLike, h5py.File], group_path: str) -> None:
"""Save an ``xr.DataArray`` to the hdf5 file or file handle with a given path to the group."""
if isinstance(fname, (str, pathlib.Path)):
path = pathlib.Path(fname)
path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(path, "w") as f_handle:
self.to_hdf5_handle(f_handle=f_handle, group_path=group_path)
else:
self.to_hdf5_handle(f_handle=fname, group_path=group_path)
[docs]
def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None:
"""Save an ``xr.DataArray`` to the hdf5 file handle with a given path to the group."""
sub_group = f_handle.create_group(group_path)
sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data)
for key, val in self.coords.items():
if val.dtype.kind == "U":
# Convert Unicode strings to list for HDF5 storage
sub_group[key] = val.values.tolist()
else:
sub_group[key] = val
@classmethod
def _from_hdf5_handle(cls, f_handle: h5py.File, group_path: str) -> Self:
"""Load a DataArray from an open hdf5 file handle with a given group path."""
sub_group = f_handle[group_path]
values = np.array(sub_group[DATA_ARRAY_VALUE_NAME])
coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group}
for key, val in coords.items():
if val.dtype == "O":
coords[key] = [byte_string.decode() for byte_string in val.tolist()]
return cls(values, coords=coords, dims=cls._dims)
[docs]
@classmethod
def from_hdf5(cls, fname: Union[PathLike, h5py.File], group_path: str) -> Self:
"""Load a DataArray from an hdf5 file or open file handle with a given group path."""
if isinstance(fname, h5py.File):
return cls._from_hdf5_handle(f_handle=fname, group_path=group_path)
path = pathlib.Path(fname)
with h5py.File(path, "r") as f_handle:
return cls._from_hdf5_handle(f_handle=f_handle, group_path=group_path)
[docs]
@classmethod
def from_file(cls, fname: PathLike, group_path: str) -> Self:
"""Load a DataArray from an hdf5 file with a given path to the group."""
path = pathlib.Path(fname)
if not any(suffix.lower() == ".hdf5" for suffix in path.suffixes):
raise FileError(
f"'DataArray' objects must be written to '.hdf5' format. Given filename of {path}."
)
return cls.from_hdf5(fname=path, group_path=group_path)
def __hash__(self) -> int:
"""Generate hash value for a :class:`.DataArray` instance, needed for custom components."""
import dask
token_str = dask.base.tokenize(self)
return hash(token_str)
[docs]
def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self:
"""Multiply self by value at indices."""
if isbox(self.data) or isbox(value):
return self._ag_multiply_at(value, coord_name, indices)
self_mult = self.copy()
self_mult[{coord_name: indices}] *= value
return self_mult
def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self:
"""Autograd multiply_at override when tracing."""
key = {coord_name: indices}
_, index_tuple, _ = self.variable._broadcast_indexes(key)
idx = _outer_to_numpy_indexer(index_tuple, self.data.shape)
mask = np.zeros(self.data.shape, dtype="?")
mask[idx] = True
return self.copy(deep=False, data=anp.where(mask, self.data * value, self.data))
[docs]
def interp(
self,
coords: Mapping[Any, Any] | None = None,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Mapping[str, Any] | None = None,
**coords_kwargs: Any,
) -> Self:
"""Interpolate this DataArray to new coordinate values.
Parameters
----------
coords : Union[Mapping[Any, Any], None] = None
A mapping from dimension names to new coordinate labels.
method : InterpOptions = "linear"
The interpolation method to use.
assume_sorted : bool = False
If True, skip sorting of coordinates.
kwargs : Union[Mapping[str, Any], None] = None
Additional keyword arguments to pass to the interpolation function.
**coords_kwargs : Any
The keyword arguments form of coords.
Returns
-------
DataArray
A new DataArray with interpolated values.
Raises
------
KeyError
If any of the specified coordinates are not in the DataArray.
"""
if isbox(self.data):
return self._ag_interp(coords, method, assume_sorted, kwargs, **coords_kwargs)
return super().interp(coords, method, assume_sorted, kwargs, **coords_kwargs)
def _ag_interp(
self,
coords: Union[Mapping[Any, Any], None] = None,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Union[Mapping[str, Any], None] = None,
**coords_kwargs: Any,
) -> Self:
"""Autograd interp override when tracing over self.data.
This implementation closely follows the interp implementation of xarray
to match its behavior as closely as possible while supporting autograd.
See:
- https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
- https://docs.xarray.dev/en/latest/generated/xarray.Dataset.interp.html
"""
if kwargs is None:
kwargs = {}
ds = self._to_temp_dataset()
coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(ds._validate_interp_indexers(coords))
if coords:
# Find shared dimensions between the dataset and the indexers
sdims = (
set(ds.dims)
.intersection(*[set(nx.dims) for nx in indexers.values()])
.difference(coords.keys())
)
indexers.update({d: ds.variables[d] for d in sdims})
obj = ds if assume_sorted else ds.sortby(list(coords))
# workaround to get a variable for a dimension without a coordinate
validated_indexers = {
k: (obj._variables.get(k, as_variable((k, range(obj.sizes[k])))), v)
for k, v in indexers.items()
}
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]
variables = {}
reindex = False
for name, var in obj._variables.items():
if name in indexers:
continue
dtype_kind = var.dtype.kind
if dtype_kind in "uifc":
# Interpolation for numeric types
var_indexers = {k: v for k, v in validated_indexers.items() if k in var.dims}
variables[name] = self._ag_interp_func(var, var_indexers, method, **kwargs)
elif dtype_kind in "ObU" and (validated_indexers.keys() & var.dims):
# Stepwise interpolation for non-numeric types
reindex = True
elif all(d not in indexers for d in var.dims):
# Keep variables not dependent on interpolated coords
variables[name] = var
if reindex:
# Reindex for non-numeric types
reindex_indexers = {k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)}
reindexed = alignment.reindex(
obj,
indexers=reindex_indexers,
method="nearest",
exclude_vars=variables.keys(),
)
indexes = dict(reindexed._indexes)
variables.update(reindexed.variables)
else:
# Get the indexes that are not being interpolated along
indexes = {k: v for k, v in obj._indexes.items() if k not in indexers}
# Get the coords that also exist in the variables
coord_names = obj._coord_names & variables.keys()
selected = ds._replace_with_new_dims(variables.copy(), coord_names, indexes=indexes)
# Attach indexer as coordinate
for k, v in indexers.items():
if v.dims == (k,):
index = PandasIndex(v, k, coord_dtype=v.dtype)
index_vars = index.create_variables({k: v})
indexes[k] = index
variables.update(index_vars)
else:
variables[k] = v
# Extract coordinates from indexers
coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords)
variables.update(coord_vars)
indexes.update(new_indexes)
coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
ds = ds._replace_with_new_dims(variables, coord_names, indexes=indexes)
return self._from_temp_dataset(ds)
@staticmethod
def _ag_interp_func(
var: xr.Variable,
indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]],
method: InterpolationType,
**kwargs: Any,
) -> xr.Variable:
"""
Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`.
The implementation follows xarray's interp implementation in xarray.core.missing,
but replaces some of the pre-processing as well as the actual interpolation
function with an autograd-compatible approach.
Parameters
----------
var : xr.Variable
The variable to be interpolated.
indexes_coords : dict
A dictionary mapping dimension names to coordinate values for interpolation.
method : Literal["nearest", "linear"]
The interpolation method to use.
**kwargs : dict
Additional keyword arguments to pass to the interpolation function.
Returns
-------
xr.Variable
The interpolated variable.
"""
if not indexes_coords:
return var.copy()
result = var
for indep_indexes_coords in missing.decompose_interp(indexes_coords):
var = result
# target dimensions
dims = list(indep_indexes_coords)
x, new_x = zip(*[indep_indexes_coords[d] for d in dims])
destination = missing.broadcast_variables(*new_x)
broadcast_dims = [d for d in var.dims if d not in dims]
original_dims = broadcast_dims + dims
new_dims = broadcast_dims + list(destination[0].dims)
x, new_x = missing._floatize_x(x, new_x)
permutation = [var.dims.index(dim) for dim in original_dims]
combined_permutation = permutation[-len(x) :] + permutation[: -len(x)]
data = anp.transpose(var.data, combined_permutation)
xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1)
result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs)
result = anp.moveaxis(result, 0, -1)
result = anp.reshape(result, result.shape[:-1] + new_x[0].shape)
result = xr.Variable(new_dims, result, attrs=var.attrs, fastpath=True)
out_dims: OrderedSet = OrderedSet()
for d in var.dims:
if d in dims:
out_dims.update(indep_indexes_coords[d][1].dims)
else:
out_dims.add(d)
if len(out_dims) > 1:
result = result.transpose(*out_dims)
return result
def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray:
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible
Constraints / Edge cases:
- `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays
- `data` will be reshaped to try to match `self.shape` except where `coords` present
"""
# make mask
mask = xr.zeros_like(self, dtype=bool)
mask.loc[coords] = True
# reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis
old_data = self.data
new_shape = list(old_data.shape)
for i, dim in enumerate(self.dims):
if dim in coords:
new_shape[i] = 1
try:
new_data = data.reshape(new_shape)
except ValueError as e:
raise ValueError(
format_chained_exception_message(
"Couldn't reshape the supplied 'data' to update 'DataArray'. The provided "
f"data was of shape {data.shape} and tried to reshape to {new_shape}. If "
"you encounter this error please raise an issue on the tidy3d github "
"repository with this context.",
e,
)
) from e
# broadcast data to repeat data along the selected dimensions to match mask
new_data = new_data + np.zeros_like(old_data)
new_data = np.where(mask, new_data, old_data)
return self.copy(deep=True, data=new_data)
[docs]
class FreqDataArray(DataArray):
"""Frequency-domain array.
Example
-------
>>> f = [2e14, 3e14]
>>> fd = FreqDataArray((1+1j) * np.random.random((2,)), coords=dict(f=f))
"""
__slots__ = ()
_dims = ("f",)
class FreqVoltageDataArray(DataArray):
"""Frequency-domain array.
Example
-------
>>> f = [2e14, 3e14]
>>> v = [0.1, 0.2, 0.3]
>>> coords = dict(f=f, v=v)
>>> fd = FreqVoltageDataArray((1+1j) * np.random.random((2, 3)), coords=coords)
"""
__slots__ = ()
_dims = (
"f",
"v",
)
[docs]
class ModeDataArray(DataArray):
"""Mode index data array.
Example
-------
>>> mode_index = np.arange(4)
>>> coords = dict(mode_index=mode_index)
>>> data = ModeDataArray((1+1j) * np.random.random(4), coords=coords)
"""
__slots__ = ()
_dims = ("mode_index",)
[docs]
class TerminalDataArray(DataArray):
"""Terminal index data array.
Example
-------
>>> terminal_label = ["t0", "t1", "t2", "t3", "t4"]
>>> coords = dict(terminal_label=terminal_label)
>>> data = TerminalDataArray((1+1j) * np.random.random(5), coords=coords)
"""
__slots__ = ()
_dims = ("terminal_label",)
[docs]
class FreqModeDataArray(DataArray):
"""Array over frequency and mode index.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> coords = dict(f=f, mode_index=mode_index)
>>> fd = FreqModeDataArray((1+1j) * np.random.random((2, 5)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index")
[docs]
class FreqTerminalDataArray(DataArray):
"""Array over frequency and terminal index.
Example
-------
>>> f = [2e14, 3e14]
>>> terminal_label = ["t0", "t1", "t2", "t3", "t4"]
>>> coords = dict(f=f, terminal_label=terminal_label)
>>> fd = FreqTerminalDataArray((1+1j) * np.random.random((2, 5)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "terminal_label")
class FreqModeModeDataArray(DataArray):
"""Array over frequency, mode index, and mode index.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index_out = np.arange(5)
>>> mode_index_in = np.arange(5)
>>> coords = dict(f=f, mode_index_out=mode_index_out, mode_index_in=mode_index_in)
>>> fd = FreqModeModeDataArray((1+1j) * np.random.random((2, 5, 5)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index_out", "mode_index_in")
[docs]
class FreqTerminalModeDataArray(DataArray):
"""Array over frequency, terminal index, and mode index.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> terminal_label = ["t0", "t1"]
>>> coords = dict(f=f, terminal_label=terminal_label, mode_index=mode_index)
>>> fd = FreqTerminalModeDataArray((1+1j) * np.random.random((2, 2, 5)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "terminal_label", "mode_index")
class FreqModeTerminalDataArray(DataArray):
"""Array over frequency, mode index, and terminal index.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> terminal_label = ["t0", "t1"]
>>> coords = dict(f=f, mode_index=mode_index, terminal_label=terminal_label)
>>> fd = FreqModeTerminalDataArray((1+1j) * np.random.random((2, 5, 2)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index", "terminal_label")
[docs]
class FreqTerminalTerminalDataArray(DataArray):
"""Array over frequency, terminal index, and terminal index.
Example
-------
>>> f = [2e14, 3e14]
>>> terminal_label_out = ["t0", "t1"]
>>> terminal_label_in = ["t0", "t1"]
>>> coords = dict(f=f, terminal_label_out=terminal_label_out, terminal_label_in=terminal_label_in)
>>> fd = FreqTerminalTerminalDataArray((1+1j) * np.random.random((2, 2, 2)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "terminal_label_out", "terminal_label_in")
class TimeDataArray(DataArray):
"""Time-domain array.
Example
-------
>>> t = [0, 1e-12, 2e-12]
>>> td = TimeDataArray((1+1j) * np.random.random((3,)), coords=dict(t=t))
"""
__slots__ = ()
_dims = ("t",)
class MixedModeDataArray(DataArray):
"""Scalar property associated with mode pairs
Example
-------
>>> f = [1e14, 2e14, 3e14]
>>> mode_index_0 = np.arange(4)
>>> mode_index_1 = np.arange(2)
>>> coords = dict(f=f, mode_index_0=mode_index_0, mode_index_1=mode_index_1)
>>> data = MixedModeDataArray((1+1j) * np.random.random((3, 4, 2)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index_0", "mode_index_1")
[docs]
class AbstractSpatialDataArray(DataArray, ABC):
"""Spatial distribution."""
__slots__ = ()
_dims = ("x", "y", "z")
_data_attrs = {"long_name": "field value"}
[docs]
def plot(self, *args: Any, field: bool = True, grid: bool = False, **kwargs: Any) -> Any:
"""Plot the spatial data.
Accepts the same arguments as xarray's ``DataArray.plot()``. The extra
``grid`` and ``field`` keyword arguments are accepted for API
compatibility with :meth:`TriangularGridDataset.plot` but grid overlay
is not supported on structured data.
Parameters
----------
field : bool = True
Whether to plot the data field. Must be ``True`` for structured
data.
grid : bool = False
Not supported for structured data. Raises ``DataError`` if
``True``.
"""
if grid:
raise DataError("The 'grid' argument is only supported for unstructured data.")
if not field:
raise DataError("The 'field' argument is only supported for unstructured data.")
PlotAccessor = xr.DataArray.plot
return PlotAccessor(self)(*args, **kwargs)
@property
def _spatially_sorted(self) -> Self:
"""Check whether sorted and sort if not."""
needs_sorting = []
for axis in "xyz":
axis_coords = np.atleast_1d(self.coords[axis].values)
if axis_coords.size > 1 and np.any(axis_coords[1:] < axis_coords[:-1]):
needs_sorting.append(axis)
if len(needs_sorting) > 0:
return self.sortby(needs_sorting)
return self
[docs]
def shifted_spatial_coords(self, center: Coordinate) -> Self:
"""Return a copy with spatial coordinates shifted by ``center``."""
shifted = self
for axis, dim in enumerate("xyz"):
if dim in shifted.coords:
coord_vals = np.asarray(shifted.coords[dim].data)
shifted = shifted.assign_coords({dim: coord_vals + center[axis]})
return shifted
[docs]
def interpolate_to_grid(
self,
grid: Coords,
*,
offset: Optional[Coordinate] = None,
method: InterpOptions = "linear",
target_dims: Optional[tuple[str, ...]] = None,
) -> Self:
"""Interpolate onto a target grid, with optional spatial offset and output ordering."""
if offset is None:
interpolated = grid.spatial_interp(self, method)
else:
interpolated = grid.spatial_interp(self.shifted_spatial_coords(offset), method)
if target_dims is not None and tuple(interpolated.dims) != tuple(target_dims):
interpolated = interpolated.transpose(*target_dims)
return interpolated
[docs]
def sel_inside(self, bounds: Bound, *, include_interp_padding: bool = True) -> Self:
"""Return a new SpatialDataArray that contains the minimal amount data necessary to cover
a spatial region defined by ``bounds``. Note that the returned data is sorted with respect
to spatial coordinates.
Parameters
----------
bounds : Tuple[float, float, float], Tuple[float, float float]
Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``.
Returns
-------
SpatialDataArray
Extracted spatial data array.
include_interp_padding : bool = True
If ``True`` (default), include neighbor points around bounds to support interpolation.
If ``False``, keep only points whose coordinates are inside bounds.
"""
if any(bmin > bmax for bmin, bmax in zip(*bounds)):
raise DataError(
"Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'."
)
# make sure data is sorted with respect to coordinates
sorted_self = self._spatially_sorted
if not include_interp_padding:
selected = sorted_self
for coord, smin, smax, dim in zip(
(sorted_self.x, sorted_self.y, sorted_self.z),
bounds[0],
bounds[1],
"xyz",
):
coord_vals = np.atleast_1d(coord.values)
if coord_vals.size <= 1:
continue
selected = selected.sel({dim: slice(smin, smax)})
return selected
inds_list = []
coords = (sorted_self.x, sorted_self.y, sorted_self.z)
for coord, smin, smax in zip(coords, bounds[0], bounds[1]):
length = len(coord)
# one point along direction, assume invariance
if length == 1:
comp_inds = [0]
else:
# if data does not cover structure at all take the closest index
if smax < coord[0]: # structure is completely on the left side
# take 2 if possible, so that linear iterpolation is possible
comp_inds = np.arange(0, max(2, length))
elif smin > coord[-1]: # structure is completely on the right side
# take 2 if possible, so that linear iterpolation is possible
comp_inds = np.arange(min(0, length - 2), length)
else:
if smin < coord[0]:
ind_min = 0
else:
ind_min = max(0, (coord >= smin).data.argmax() - 1)
if smax > coord[-1]:
ind_max = length - 1
else:
ind_max = (coord >= smax).data.argmax()
comp_inds = np.arange(ind_min, ind_max + 1)
inds_list.append(comp_inds)
return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2])
[docs]
def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool:
"""Check whether data fully covers specified by ``bounds`` spatial region. If data contains
only one point along a given direction, then it is assumed the data is constant along that
direction and coverage is not checked.
Parameters
----------
bounds : Tuple[float, float, float], Tuple[float, float float]
Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``.
rtol : float = 0.0
Relative tolerance for comparing bounds
atol : float = 0.0
Absolute tolerance for comparing bounds
Returns
-------
bool
Full cover check outcome.
"""
if any(bmin > bmax for bmin, bmax in zip(*bounds)):
raise DataError(
"Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'."
)
xyz = [self.x, self.y, self.z]
self_min = [0] * 3
self_max = [0] * 3
for dim in range(3):
coords = xyz[dim]
if len(coords) == 1:
self_min[dim] = bounds[0][dim]
self_max[dim] = bounds[1][dim]
else:
self_min[dim] = np.min(coords)
self_max[dim] = np.max(coords)
self_bounds = (tuple(self_min), tuple(self_max))
return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol)
[docs]
class SpatialDataArray(AbstractSpatialDataArray):
"""Spatial distribution.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> coords = dict(x=x, y=y, z=z)
>>> fd = SpatialDataArray((1+1j) * np.random.random((2,3,4)), coords=coords)
"""
__slots__ = ()
[docs]
def reflect(
self, axis: Axis, center: float, reflection_only: bool = False, symmetry: float = 1
) -> Self:
"""Reflect data across the plane define by parameters ``axis`` and ``center`` from right to
left. Note that the returned data is sorted with respect to spatial coordinates.
Parameters
----------
axis : Literal[0, 1, 2]
Normal direction of the reflection plane.
center : float
Location of the reflection plane along its normal direction.
reflection_only : bool = False
Return only reflected data.
symmetry : float = 1
Symmetry factor of the reflection.
Returns
-------
SpatialDataArray
Data after reflection is performed.
"""
sorted_self = self._spatially_sorted
coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values]
data = np.array(sorted_self.data)
data_left_bound = coords[axis][0]
if np.isclose(center, data_left_bound):
num_duplicates = 1
elif center > data_left_bound:
raise DataError("Reflection center must be outside and to the left of the data region.")
else:
num_duplicates = 0
if reflection_only:
coords[axis] = 2 * center - coords[axis]
coords_dict = dict(zip("xyz", coords))
tmp_arr = SpatialDataArray(sorted_self.data * symmetry, coords=coords_dict)
return tmp_arr.sortby("xyz"[axis])
shape = np.array(np.shape(data))
old_len = shape[axis]
shape[axis] = 2 * old_len - num_duplicates
ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])]
ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])]
ind_left[axis] = slice(old_len - 1, None, -1)
ind_right[axis] = slice(old_len - num_duplicates, None)
new_data = np.zeros(shape)
new_data[ind_left[0], ind_left[1], ind_left[2]] = data * symmetry
new_data[ind_right[0], ind_right[1], ind_right[2]] = data
new_coords = np.zeros(shape[axis])
new_coords[old_len - num_duplicates :] = coords[axis]
new_coords[old_len - 1 :: -1] = 2 * center - coords[axis]
coords[axis] = new_coords
coords_dict = dict(zip("xyz", coords))
return SpatialDataArray(new_data, coords=coords_dict)
[docs]
class ScalarFieldDataArray(AbstractSpatialDataArray):
"""Spatial distribution in the frequency-domain.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> f = [2e14, 3e14]
>>> coords = dict(x=x, y=y, z=z, f=f)
>>> fd = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "f")
[docs]
class ScalarFieldTimeDataArray(AbstractSpatialDataArray):
"""Spatial distribution in the time-domain.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> t = [0, 1e-12, 2e-12]
>>> coords = dict(x=x, y=y, z=z, t=t)
>>> fd = ScalarFieldTimeDataArray(np.random.random((2,3,4,3)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "t")
[docs]
class ScalarModeFieldDataArray(AbstractSpatialDataArray):
"""Spatial distribution of a mode in frequency-domain as a function of mode index.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index)
>>> fd = ScalarModeFieldDataArray((1+1j) * np.random.random((2,3,4,2,5)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "f", "mode_index")
class ScalarModeFieldCylindricalDataArray(AbstractSpatialDataArray):
"""Spatial distribution of a mode in frequency-domain as a function of mode index.
Example
-------
>>> rho = [1,2]
>>> theta = [2,3,4]
>>> axial = [3,4,5,6]
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> coords = dict(rho=rho, theta=theta, axial=axial, f=f, mode_index=mode_index)
>>> fd = ScalarModeFieldCylindricalDataArray((1+1j) * np.random.random((2,3,4,2,5)), coords=coords)
"""
__slots__ = ()
_dims = ("rho", "theta", "axial", "f", "mode_index")
class ScalarTerminalFieldDataArray(AbstractSpatialDataArray):
"""Spatial distribution of a terminal field in frequency-domain as a function of terminal index.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> f = [2e14, 3e14]
>>> terminal_label = ["t0", "t1", "t2"]
>>> coords = dict(x=x, y=y, z=z, f=f, terminal_label=terminal_label)
>>> fd = ScalarTerminalFieldDataArray((1+1j) * np.random.random((2,3,4,2,3)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "f", "terminal_label")
[docs]
class FluxDataArray(DataArray):
"""Flux through a surface in the frequency-domain.
Example
-------
>>> f = [2e14, 3e14]
>>> coords = dict(f=f)
>>> fd = FluxDataArray(np.random.random(2), coords=coords)
"""
__slots__ = ()
_dims = ("f",)
_data_attrs = {"units": WATT, "long_name": "flux"}
[docs]
class FluxTimeDataArray(DataArray):
"""Flux through a surface in the time-domain.
Example
-------
>>> t = [0, 1e-12, 2e-12]
>>> coords = dict(t=t)
>>> data = FluxTimeDataArray(np.random.random(3), coords=coords)
"""
__slots__ = ()
_dims = ("t",)
_data_attrs = {"units": WATT, "long_name": "flux"}
[docs]
class ModeAmpsDataArray(DataArray):
"""Forward and backward propagating complex-valued mode amplitudes.
Example
-------
>>> direction = ["+", "-"]
>>> f = [1e14, 2e14, 3e14]
>>> mode_index = np.arange(4)
>>> coords = dict(direction=direction, f=f, mode_index=mode_index)
>>> data = ModeAmpsDataArray((1+1j) * np.random.random((2, 3, 4)), coords=coords)
"""
__slots__ = ()
_dims = ("direction", "f", "mode_index")
_data_attrs = {"units": "sqrt(W)", "long_name": "mode amplitudes"}
[docs]
class ModeIndexDataArray(DataArray):
"""Complex-valued effective propagation index of a mode.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(4)
>>> coords = dict(f=f, mode_index=mode_index)
>>> data = ModeIndexDataArray((1+1j) * np.random.random((2,4)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index")
_data_attrs = {"long_name": "Propagation index"}
class GroupIndexDataArray(DataArray):
"""Group index of a mode.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(4)
>>> coords = dict(f=f, mode_index=mode_index)
>>> data = GroupIndexDataArray((1+1j) * np.random.random((2,4)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index")
_data_attrs = {"long_name": "Group index"}
class ModeDispersionDataArray(DataArray):
"""Dispersion parameter of a mode.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(4)
>>> coords = dict(f=f, mode_index=mode_index)
>>> data = ModeDispersionDataArray((1+1j) * np.random.random((2,4)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "mode_index")
_data_attrs = {
"long_name": "Dispersion parameter",
"units": PICOSECOND_PER_NANOMETER_PER_KILOMETER,
}
[docs]
class FieldProjectionAngleDataArray(DataArray):
"""Far fields in frequency domain as a function of angles theta and phi.
Example
-------
>>> f = np.linspace(1e14, 2e14, 10)
>>> r = np.atleast_1d(5)
>>> theta = np.linspace(0, np.pi, 10)
>>> phi = np.linspace(0, 2*np.pi, 20)
>>> coords = dict(r=r, theta=theta, phi=phi, f=f)
>>> values = (1+1j) * np.random.random((len(r), len(theta), len(phi), len(f)))
>>> data = FieldProjectionAngleDataArray(values, coords=coords)
"""
__slots__ = ()
_dims = ("r", "theta", "phi", "f")
_data_attrs = {"long_name": "radiation vectors"}
[docs]
class FieldProjectionCartesianDataArray(DataArray):
"""Far fields in frequency domain as a function of local x and y coordinates.
Example
-------
>>> f = np.linspace(1e14, 2e14, 10)
>>> x = np.linspace(0, 5, 10)
>>> y = np.linspace(0, 10, 20)
>>> z = np.atleast_1d(5)
>>> coords = dict(x=x, y=y, z=z, f=f)
>>> values = (1+1j) * np.random.random((len(x), len(y), len(z), len(f)))
>>> data = FieldProjectionCartesianDataArray(values, coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "f")
_data_attrs = {"long_name": "radiation vectors"}
[docs]
class FieldProjectionKSpaceDataArray(DataArray):
"""Far fields in frequency domain as a function of normalized
kx and ky vectors on the observation plane.
Example
-------
>>> f = np.linspace(1e14, 2e14, 10)
>>> r = np.atleast_1d(5)
>>> ux = np.linspace(0, 5, 10)
>>> uy = np.linspace(0, 10, 20)
>>> coords = dict(ux=ux, uy=uy, r=r, f=f)
>>> values = (1+1j) * np.random.random((len(ux), len(uy), len(r), len(f)))
>>> data = FieldProjectionKSpaceDataArray(values, coords=coords)
"""
__slots__ = ()
_dims = ("ux", "uy", "r", "f")
_data_attrs = {"long_name": "radiation vectors"}
[docs]
class DiffractionDataArray(DataArray):
"""Diffraction power amplitudes as a function of diffraction orders and frequency.
Example
-------
>>> f = np.linspace(1e14, 2e14, 10)
>>> orders_x = np.linspace(-1, 1, 3)
>>> orders_y = np.linspace(-2, 2, 5)
>>> coords = dict(orders_x=orders_x, orders_y=orders_y, f=f)
>>> values = (1+1j) * np.random.random((len(orders_x), len(orders_y), len(f)))
>>> data = DiffractionDataArray(values, coords=coords)
"""
__slots__ = ()
_dims = ("orders_x", "orders_y", "f")
_data_attrs = {"long_name": "diffraction amplitude"}
class TriangleMeshDataArray(DataArray):
"""Data of the triangles of a surface mesh as in the STL file format."""
__slots__ = ()
_dims = ("face_index", "vertex_index", "axis")
_data_attrs = {"long_name": "surface mesh triangles"}
class HeatDataArray(DataArray):
"""Heat data array.
Example
-------
>>> T = [0, 1e-12, 2e-12]
>>> td = HeatDataArray((1+1j) * np.random.random((3,)), coords=dict(T=T))
"""
__slots__ = ()
_dims = ("T",)
class EMEScalarModeFieldDataArray(AbstractSpatialDataArray):
"""Spatial distribution of a mode in frequency-domain as a function of mode index
and EME cell index.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3]
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> eme_cell_index = np.arange(5)
>>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
>>> fd = EMEScalarModeFieldDataArray((1+1j) * np.random.random((2,3,1,2,5,5)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "f", "sweep_index", "eme_cell_index", "mode_index")
class EMEFreqModeDataArray(DataArray):
"""Array over frequency, mode index, and EME cell index.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> eme_cell_index = np.arange(5)
>>> coords = dict(f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
>>> fd = EMEFreqModeDataArray((1+1j) * np.random.random((2, 5, 5)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")
class EMEScalarFieldDataArray(AbstractSpatialDataArray):
"""Spatial distribution of a field excited from an EME port in frequency-domain as a
function of mode index at the EME port and the EME port index.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> eme_port_index = [0, 1]
>>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_port_index=eme_port_index)
>>> fd = EMEScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2,5,2)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "f", "sweep_index", "eme_port_index", "mode_index")
class EMECoefficientDataArray(DataArray):
"""EME expansion coefficient of the mode `mode_index_out` in the EME cell
`eme_cell_index`, when excited from mode `mode_index_in` of EME port `eme_port_index`.
Example
-------
>>> mode_index_in = [0, 1]
>>> mode_index_out = [0, 1]
>>> eme_cell_index = np.arange(5)
>>> eme_port_index = [0, 1]
>>> f = [2e14]
>>> coords = dict(
... f=f,
... mode_index_out=mode_index_out,
... mode_index_in=mode_index_in,
... eme_cell_index=eme_cell_index,
... eme_port_index=eme_port_index
... )
>>> fd = EMECoefficientDataArray((1 + 1j) * np.random.random((1, 2, 2, 5, 2)), coords=coords)
"""
__slots__ = ()
_dims = (
"f",
"sweep_index",
"eme_port_index",
"eme_cell_index",
"mode_index_out",
"mode_index_in",
)
_data_attrs = {"long_name": "mode expansion coefficient"}
class EMESMatrixDataArray(DataArray):
"""Scattering matrix elements for a fixed pair of ports, possibly with an extra
sweep index.
Example
-------
>>> mode_index_in = [0, 1]
>>> mode_index_out = [0, 1, 2]
>>> f = [2e14]
>>> sweep_index = np.arange(10)
>>> coords = dict(
... f=f,
... mode_index_out=mode_index_out,
... mode_index_in=mode_index_in,
... sweep_index=sweep_index
... )
>>> fd = EMESMatrixDataArray((1 + 1j) * np.random.random((1, 3, 2, 10)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "sweep_index", "mode_index_out", "mode_index_in")
_data_attrs = {"long_name": "scattering matrix element"}
class EMEInterfaceSMatrixDataArray(DataArray):
"""Scattering matrix elements at a single cell interface for a fixed pair of ports,
possibly with an extra sweep index.
Example
-------
>>> mode_index_in = [0, 1]
>>> mode_index_out = [0, 1, 2]
>>> eme_cell_index = [2, 4]
>>> f = [2e14]
>>> sweep_index = np.arange(10)
>>> coords = dict(
... f=f,
... sweep_index=sweep_index,
... eme_cell_index=eme_cell_index,
... mode_index_out=mode_index_out,
... mode_index_in=mode_index_in,
... )
>>> fd = EMEInterfaceSMatrixDataArray((1 + 1j) * np.random.random((1, 10, 2, 3, 2)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index_out", "mode_index_in")
_data_attrs = {"long_name": "scattering matrix element"}
class EMEModeIndexDataArray(DataArray):
"""Complex-valued effective propagation index of an EME mode,
also indexed by EME cell.
Example
-------
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(4)
>>> eme_cell_index = np.arange(5)
>>> coords = dict(f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
>>> data = EMEModeIndexDataArray((1+1j) * np.random.random((2,4,5)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")
_data_attrs = {"long_name": "Propagation index"}
class EMEFluxDataArray(DataArray):
"""Power flux of an EME mode, also indexed by EME cell.
Example
-------
>>> f = [2e14, 3e14]
>>> sweep_index = np.arange(2)
>>> eme_cell_index = np.arange(5)
>>> mode_index = np.arange(4)
>>> coords = dict(f=f, sweep_index=sweep_index, eme_cell_index=eme_cell_index, mode_index=mode_index)
>>> data = EMEFluxDataArray(np.random.random((2,2,5,4)), coords=coords)
"""
__slots__ = ()
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")
_data_attrs = {"units": WATT, "long_name": "flux"}
class ChargeDataArray(DataArray):
"""Charge data array.
Example
-------
>>> n = [0, 1e-12, 2e-12]
>>> p = [0, 3e-12, 4e-12]
>>> td = ChargeDataArray((1+1j) * np.random.random((3,3)), coords=dict(n=n, p=p))
"""
__slots__ = ()
_dims = ("n", "p")
[docs]
class SteadyVoltageDataArray(DataArray):
"""Steady voltage data array. Data array used with steady state
simulations with voltage as dimension.
Example
-------
>>> import tidy3d as td
>>> intensities = [0, 1, 4]
>>> V = [-1, -0.5, 0]
>>> voltage_dataarray = td.SteadyVoltageDataArray(data=intensities, coords={"v": V})
"""
__slots__ = ()
_dims = ("v",)
[docs]
class PointDataArray(DataArray):
"""A two-dimensional array that stores coordinates/field components for a collection of points.
Dimension ``index`` denotes the index of a point in the collection, and dimension ``axis``
denotes the field component (or point coordinate) in that direction.
Example
-------
>>> point_array = PointDataArray(
... (1+1j) * np.random.random((5, 3)), coords=dict(index=np.arange(5), axis=np.arange(3)),
... )
>>> # get coordinates of a point number 3
>>> point3 = point_array.sel(index=3)
>>> # get x coordinates of all points
>>> x_coords = point_array.sel(axis=0)
>>>
>>> field_da = PointDataArray(
... np.random.random((120, 3)), coords=dict(index=np.arange(120), axis=np.arange(3)),
... )
>>> # get field of point number 90
>>> field_point90 = field_da.sel(index=90)
>>> # get z component of all points
>>> z_field = field_da.sel(axis=2)
"""
__slots__ = ()
_dims = ("index", "axis")
[docs]
class CellDataArray(DataArray):
"""A two-dimensional array that stores indices of points composing each cell in a collection of
cells of the same type (for example: triangles, tetrahedra, etc). Dimension ``cell_index``
denotes the index of a cell in the collection, and dimension ``vertex_index`` denotes placement
(index) of a point in a cell (for example: 0, 1, or 2 for triangles; 0, 1, 2, or 3 for
tetrahedra).
Example
-------
>>> cell_array = CellDataArray(
... (1+1j) * np.random.random((4, 3)),
... coords=dict(cell_index=np.arange(4), vertex_index=np.arange(3)),
... )
>>> # get indices of points composing cell number 3
>>> cell3 = cell_array.sel(cell_index=3)
>>> # get indices of points that represent the first vertex in each cell
>>> first_vertices = cell_array.sel(vertex_index=0)
"""
__slots__ = ()
_dims = ("cell_index", "vertex_index")
[docs]
class IndexedDataArray(DataArray):
"""Stores a one-dimensional array enumerated by coordinate ``index``. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated data or a ``CellDataArray``
to store cell-associated data.
Example
-------
>>> indexed_array = IndexedDataArray(
... (1+1j) * np.random.random((3,)), coords=dict(index=np.arange(3))
... )
"""
__slots__ = ()
_dims = ("index",)
[docs]
class IndexedVoltageDataArray(DataArray):
"""Stores a two-dimensional array with coordinates ``index`` and ``voltage``, where
``index`` is usually associated with ``PointDataArray`` and ``voltage`` indicates at what
bias/DC-voltage the data was obtained with.
Example
-------
>>> indexed_array = IndexedVoltageDataArray(
... (1+1j) * np.random.random((3,2)), coords=dict(index=np.arange(3), voltage=[-1, 1])
... )
"""
__slots__ = ()
_dims = ("index", "voltage")
[docs]
class IndexedTimeDataArray(DataArray):
"""Stores a two-dimensional array with coordinates ``index`` and ``t``, where
``index`` is usually associated with ``PointDataArray`` and ``t`` indicates at what
simulated time the data was obtained.
Example
-------
>>> indexed_array = IndexedTimeDataArray(
... (1+1j) * np.random.random((3,2)), coords=dict(index=np.arange(3), t=[0, 1])
... )
"""
__slots__ = ()
_dims = ("index", "t")
[docs]
class IndexedFieldVoltageDataArray(DataArray):
"""Stores indexed values of vector fields for different voltages. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated vector data.
Example
-------
>>> indexed_array = IndexedFieldVoltageDataArray(
... (1+1j) * np.random.random((4,3,2)), coords=dict(index=np.arange(4), axis=np.arange(3), voltage=[-1, 1])
... )
"""
__slots__ = ()
_dims = ("index", "axis", "voltage")
class SpatialVoltageDataArray(AbstractSpatialDataArray):
"""Spatial distribution with voltage mapping.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> v = [-1, 1]
>>> coords = dict(x=x, y=y, z=z, voltage=v)
>>> fd = SpatialVoltageDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords)
"""
__slots__ = ()
_dims = ("x", "y", "z", "voltage")
class PerturbationCoefficientDataArray(DataArray):
__slots__ = ()
_dims = ("wvl", "coeff")
class VoltageArray(DataArray):
# Always set __slots__ = () to avoid xarray warnings
__slots__ = ()
_data_attrs = {"units": VOLT, "long_name": "voltage"}
class CurrentArray(DataArray):
# Always set __slots__ = () to avoid xarray warnings
__slots__ = ()
_data_attrs = {"units": AMP, "long_name": "current"}
class ImpedanceArray(DataArray):
# Always set __slots__ = () to avoid xarray warnings
__slots__ = ()
_data_attrs = {"units": OHM, "long_name": "impedance"}
# Voltage arrays
[docs]
class VoltageFreqDataArray(VoltageArray, FreqDataArray):
"""Voltage data array in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9, 4e9]
>>> coords = dict(f=f)
>>> data = np.random.random(3) + 1j * np.random.random(3)
>>> vfd = VoltageFreqDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class VoltageTimeDataArray(VoltageArray, TimeDataArray):
"""Voltage data array in time domain.
Example
-------
>>> import numpy as np
>>> t = [0, 1e-9, 2e-9, 3e-9]
>>> coords = dict(t=t)
>>> data = np.sin(2 * np.pi * 1e9 * np.array(t))
>>> vtd = VoltageTimeDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class VoltageFreqModeDataArray(VoltageArray, FreqModeDataArray):
"""Voltage data array in frequency-mode domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> mode_index = [0, 1]
>>> coords = dict(f=f, mode_index=mode_index)
>>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2))
>>> vfmd = VoltageFreqModeDataArray(data, coords=coords)
"""
__slots__ = ()
class VoltageFreqTerminalDataArray(VoltageArray, FreqTerminalDataArray):
"""Voltage data array in frequency-terminal domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> terminal_label = ["t0", "t1"]
>>> coords = dict(f=f, terminal_label=terminal_label)
>>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2))
>>> vftd = VoltageFreqTerminalDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class VoltageFreqTerminalModeDataArray(VoltageArray, FreqTerminalModeDataArray):
"""Voltage transformation matrix data array from modes to terminals in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> terminal_label = ["t0", "t1"]
>>> mode_index = [0, 1]
>>> coords = dict(f=f, terminal_label=terminal_label, mode_index=mode_index)
>>> data = np.random.random((2, 2, 2)) + 1j * np.random.random((2, 2, 2))
>>> vtransform = VoltageFreqTerminalModeDataArray(data, coords=coords)
"""
__slots__ = ()
class VoltageFreqModeTerminalDataArray(VoltageArray, FreqModeTerminalDataArray):
"""Inverse voltage transformation matrix data array from terminals to modes in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> mode_index = [0, 1]
>>> terminal_label = ["t0", "t1"]
>>> coords = dict(f=f, mode_index=mode_index, terminal_label=terminal_label)
>>> data = np.random.random((2, 2, 2)) + 1j * np.random.random((2, 2, 2))
>>> vtransform_inv = VoltageFreqModeTerminalDataArray(data, coords=coords)
"""
__slots__ = ()
# Current arrays
[docs]
class CurrentFreqDataArray(CurrentArray, FreqDataArray):
"""Current data array in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9, 4e9]
>>> coords = dict(f=f)
>>> data = np.random.random(3) + 1j * np.random.random(3)
>>> cfd = CurrentFreqDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class CurrentTimeDataArray(CurrentArray, TimeDataArray):
"""Current data array in time domain.
Example
-------
>>> import numpy as np
>>> t = [0, 1e-9, 2e-9, 3e-9]
>>> coords = dict(t=t)
>>> data = np.cos(2 * np.pi * 1e9 * np.array(t))
>>> ctd = CurrentTimeDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class CurrentFreqModeDataArray(CurrentArray, FreqModeDataArray):
"""Current data array in frequency-mode domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> mode_index = [0, 1]
>>> coords = dict(f=f, mode_index=mode_index)
>>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2))
>>> cfmd = CurrentFreqModeDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class CurrentFreqTerminalDataArray(CurrentArray, FreqTerminalDataArray):
"""Current data array in frequency-terminal domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> terminal_label = ["t0", "t1", "t2", "t3", "t4"]
>>> coords = dict(f=f, terminal_label=terminal_label)
>>> data = np.random.random((2, 5)) + 1j * np.random.random((2, 5))
>>> cftd = CurrentFreqTerminalDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class CurrentFreqTerminalModeDataArray(CurrentArray, FreqTerminalModeDataArray):
"""Current transformation matrix data array from modes to terminals in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> mode_index = [0, 1]
>>> terminal_label = ["t0", "t1"]
>>> coords = dict(f=f, terminal_label=terminal_label, mode_index=mode_index)
>>> data = np.random.random((2, 2, 2)) + 1j * np.random.random((2, 2, 2))
>>> itransform = CurrentFreqTerminalModeDataArray(data, coords=coords)
"""
__slots__ = ()
# Impedance arrays
[docs]
class ImpedanceModeDataArray(ImpedanceArray, ModeDataArray):
"""Impedance data array in mode index domain.
Example
-------
>>> mode_index = np.arange(4)
>>> coords = dict(mode_index=mode_index)
>>> data = ImpedanceModeDataArray((1+1j) * np.random.random(4), coords=coords)
"""
__slots__ = ()
[docs]
class ImpedanceTerminalDataArray(ImpedanceArray, TerminalDataArray):
"""Impedance data array in terminal index domain.
Example
-------
>>> terminal_label = ["t0", "t1", "t2", "t3", "t4"]
>>> coords = dict(terminal_label=terminal_label)
>>> data = ImpedanceTerminalDataArray((1+1j) * np.random.random(5), coords=coords)
"""
__slots__ = ()
[docs]
class ImpedanceFreqDataArray(ImpedanceArray, FreqDataArray):
"""Impedance data array in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9, 4e9]
>>> coords = dict(f=f)
>>> data = 50.0 + 1j * np.random.random(3)
>>> zfd = ImpedanceFreqDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class ImpedanceTimeDataArray(ImpedanceArray, TimeDataArray):
"""Impedance data array in time domain.
Example
-------
>>> import numpy as np
>>> t = [0, 1e-9, 2e-9, 3e-9]
>>> coords = dict(t=t)
>>> data = 50.0 * np.ones_like(t)
>>> ztd = ImpedanceTimeDataArray(data, coords=coords)
"""
__slots__ = ()
[docs]
class ImpedanceFreqModeDataArray(ImpedanceArray, FreqModeDataArray):
"""Impedance data array in frequency-mode domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> mode_index = [0, 1]
>>> coords = dict(f=f, mode_index=mode_index)
>>> data = 50.0 + 10.0 * np.random.random((2, 2))
>>> zfmd = ImpedanceFreqModeDataArray(data, coords=coords)
"""
__slots__ = ()
class ImpedanceFreqModeModeDataArray(ImpedanceArray, FreqModeModeDataArray):
"""Impedance matrix data array between modes in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> mode_index_out = np.arange(2)
>>> mode_index_in = np.arange(2)
>>> coords = dict(f=f, mode_index_out=mode_index_out, mode_index_in=mode_index_in)
>>> data = ImpedanceFreqModeModeDataArray(50.0 + 10.0 * np.random.random((2, 2, 2)), coords=coords)
>>> zfmmd = data
"""
__slots__ = ()
[docs]
class ImpedanceFreqTerminalTerminalDataArray(ImpedanceArray, FreqTerminalTerminalDataArray):
"""Impedance matrix data array between terminals in frequency domain.
Example
-------
>>> import numpy as np
>>> f = [2e9, 3e9]
>>> terminal_label_out = ["t0", "t1"]
>>> terminal_label_in = ["t0", "t1"]
>>> coords = dict(f=f, terminal_label_out=terminal_label_out, terminal_label_in=terminal_label_in)
>>> data = ImpedanceFreqTerminalTerminalDataArray(50.0 + 10.0 * np.random.random((2, 2, 2)), coords=coords)
>>> zfttd = data
"""
__slots__ = ()
[docs]
class IndexedSurfaceFreqDataArray(DataArray):
"""Stores indexed values of scalar fields on the sides of a surface. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated scalar data.
Example
-------
>>> surface_side_array = IndexedSurfaceFreqDataArray(
... (1+1j) * np.random.random((4,2,1)), coords=dict(index=np.arange(4), side=["outside", "inside"], f=[1e9])
... )
"""
__slots__ = ()
_dims = ("index", "side", "f")
[docs]
class IndexedSurfaceTimeDataArray(DataArray):
"""Stores indexed values of scalar fields on the sides of a surface. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated scalar data.
Example
-------
>>> surface_side_array = IndexedSurfaceTimeDataArray(
... (1+1j) * np.random.random((4,2,1)), coords=dict(index=np.arange(4), side=["outside", "inside"], t=[1e9])
... )
"""
__slots__ = ()
_dims = ("index", "side", "t")
[docs]
class IndexedSurfaceFieldDataArray(DataArray):
"""Stores indexed values of vector fields on the sides of a surface in frequency domain.
It is typically used in conjuction with a ``PointDataArray`` to store point-associated vector data.
Example
-------
>>> indexed_array = IndexedSurfaceFieldDataArray(
... (1+1j) * np.random.random((4,2,3,1)), coords=dict(index=np.arange(4), side=["outside", "inside"], axis=np.arange(3), f=[1e9])
... )
"""
__slots__ = ()
_dims = ("index", "side", "axis", "f")
[docs]
class IndexedSurfaceFieldTimeDataArray(DataArray):
"""Stores indexed values of vector fields on the sides of a surface in time domain.
It is typically used in conjuction with a ``PointDataArray`` to store point-associated vector data.
Example
-------
>>> indexed_array = IndexedSurfaceFieldTimeDataArray(
... (1+1j) * np.random.random((4,2,3,1)), coords=dict(index=np.arange(4), side=["outside", "inside"], axis=np.arange(3), t=[0])
... )
"""
__slots__ = ()
_dims = ("index", "side", "axis", "t")
class IndexedFieldDataArray(DataArray):
"""Stores indexed values of vector fields in frequency domain. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated vector data.
Example
-------
>>> indexed_array = IndexedFieldDataArray(
... (1+1j) * np.random.random((4,3,1)), coords=dict(index=np.arange(4), axis=np.arange(3), f=[1e9])
... )
"""
__slots__ = ()
_dims = ("index", "axis", "f")
class IndexedFieldTimeDataArray(DataArray):
"""Stores indexed values of vector fields in time domain. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated vector data.
Example
-------
>>> indexed_array = IndexedFieldTimeDataArray(
... (1+1j) * np.random.random((4,3,1)), coords=dict(index=np.arange(4), axis=np.arange(3), t=[0])
... )
"""
__slots__ = ()
_dims = ("index", "axis", "t")
class IndexedFreqDataArray(DataArray):
"""Stores indexed values of scalar fields in frequency domain. It is typically used
in conjuction with a ``PointDataArray`` to store point-associated vector data.
Example
-------
>>> indexed_array = IndexedFreqDataArray(
... (1+1j) * np.random.random((4,1)), coords=dict(index=np.arange(4), f=[1e9])
... )
"""
__slots__ = ()
_dims = ("index", "f")
def _make_base_result_data_array(result: DataArray) -> IntegralResultType:
"""Helper for creating the proper base result type."""
cls = FreqDataArray
if "t" in result.coords:
cls = TimeDataArray
if "f" in result.coords and "mode_index" in result.coords:
cls = FreqModeDataArray
if (
"f" in result.coords
and "terminal_label" in result.coords
and "mode_index" not in result.coords
):
cls = FreqTerminalDataArray
if "f" in result.coords and "terminal_label" in result.coords and "mode_index" in result.coords:
cls = FreqTerminalModeDataArray
return cls._assign_data_attrs(cls(data=result.data, coords=result.coords))
def _make_voltage_data_array(result: DataArray) -> VoltageIntegralResultType:
"""Helper for creating the proper voltage array type."""
cls = VoltageFreqDataArray
if "t" in result.coords:
cls = VoltageTimeDataArray
if "f" in result.coords and "mode_index" in result.coords:
cls = VoltageFreqModeDataArray
if (
"f" in result.coords
and "terminal_label" in result.coords
and "mode_index" not in result.coords
):
cls = VoltageFreqTerminalDataArray
if "f" in result.coords and "terminal_label" in result.coords and "mode_index" in result.coords:
cls = VoltageFreqTerminalModeDataArray
return cls._assign_data_attrs(cls(data=result.data, coords=result.coords))
def _make_current_data_array(result: DataArray) -> CurrentIntegralResultType:
"""Helper for creating the proper current array type."""
cls = CurrentFreqDataArray
if "t" in result.coords:
cls = CurrentTimeDataArray
if "f" in result.coords and "mode_index" in result.coords:
cls = CurrentFreqModeDataArray
if (
"f" in result.coords
and "terminal_label" in result.coords
and "mode_index" not in result.coords
):
cls = CurrentFreqTerminalDataArray
if "f" in result.coords and "terminal_label" in result.coords and "mode_index" in result.coords:
cls = CurrentFreqTerminalModeDataArray
return cls._assign_data_attrs(cls(data=result.data, coords=result.coords))
def _make_impedance_data_array(result: DataArray) -> ImpedanceResultType:
"""Helper for creating the proper impedance array type."""
cls = ImpedanceFreqDataArray
if "t" in result.coords:
cls = ImpedanceTimeDataArray
if "f" in result.coords and "mode_index" in result.coords:
cls = ImpedanceFreqModeDataArray
if (
"f" in result.coords
and "terminal_label_out" in result.coords
and "terminal_label_in" in result.coords
):
cls = ImpedanceFreqTerminalTerminalDataArray
return cls._assign_data_attrs(cls(data=result.data, coords=result.coords))
DATA_ARRAY_TYPES = [
SpatialDataArray,
ScalarFieldDataArray,
ScalarFieldTimeDataArray,
ScalarModeFieldDataArray,
ScalarTerminalFieldDataArray,
FluxDataArray,
FluxTimeDataArray,
ModeAmpsDataArray,
ModeIndexDataArray,
GroupIndexDataArray,
ModeDispersionDataArray,
FieldProjectionAngleDataArray,
FieldProjectionCartesianDataArray,
FieldProjectionKSpaceDataArray,
DiffractionDataArray,
ModeDataArray,
TerminalDataArray,
FreqModeDataArray,
FreqDataArray,
TimeDataArray,
FreqVoltageDataArray,
FreqTerminalDataArray,
FreqTerminalModeDataArray,
FreqModeTerminalDataArray,
FreqTerminalTerminalDataArray,
TriangleMeshDataArray,
HeatDataArray,
EMEScalarFieldDataArray,
EMEScalarModeFieldDataArray,
EMESMatrixDataArray,
EMEInterfaceSMatrixDataArray,
EMECoefficientDataArray,
EMEModeIndexDataArray,
EMEFluxDataArray,
EMEFreqModeDataArray,
ChargeDataArray,
SteadyVoltageDataArray,
PointDataArray,
CellDataArray,
IndexedDataArray,
IndexedFieldVoltageDataArray,
IndexedVoltageDataArray,
SpatialVoltageDataArray,
PerturbationCoefficientDataArray,
IndexedTimeDataArray,
VoltageFreqDataArray,
VoltageTimeDataArray,
VoltageFreqModeDataArray,
VoltageFreqTerminalDataArray,
VoltageFreqTerminalModeDataArray,
VoltageFreqModeTerminalDataArray,
CurrentFreqDataArray,
CurrentTimeDataArray,
CurrentFreqModeDataArray,
CurrentFreqTerminalDataArray,
CurrentFreqTerminalModeDataArray,
ImpedanceModeDataArray,
ImpedanceTerminalDataArray,
ImpedanceFreqDataArray,
ImpedanceTimeDataArray,
ImpedanceFreqModeDataArray,
FreqModeModeDataArray,
ImpedanceFreqModeModeDataArray,
ImpedanceFreqTerminalTerminalDataArray,
IndexedSurfaceFieldDataArray,
IndexedSurfaceFieldTimeDataArray,
IndexedFieldDataArray,
IndexedFieldTimeDataArray,
IndexedFreqDataArray,
IndexedSurfaceFreqDataArray,
IndexedSurfaceTimeDataArray,
]
DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES}
IndexedDataArrayTypes = Union[
IndexedDataArray,
IndexedVoltageDataArray,
IndexedSurfaceFieldDataArray,
IndexedSurfaceFieldTimeDataArray,
IndexedFieldDataArray,
IndexedFieldTimeDataArray,
IndexedFreqDataArray,
IndexedTimeDataArray,
IndexedFieldVoltageDataArray,
IndexedSurfaceFreqDataArray,
IndexedSurfaceTimeDataArray,
PointDataArray,
]
IntegralResultType = Union[
FreqDataArray,
FreqModeDataArray,
FreqTerminalDataArray,
FreqTerminalModeDataArray,
TimeDataArray,
]
VoltageIntegralResultType = Union[
VoltageFreqDataArray,
VoltageFreqModeDataArray,
VoltageFreqTerminalDataArray,
VoltageTimeDataArray,
VoltageFreqTerminalModeDataArray,
]
CurrentIntegralResultType = Union[
CurrentFreqDataArray,
CurrentFreqModeDataArray,
CurrentFreqTerminalDataArray,
CurrentTimeDataArray,
CurrentFreqTerminalModeDataArray,
]
ImpedanceResultType = Union[
ImpedanceFreqDataArray,
ImpedanceFreqModeDataArray,
ImpedanceTimeDataArray,
ImpedanceFreqTerminalTerminalDataArray,
]
class _TracedDataset(xr.Dataset):
"""Dataset subclass that preserves traced tidy3d DataArrays when accessed.
When xr.Dataset constructor is called with tidy3d DataArray objects,
xarray extracts the data and stores it as Variables internally. When
items are accessed via __getitem__, xarray wraps these Variables in
vanilla xr.DataArray, losing the custom .values property that's needed
for autograd compatibility.
This subclass overrides _construct_dataarray to return tidy3d DataArray
objects, preserving the custom .values behavior that returns .data
directly when tracing (avoiding np.asarray which breaks autodiff).
"""
__slots__ = ()
def _construct_dataarray(self, name: Hashable) -> DataArray:
"""Construct a tidy3d DataArray by indexing this dataset."""
return DataArray(super()._construct_dataarray(name))