Source code for tidy3d.components.data.data_array

"""Storing tidy3d data at it's most fundamental level as xr.DataArray objects"""

from __future__ import annotations

from abc import ABC
from collections.abc import Mapping
from typing import Any, Optional, Union

import autograd.numpy as anp
import h5py
import numpy as np
import xarray as xr
from autograd.tracer import isbox
from xarray.core import missing
from xarray.core.indexes import PandasIndex
from xarray.core.indexing import _outer_to_numpy_indexer
from xarray.core.types import InterpOptions, Self
from xarray.core.utils import OrderedSet, either_dict_or_kwargs
from xarray.core.variable import as_variable

from tidy3d.compat import alignment
from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box
from tidy3d.components.geometry.bound_ops import bounds_contains
from tidy3d.components.types import Axis, Bound
from tidy3d.constants import (
    AMP,
    HERTZ,
    MICROMETER,
    OHM,
    PICOSECOND_PER_NANOMETER_PER_KILOMETER,
    RADIAN,
    SECOND,
    VOLT,
    WATT,
)
from tidy3d.exceptions import DataError, FileError

# maps the dimension names to their attributes
DIM_ATTRS = {
    "x": {"units": MICROMETER, "long_name": "x position"},
    "y": {"units": MICROMETER, "long_name": "y position"},
    "z": {"units": MICROMETER, "long_name": "z position"},
    "f": {"units": HERTZ, "long_name": "frequency"},
    "t": {"units": SECOND, "long_name": "time"},
    "direction": {"long_name": "propagation direction"},
    "mode_index": {"long_name": "mode index"},
    "eme_port_index": {"long_name": "EME port index"},
    "eme_cell_index": {"long_name": "EME cell index"},
    "mode_index_in": {"long_name": "mode index in"},
    "mode_index_out": {"long_name": "mode index out"},
    "sweep_index": {"long_name": "sweep index"},
    "theta": {"units": RADIAN, "long_name": "elevation angle"},
    "phi": {"units": RADIAN, "long_name": "azimuth angle"},
    "ux": {"long_name": "normalized kx"},
    "uy": {"long_name": "normalized ky"},
    "orders_x": {"long_name": "diffraction order"},
    "orders_y": {"long_name": "diffraction order"},
    "face_index": {"long_name": "face index"},
    "vertex_index": {"long_name": "vertex index"},
    "axis": {"long_name": "axis"},
}


# name of the DataArray.values in the hdf5 file (xarray's default name too)
DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__"


[docs] class DataArray(xr.DataArray): """Subclass of ``xr.DataArray`` that requires _dims to match the keys of the coords.""" # Always set __slots__ = () to avoid xarray warnings __slots__ = () # stores an ordered tuple of strings corresponding to the data dimensions _dims = () # stores a dictionary of attributes corresponding to the data values _data_attrs: dict[str, str] = {}
[docs] def __init__(self, data, *args, **kwargs): # if data is a vanilla autograd box, convert to our box if isbox(data) and not is_tidy_box(data): data = TidyArrayBox.from_arraybox(data) # do the same for xr.Variable or xr.DataArray type elif ( isinstance(data, (xr.Variable, xr.DataArray)) and isbox(data.data) and not is_tidy_box(data.data) ): data.data = TidyArrayBox.from_arraybox(data.data) super().__init__(data, *args, **kwargs)
[docs] @classmethod def __get_validators__(cls): """Validators that get run when :class:`.DataArray` objects are added to pydantic models.""" yield cls.check_unloaded_data yield cls.validate_dims yield cls.assign_data_attrs yield cls.assign_coord_attrs
[docs] @classmethod def check_unloaded_data(cls, val): """If the data comes in as the raw data array string, raise a custom warning.""" if isinstance(val, str) and val in DATA_ARRAY_MAP: raise DataError( f"Trying to load {cls.__name__} but the data is not present. " "Note that data will not be saved to .json file. " "use .hdf5 format instead if data present." ) return cls(val)
[docs] @classmethod def validate_dims(cls, val): """Make sure the dims are the same as _dims, then put them in the correct order.""" if set(val.dims) != set(cls._dims): raise ValueError(f"wrong dims, expected '{cls._dims}', got '{val.dims}'") return val.transpose(*cls._dims)
[docs] @classmethod def assign_data_attrs(cls, val): """Assign the correct data attributes to the :class:`.DataArray`.""" for attr_name, attr in cls._data_attrs.items(): val.attrs[attr_name] = attr return val
def _interp_validator(self, field_name: Optional[str] = None) -> None: """Ensure the data can be interpolated or selected by checking for duplicate coordinates. NOTE ---- This does not check every 'DataArray' by default. Instead, when required, this check can be called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. """ if field_name is None: field_name = "DataArray" for dim, coord in self.coords.items(): if coord.to_index().duplicated().any(): raise DataError( f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. " "Duplicates can be removed by running " f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." )
[docs] @classmethod def assign_coord_attrs(cls, val): """Assign the correct coordinate attributes to the :class:`.DataArray`.""" for dim in cls._dims: dim_attrs = DIM_ATTRS.get(dim) if dim_attrs is not None: for attr_name, attr in dim_attrs.items(): val.coords[dim].attrs[attr_name] = attr return val
[docs] @classmethod def __modify_schema__(cls, field_schema): """Sets the schema of DataArray object.""" schema = { "title": "DataArray", "type": "xr.DataArray", "properties": { "_dims": { "title": "_dims", "type": "Tuple[str, ...]", }, }, "required": ["_dims"], } field_schema.update(schema)
@classmethod def _json_encoder(cls, val): """What function to call when writing a DataArray to json.""" return type(val).__name__
[docs] def __eq__(self, other) -> bool: """Whether two data array objects are equal.""" if not isinstance(other, xr.DataArray): return False if not self.data.shape == other.data.shape or not np.all(self.data == other.data): return False for key, val in self.coords.items(): if not np.all(np.array(val) == np.array(other.coords[key])): return False return True
@property def values(self): """ The array's data converted to a numpy.ndarray. Returns ------- np.ndarray The values of the DataArray. """ return self.data if isbox(self.data) else super().values @values.setter def values(self, value: Any) -> None: self.variable.values = value @property def abs(self): """Absolute value of data array.""" return abs(self) @property def is_uniform(self): """Whether each element is of equal value in the data array""" raw_data = self.data.ravel() return np.allclose(raw_data, raw_data[0])
[docs] def to_hdf5(self, fname: Union[str, h5py.File], group_path: str) -> None: """Save an xr.DataArray to the hdf5 file or file handle with a given path to the group.""" # file name passed if isinstance(fname, str): with h5py.File(fname, "w") as f_handle: self.to_hdf5_handle(f_handle=f_handle, group_path=group_path) # file handle passed else: self.to_hdf5_handle(f_handle=fname, group_path=group_path)
[docs] def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: """Save an xr.DataArray to the hdf5 file handle with a given path to the group.""" sub_group = f_handle.create_group(group_path) sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) for key, val in self.coords.items(): if val.dtype == "<U1": sub_group[key] = val.values.tolist() else: sub_group[key] = val
[docs] @classmethod def from_hdf5(cls, fname: str, group_path: str) -> Self: """Load an DataArray from an hdf5 file with a given path to the group.""" with h5py.File(fname, "r") as f: sub_group = f[group_path] values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group} for key, val in coords.items(): if val.dtype == "O": coords[key] = [byte_string.decode() for byte_string in val.tolist()] return cls(values, coords=coords, dims=cls._dims)
[docs] @classmethod def from_file(cls, fname: str, group_path: str) -> Self: """Load an DataArray from an hdf5 file with a given path to the group.""" if ".hdf5" not in fname: raise FileError( f"'DataArray' objects must be written to '.hdf5' format. Given filename of {fname}." ) return cls.from_hdf5(fname=fname, group_path=group_path)
[docs] def __hash__(self) -> int: """Generate hash value for a :class:`.DataArray` instance, needed for custom components.""" import dask token_str = dask.base.tokenize(self) return hash(token_str)
[docs] def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: """Multiply self by value at indices.""" if isbox(self.data) or isbox(value): return self._ag_multiply_at(value, coord_name, indices) self_mult = self.copy() self_mult[{coord_name: indices}] *= value return self_mult
def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: """Autograd multiply_at override when tracing.""" key = {coord_name: indices} _, index_tuple, _ = self.variable._broadcast_indexes(key) idx = _outer_to_numpy_indexer(index_tuple, self.data.shape) mask = np.zeros(self.data.shape, dtype="?") mask[idx] = True return self.copy(deep=False, data=anp.where(mask, self.data * value, self.data))
[docs] def interp( self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, **coords_kwargs: Any, ) -> Self: """Interpolate this DataArray to new coordinate values. Parameters ---------- coords : Union[Mapping[Any, Any], None] = None A mapping from dimension names to new coordinate labels. method : InterpOptions = "linear" The interpolation method to use. assume_sorted : bool = False If True, skip sorting of coordinates. kwargs : Union[Mapping[str, Any], None] = None Additional keyword arguments to pass to the interpolation function. **coords_kwargs : Any The keyword arguments form of coords. Returns ------- DataArray A new DataArray with interpolated values. Raises ------ KeyError If any of the specified coordinates are not in the DataArray. """ if isbox(self.data): return self._ag_interp(coords, method, assume_sorted, kwargs, **coords_kwargs) return super().interp(coords, method, assume_sorted, kwargs, **coords_kwargs)
def _ag_interp( self, coords: Union[Mapping[Any, Any], None] = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Union[Mapping[str, Any], None] = None, **coords_kwargs: Any, ) -> Self: """Autograd interp override when tracing over self.data. This implementation closely follows the interp implementation of xarray to match its behavior as closely as possible while supporting autograd. See: - https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html - https://docs.xarray.dev/en/latest/generated/xarray.Dataset.interp.html """ if kwargs is None: kwargs = {} ds = self._to_temp_dataset() coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") indexers = dict(ds._validate_interp_indexers(coords)) if coords: # Find shared dimensions between the dataset and the indexers sdims = ( set(ds.dims) .intersection(*[set(nx.dims) for nx in indexers.values()]) .difference(coords.keys()) ) indexers.update({d: ds.variables[d] for d in sdims}) obj = ds if assume_sorted else ds.sortby(list(coords)) # workaround to get a variable for a dimension without a coordinate validated_indexers = { k: (obj._variables.get(k, as_variable((k, range(obj.sizes[k])))), v) for k, v in indexers.items() } for k, v in validated_indexers.items(): obj, newidx = missing._localize(obj, {k: v}) validated_indexers[k] = newidx[k] variables = {} reindex = False for name, var in obj._variables.items(): if name in indexers: continue dtype_kind = var.dtype.kind if dtype_kind in "uifc": # Interpolation for numeric types var_indexers = {k: v for k, v in validated_indexers.items() if k in var.dims} variables[name] = self._ag_interp_func(var, var_indexers, method, **kwargs) elif dtype_kind in "ObU" and (validated_indexers.keys() & var.dims): # Stepwise interpolation for non-numeric types reindex = True elif all(d not in indexers for d in var.dims): # Keep variables not dependent on interpolated coords variables[name] = var if reindex: # Reindex for non-numeric types reindex_indexers = {k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)} reindexed = alignment.reindex( obj, indexers=reindex_indexers, method="nearest", exclude_vars=variables.keys(), ) indexes = dict(reindexed._indexes) variables.update(reindexed.variables) else: # Get the indexes that are not being interpolated along indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} # Get the coords that also exist in the variables coord_names = obj._coord_names & variables.keys() selected = ds._replace_with_new_dims(variables.copy(), coord_names, indexes=indexes) # Attach indexer as coordinate for k, v in indexers.items(): if v.dims == (k,): index = PandasIndex(v, k, coord_dtype=v.dtype) index_vars = index.create_variables({k: v}) indexes[k] = index variables.update(index_vars) else: variables[k] = v # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) variables.update(coord_vars) indexes.update(new_indexes) coord_names = obj._coord_names & variables.keys() | coord_vars.keys() ds = ds._replace_with_new_dims(variables, coord_names, indexes=indexes) return self._from_temp_dataset(ds) @staticmethod def _ag_interp_func(var, indexes_coords, method, **kwargs): """ Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`. The implementation follows xarray's interp implementation in xarray.core.missing, but replaces some of the pre-processing as well as the actual interpolation function with an autograd-compatible approach. Parameters ---------- var : xr.Variable The variable to be interpolated. indexes_coords : dict A dictionary mapping dimension names to coordinate values for interpolation. method : str The interpolation method to use. **kwargs : dict Additional keyword arguments to pass to the interpolation function. Returns ------- xr.Variable The interpolated variable. """ if not indexes_coords: return var.copy() result = var for indep_indexes_coords in missing.decompose_interp(indexes_coords): var = result # target dimensions dims = list(indep_indexes_coords) x, new_x = zip(*[indep_indexes_coords[d] for d in dims]) destination = missing.broadcast_variables(*new_x) broadcast_dims = [d for d in var.dims if d not in dims] original_dims = broadcast_dims + dims new_dims = broadcast_dims + list(destination[0].dims) x, new_x = missing._floatize_x(x, new_x) permutation = [var.dims.index(dim) for dim in original_dims] combined_permutation = permutation[-len(x) :] + permutation[: -len(x)] data = anp.transpose(var.data, combined_permutation) xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1) result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs) result = anp.moveaxis(result, 0, -1) result = anp.reshape(result, result.shape[:-1] + new_x[0].shape) result = xr.Variable(new_dims, result, attrs=var.attrs, fastpath=True) out_dims: OrderedSet = OrderedSet() for d in var.dims: if d in dims: out_dims.update(indep_indexes_coords[d][1].dims) else: out_dims.add(d) if len(out_dims) > 1: result = result.transpose(*out_dims) return result def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray: """Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible Constraints / Edge cases: - `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays - `data` will be reshaped to try to match `self.shape` except where `coords` present """ # make mask mask = xr.zeros_like(self, dtype=bool) mask.loc[coords] = True # reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis old_data = self.data new_shape = list(old_data.shape) for i, dim in enumerate(self.dims): if dim in coords: new_shape[i] = 1 try: new_data = data.reshape(new_shape) except ValueError as e: raise ValueError( "Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was " f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this " "error please raise an issue on the tidy3d github repository with the context." ) from e # broadcast data to repeat data along the selected dimensions to match mask new_data = new_data + np.zeros_like(old_data) new_data = np.where(mask, new_data, old_data) return self.copy(deep=True, data=new_data)
class FreqDataArray(DataArray): """Frequency-domain array. Example ------- >>> f = [2e14, 3e14] >>> fd = FreqDataArray((1+1j) * np.random.random((2,)), coords=dict(f=f)) """ __slots__ = () _dims = ("f",) class FreqModeDataArray(DataArray): """Array over frequency and mode index. Example ------- >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> coords = dict(f=f, mode_index=mode_index) >>> fd = FreqModeDataArray((1+1j) * np.random.random((2, 5)), coords=coords) """ __slots__ = () _dims = ("f", "mode_index") class TimeDataArray(DataArray): """Time-domain array. Example ------- >>> t = [0, 1e-12, 2e-12] >>> td = TimeDataArray((1+1j) * np.random.random((3,)), coords=dict(t=t)) """ __slots__ = () _dims = "t" class MixedModeDataArray(DataArray): """Scalar property associated with mode pairs Example ------- >>> f = [1e14, 2e14, 3e14] >>> mode_index_0 = np.arange(4) >>> mode_index_1 = np.arange(2) >>> coords = dict(f=f, mode_index_0=mode_index_0, mode_index_1=mode_index_1) >>> data = MixedModeDataArray((1+1j) * np.random.random((3, 4, 2)), coords=coords) """ __slots__ = () _dims = ("f", "mode_index_0", "mode_index_1") class AbstractSpatialDataArray(DataArray, ABC): """Spatial distribution.""" __slots__ = () _dims = ("x", "y", "z") _data_attrs = {"long_name": "field value"} @property def _spatially_sorted(self) -> SpatialDataArray: """Check whether sorted and sort if not.""" needs_sorting = [] for axis in "xyz": axis_coords = self.coords[axis].values if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): needs_sorting.append(axis) if len(needs_sorting) > 0: return self.sortby(needs_sorting) return self def sel_inside(self, bounds: Bound) -> SpatialDataArray: """Return a new SpatialDataArray that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Note that the returned data is sorted with respect to spatial coordinates. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns ------- SpatialDataArray Extracted spatial data array. """ if any(bmin > bmax for bmin, bmax in zip(*bounds)): raise DataError( "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." ) # make sure data is sorted with respect to coordinates sorted_self = self._spatially_sorted inds_list = [] coords = (sorted_self.x, sorted_self.y, sorted_self.z) for coord, smin, smax in zip(coords, bounds[0], bounds[1]): length = len(coord) # one point along direction, assume invariance if length == 1: comp_inds = [0] else: # if data does not cover structure at all take the closest index if smax < coord[0]: # structure is completely on the left side # take 2 if possible, so that linear iterpolation is possible comp_inds = np.arange(0, max(2, length)) elif smin > coord[-1]: # structure is completely on the right side # take 2 if possible, so that linear iterpolation is possible comp_inds = np.arange(min(0, length - 2), length) else: if smin < coord[0]: ind_min = 0 else: ind_min = max(0, (coord >= smin).argmax().data - 1) if smax > coord[-1]: ind_max = length - 1 else: ind_max = (coord >= smax).argmax().data comp_inds = np.arange(ind_min, ind_max + 1) inds_list.append(comp_inds) return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: """Check whether data fully covers specified by ``bounds`` spatial region. If data contains only one point along a given direction, then it is assumed the data is constant along that direction and coverage is not checked. Parameters ---------- bounds : Tuple[float, float, float], Tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. rtol : float = 0.0 Relative tolerance for comparing bounds atol : float = 0.0 Absolute tolerance for comparing bounds Returns ------- bool Full cover check outcome. """ if any(bmin > bmax for bmin, bmax in zip(*bounds)): raise DataError( "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." ) xyz = [self.x, self.y, self.z] self_min = [0] * 3 self_max = [0] * 3 for dim in range(3): coords = xyz[dim] if len(coords) == 1: self_min[dim] = bounds[0][dim] self_max[dim] = bounds[1][dim] else: self_min[dim] = np.min(coords) self_max[dim] = np.max(coords) self_bounds = (tuple(self_min), tuple(self_max)) return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol)
[docs] class SpatialDataArray(AbstractSpatialDataArray): """Spatial distribution. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3,4,5,6] >>> coords = dict(x=x, y=y, z=z) >>> fd = SpatialDataArray((1+1j) * np.random.random((2,3,4)), coords=coords) """ __slots__ = ()
[docs] def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> SpatialDataArray: """Reflect data across the plane define by parameters ``axis`` and ``center`` from right to left. Note that the returned data is sorted with respect to spatial coordinates. Parameters ---------- axis : Literal[0, 1, 2] Normal direction of the reflection plane. center : float Location of the reflection plane along its normal direction. reflection_only : bool = False Return only reflected data. Returns ------- SpatialDataArray Data after reflection is performed. """ sorted_self = self._spatially_sorted coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values] data = np.array(sorted_self.data) data_left_bound = coords[axis][0] if np.isclose(center, data_left_bound): num_duplicates = 1 elif center > data_left_bound: raise DataError("Reflection center must be outside and to the left of the data region.") else: num_duplicates = 0 if reflection_only: coords[axis] = 2 * center - coords[axis] coords_dict = dict(zip("xyz", coords)) tmp_arr = SpatialDataArray(sorted_self.data, coords=coords_dict) return tmp_arr.sortby("xyz"[axis]) shape = np.array(np.shape(data)) old_len = shape[axis] shape[axis] = 2 * old_len - num_duplicates ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] ind_left[axis] = slice(old_len - 1, None, -1) ind_right[axis] = slice(old_len - num_duplicates, None) new_data = np.zeros(shape) new_data[ind_left[0], ind_left[1], ind_left[2]] = data new_data[ind_right[0], ind_right[1], ind_right[2]] = data new_coords = np.zeros(shape[axis]) new_coords[old_len - num_duplicates :] = coords[axis] new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] coords[axis] = new_coords coords_dict = dict(zip("xyz", coords)) return SpatialDataArray(new_data, coords=coords_dict)
[docs] class ScalarFieldDataArray(AbstractSpatialDataArray): """Spatial distribution in the frequency-domain. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3,4,5,6] >>> f = [2e14, 3e14] >>> coords = dict(x=x, y=y, z=z, f=f) >>> fd = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "f")
[docs] class ScalarFieldTimeDataArray(AbstractSpatialDataArray): """Spatial distribution in the time-domain. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3,4,5,6] >>> t = [0, 1e-12, 2e-12] >>> coords = dict(x=x, y=y, z=z, t=t) >>> fd = ScalarFieldTimeDataArray(np.random.random((2,3,4,3)), coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "t")
[docs] class ScalarModeFieldDataArray(AbstractSpatialDataArray): """Spatial distribution of a mode in frequency-domain as a function of mode index. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3,4,5,6] >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index) >>> fd = ScalarModeFieldDataArray((1+1j) * np.random.random((2,3,4,2,5)), coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "f", "mode_index")
class ScalarModeFieldCylindricalDataArray(AbstractSpatialDataArray): """Spatial distribution of a mode in frequency-domain as a function of mode index. Example ------- >>> rho = [1,2] >>> theta = [2,3,4] >>> axial = [3,4,5,6] >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> coords = dict(rho=rho, theta=theta, axial=axial, f=f, mode_index=mode_index) >>> fd = ScalarModeFieldCylindricalDataArray((1+1j) * np.random.random((2,3,4,2,5)), coords=coords) """ __slots__ = () _dims = ("rho", "theta", "axial", "f", "mode_index")
[docs] class FluxDataArray(DataArray): """Flux through a surface in the frequency-domain. Example ------- >>> f = [2e14, 3e14] >>> coords = dict(f=f) >>> fd = FluxDataArray(np.random.random(2), coords=coords) """ __slots__ = () _dims = ("f",) _data_attrs = {"units": WATT, "long_name": "flux"}
[docs] class FluxTimeDataArray(DataArray): """Flux through a surface in the time-domain. Example ------- >>> t = [0, 1e-12, 2e-12] >>> coords = dict(t=t) >>> data = FluxTimeDataArray(np.random.random(3), coords=coords) """ __slots__ = () _dims = ("t",) _data_attrs = {"units": WATT, "long_name": "flux"}
[docs] class ModeAmpsDataArray(DataArray): """Forward and backward propagating complex-valued mode amplitudes. Example ------- >>> direction = ["+", "-"] >>> f = [1e14, 2e14, 3e14] >>> mode_index = np.arange(4) >>> coords = dict(direction=direction, f=f, mode_index=mode_index) >>> data = ModeAmpsDataArray((1+1j) * np.random.random((2, 3, 4)), coords=coords) """ __slots__ = () _dims = ("direction", "f", "mode_index") _data_attrs = {"units": "sqrt(W)", "long_name": "mode amplitudes"}
[docs] class ModeIndexDataArray(DataArray): """Complex-valued effective propagation index of a mode. Example ------- >>> f = [2e14, 3e14] >>> mode_index = np.arange(4) >>> coords = dict(f=f, mode_index=mode_index) >>> data = ModeIndexDataArray((1+1j) * np.random.random((2,4)), coords=coords) """ __slots__ = () _dims = ("f", "mode_index") _data_attrs = {"long_name": "Propagation index"}
class GroupIndexDataArray(DataArray): """Group index of a mode. Example ------- >>> f = [2e14, 3e14] >>> mode_index = np.arange(4) >>> coords = dict(f=f, mode_index=mode_index) >>> data = GroupIndexDataArray((1+1j) * np.random.random((2,4)), coords=coords) """ __slots__ = () _dims = ("f", "mode_index") _data_attrs = {"long_name": "Group index"} class ModeDispersionDataArray(DataArray): """Dispersion parameter of a mode. Example ------- >>> f = [2e14, 3e14] >>> mode_index = np.arange(4) >>> coords = dict(f=f, mode_index=mode_index) >>> data = ModeDispersionDataArray((1+1j) * np.random.random((2,4)), coords=coords) """ __slots__ = () _dims = ("f", "mode_index") _data_attrs = { "long_name": "Dispersion parameter", "units": PICOSECOND_PER_NANOMETER_PER_KILOMETER, }
[docs] class FieldProjectionAngleDataArray(DataArray): """Far fields in frequency domain as a function of angles theta and phi. Example ------- >>> f = np.linspace(1e14, 2e14, 10) >>> r = np.atleast_1d(5) >>> theta = np.linspace(0, np.pi, 10) >>> phi = np.linspace(0, 2*np.pi, 20) >>> coords = dict(r=r, theta=theta, phi=phi, f=f) >>> values = (1+1j) * np.random.random((len(r), len(theta), len(phi), len(f))) >>> data = FieldProjectionAngleDataArray(values, coords=coords) """ __slots__ = () _dims = ("r", "theta", "phi", "f") _data_attrs = {"long_name": "radiation vectors"}
[docs] class FieldProjectionCartesianDataArray(DataArray): """Far fields in frequency domain as a function of local x and y coordinates. Example ------- >>> f = np.linspace(1e14, 2e14, 10) >>> x = np.linspace(0, 5, 10) >>> y = np.linspace(0, 10, 20) >>> z = np.atleast_1d(5) >>> coords = dict(x=x, y=y, z=z, f=f) >>> values = (1+1j) * np.random.random((len(x), len(y), len(z), len(f))) >>> data = FieldProjectionCartesianDataArray(values, coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "f") _data_attrs = {"long_name": "radiation vectors"}
[docs] class FieldProjectionKSpaceDataArray(DataArray): """Far fields in frequency domain as a function of normalized kx and ky vectors on the observation plane. Example ------- >>> f = np.linspace(1e14, 2e14, 10) >>> r = np.atleast_1d(5) >>> ux = np.linspace(0, 5, 10) >>> uy = np.linspace(0, 10, 20) >>> coords = dict(ux=ux, uy=uy, r=r, f=f) >>> values = (1+1j) * np.random.random((len(ux), len(uy), len(r), len(f))) >>> data = FieldProjectionKSpaceDataArray(values, coords=coords) """ __slots__ = () _dims = ("ux", "uy", "r", "f") _data_attrs = {"long_name": "radiation vectors"}
[docs] class DiffractionDataArray(DataArray): """Diffraction power amplitudes as a function of diffraction orders and frequency. Example ------- >>> f = np.linspace(1e14, 2e14, 10) >>> orders_x = np.linspace(-1, 1, 3) >>> orders_y = np.linspace(-2, 2, 5) >>> coords = dict(orders_x=orders_x, orders_y=orders_y, f=f) >>> values = (1+1j) * np.random.random((len(orders_x), len(orders_y), len(f))) >>> data = DiffractionDataArray(values, coords=coords) """ __slots__ = () _dims = ("orders_x", "orders_y", "f") _data_attrs = {"long_name": "diffraction amplitude"}
class TriangleMeshDataArray(DataArray): """Data of the triangles of a surface mesh as in the STL file format.""" __slots__ = () _dims = ("face_index", "vertex_index", "axis") _data_attrs = {"long_name": "surface mesh triangles"} class HeatDataArray(DataArray): """Heat data array. Example ------- >>> T = [0, 1e-12, 2e-12] >>> td = HeatDataArray((1+1j) * np.random.random((3,)), coords=dict(T=T)) """ __slots__ = () _dims = "T" class EMEScalarModeFieldDataArray(AbstractSpatialDataArray): """Spatial distribution of a mode in frequency-domain as a function of mode index and EME cell index. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3] >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> eme_cell_index = np.arange(5) >>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_cell_index=eme_cell_index) >>> fd = EMEScalarModeFieldDataArray((1+1j) * np.random.random((2,3,1,2,5,5)), coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "f", "sweep_index", "eme_cell_index", "mode_index") class EMEFreqModeDataArray(DataArray): """Array over frequency, mode index, and EME cell index. Example ------- >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> eme_cell_index = np.arange(5) >>> coords = dict(f=f, mode_index=mode_index, eme_cell_index=eme_cell_index) >>> fd = EMEFreqModeDataArray((1+1j) * np.random.random((2, 5, 5)), coords=coords) """ __slots__ = () _dims = ("f", "sweep_index", "eme_cell_index", "mode_index") class EMEScalarFieldDataArray(AbstractSpatialDataArray): """Spatial distribution of a field excited from an EME port in frequency-domain as a function of mode index at the EME port and the EME port index. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3,4,5,6] >>> f = [2e14, 3e14] >>> mode_index = np.arange(5) >>> eme_port_index = [0, 1] >>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_port_index=eme_port_index) >>> fd = EMEScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2,5,2)), coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "f", "sweep_index", "eme_port_index", "mode_index") class EMECoefficientDataArray(DataArray): """EME expansion coefficient of the mode `mode_index_out` in the EME cell `eme_cell_index`, when excited from mode `mode_index_in` of EME port `eme_port_index`. Example ------- >>> mode_index_in = [0, 1] >>> mode_index_out = [0, 1] >>> eme_cell_index = np.arange(5) >>> eme_port_index = [0, 1] >>> f = [2e14] >>> coords = dict( ... f=f, ... mode_index_out=mode_index_out, ... mode_index_in=mode_index_in, ... eme_cell_index=eme_cell_index, ... eme_port_index=eme_port_index ... ) >>> fd = EMECoefficientDataArray((1 + 1j) * np.random.random((1, 2, 2, 5, 2)), coords=coords) """ __slots__ = () _dims = ( "f", "sweep_index", "eme_port_index", "eme_cell_index", "mode_index_out", "mode_index_in", ) _data_attrs = {"long_name": "mode expansion coefficient"} class EMESMatrixDataArray(DataArray): """Scattering matrix elements for a fixed pair of ports, possibly with an extra sweep index. Example ------- >>> mode_index_in = [0, 1] >>> mode_index_out = [0, 1, 2] >>> f = [2e14] >>> sweep_index = np.arange(10) >>> coords = dict( ... f=f, ... mode_index_out=mode_index_out, ... mode_index_in=mode_index_in, ... sweep_index=sweep_index ... ) >>> fd = EMESMatrixDataArray((1 + 1j) * np.random.random((1, 3, 2, 10)), coords=coords) """ __slots__ = () _dims = ("f", "sweep_index", "mode_index_out", "mode_index_in") _data_attrs = {"long_name": "scattering matrix element"} class EMEModeIndexDataArray(DataArray): """Complex-valued effective propagation index of an EME mode, also indexed by EME cell. Example ------- >>> f = [2e14, 3e14] >>> mode_index = np.arange(4) >>> eme_cell_index = np.arange(5) >>> coords = dict(f=f, mode_index=mode_index, eme_cell_index=eme_cell_index) >>> data = EMEModeIndexDataArray((1+1j) * np.random.random((2,4,5)), coords=coords) """ __slots__ = () _dims = ("f", "sweep_index", "eme_cell_index", "mode_index") _data_attrs = {"long_name": "Propagation index"} class ChargeDataArray(DataArray): """Charge data array. Example ------- >>> n = [0, 1e-12, 2e-12] >>> p = [0, 3e-12, 4e-12] >>> td = ChargeDataArray((1+1j) * np.random.random((3,3)), coords=dict(n=n, p=p)) """ __slots__ = () _dims = ("n", "p") class SteadyVoltageDataArray(DataArray): """Steady voltage data array. Data array used with steady state simulations with voltage as dimension. Example ------- >>> import tidy3d as td >>> intensities = [0, 1, 4] >>> V = [-1, -0.5, 0] >>> voltage_dataarray = td.SteadyVoltageDataArray(data=intensities, coords={"v": V}) """ __slots__ = () _dims = ("v",)
[docs] class PointDataArray(DataArray): """A two-dimensional array that stores coordinates/field components for a collection of points. Dimension ``index`` denotes the index of a point in the collection, and dimension ``axis`` denotes the field component (or point coordinate) in that direction. Example ------- >>> point_array = PointDataArray( ... (1+1j) * np.random.random((5, 3)), coords=dict(index=np.arange(5), axis=np.arange(3)), ... ) >>> # get coordinates of a point number 3 >>> point3 = point_array.sel(index=3) >>> # get x coordinates of all points >>> x_coords = point_array.sel(axis=0) >>> >>> field_da = PointDataArray( ... np.random.random((120, 3)), coords=dict(index=np.arange(120), axis=np.arange(3)), ... ) >>> # get field of point number 90 >>> field_point90 = field_da.sel(index=90) >>> # get z component of all points >>> z_field = field_da.sel(axis=2) """ __slots__ = () _dims = ("index", "axis")
[docs] class CellDataArray(DataArray): """A two-dimensional array that stores indices of points composing each cell in a collection of cells of the same type (for example: triangles, tetrahedra, etc). Dimension ``cell_index`` denotes the index of a cell in the collection, and dimension ``vertex_index`` denotes placement (index) of a point in a cell (for example: 0, 1, or 2 for triangles; 0, 1, 2, or 3 for tetrahedra). Example ------- >>> cell_array = CellDataArray( ... (1+1j) * np.random.random((4, 3)), ... coords=dict(cell_index=np.arange(4), vertex_index=np.arange(3)), ... ) >>> # get indices of points composing cell number 3 >>> cell3 = cell_array.sel(cell_index=3) >>> # get indices of points that represent the first vertex in each cell >>> first_vertices = cell_array.sel(vertex_index=0) """ __slots__ = () _dims = ("cell_index", "vertex_index")
[docs] class IndexedDataArray(DataArray): """Stores a one-dimensional array enumerated by coordinate ``index``. It is typically used in conjuction with a ``PointDataArray`` to store point-associated data or a ``CellDataArray`` to store cell-associated data. Example ------- >>> indexed_array = IndexedDataArray( ... (1+1j) * np.random.random((3,)), coords=dict(index=np.arange(3)) ... ) """ __slots__ = () _dims = ("index",)
class IndexedVoltageDataArray(DataArray): """Stores a two-dimensional array with coordinates ``index`` and ``voltage``, where ``index`` is usually associated with ``PointDataArray`` and ``voltage`` indicates at what bias/DC-voltage the data was obtained with. Example ------- >>> indexed_array = IndexedVoltageDataArray( ... (1+1j) * np.random.random((3,2)), coords=dict(index=np.arange(3), voltage=[-1, 1]) ... ) """ __slots__ = () _dims = ("index", "voltage") class IndexedTimeDataArray(DataArray): """Stores a two-dimensional array with coordinates ``index`` and ``t``, where ``index`` is usually associated with ``PointDataArray`` and ``t`` indicates at what simulated time the data was obtained. Example ------- >>> indexed_array = IndexedTimeDataArray( ... (1+1j) * np.random.random((3,2)), coords=dict(index=np.arange(3), t=[0, 1]) ... ) """ __slots__ = () _dims = ("index", "t") class IndexedFieldVoltageDataArray(DataArray): """Stores indexed values of vector fields for different voltages. It is typically used in conjuction with a ``PointDataArray`` to store point-associated vector data. Example ------- >>> indexed_array = IndexedFieldVoltageDataArray( ... (1+1j) * np.random.random((4,3,2)), coords=dict(index=np.arange(4), axis=np.arange(3), voltage=[-1, 1]) ... ) """ __slots__ = () _dims = ("index", "axis", "voltage") class SpatialVoltageDataArray(AbstractSpatialDataArray): """Spatial distribution with voltage mapping. Example ------- >>> x = [1,2] >>> y = [2,3,4] >>> z = [3,4,5,6] >>> v = [-1, 1] >>> coords = dict(x=x, y=y, z=z, voltage=v) >>> fd = SpatialVoltageDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) """ __slots__ = () _dims = ("x", "y", "z", "voltage") class PerturbationCoefficientDataArray(DataArray): __slots__ = () _dims = ("wvl", "coeff") class VoltageArray(DataArray): # Always set __slots__ = () to avoid xarray warnings __slots__ = () _data_attrs = {"units": VOLT, "long_name": "voltage"} class CurrentArray(DataArray): # Always set __slots__ = () to avoid xarray warnings __slots__ = () _data_attrs = {"units": AMP, "long_name": "current"} class ImpedanceArray(DataArray): # Always set __slots__ = () to avoid xarray warnings __slots__ = () _data_attrs = {"units": OHM, "long_name": "impedance"} # Voltage arrays class VoltageFreqDataArray(VoltageArray, FreqDataArray): """Voltage data array in frequency domain. Example ------- >>> import numpy as np >>> f = [2e9, 3e9, 4e9] >>> coords = dict(f=f) >>> data = np.random.random(3) + 1j * np.random.random(3) >>> vfd = VoltageFreqDataArray(data, coords=coords) """ __slots__ = () class VoltageTimeDataArray(VoltageArray, TimeDataArray): """Voltage data array in time domain. Example ------- >>> import numpy as np >>> t = [0, 1e-9, 2e-9, 3e-9] >>> coords = dict(t=t) >>> data = np.sin(2 * np.pi * 1e9 * np.array(t)) >>> vtd = VoltageTimeDataArray(data, coords=coords) """ __slots__ = () class VoltageFreqModeDataArray(VoltageArray, FreqModeDataArray): """Voltage data array in frequency-mode domain. Example ------- >>> import numpy as np >>> f = [2e9, 3e9] >>> mode_index = [0, 1] >>> coords = dict(f=f, mode_index=mode_index) >>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2)) >>> vfmd = VoltageFreqModeDataArray(data, coords=coords) """ __slots__ = () # Current arrays class CurrentFreqDataArray(CurrentArray, FreqDataArray): """Current data array in frequency domain. Example ------- >>> import numpy as np >>> f = [2e9, 3e9, 4e9] >>> coords = dict(f=f) >>> data = np.random.random(3) + 1j * np.random.random(3) >>> cfd = CurrentFreqDataArray(data, coords=coords) """ __slots__ = () class CurrentTimeDataArray(CurrentArray, TimeDataArray): """Current data array in time domain. Example ------- >>> import numpy as np >>> t = [0, 1e-9, 2e-9, 3e-9] >>> coords = dict(t=t) >>> data = np.cos(2 * np.pi * 1e9 * np.array(t)) >>> ctd = CurrentTimeDataArray(data, coords=coords) """ __slots__ = () class CurrentFreqModeDataArray(CurrentArray, FreqModeDataArray): """Current data array in frequency-mode domain. Example ------- >>> import numpy as np >>> f = [2e9, 3e9] >>> mode_index = [0, 1] >>> coords = dict(f=f, mode_index=mode_index) >>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2)) >>> cfmd = CurrentFreqModeDataArray(data, coords=coords) """ __slots__ = () # Impedance arrays class ImpedanceFreqDataArray(ImpedanceArray, FreqDataArray): """Impedance data array in frequency domain. Example ------- >>> import numpy as np >>> f = [2e9, 3e9, 4e9] >>> coords = dict(f=f) >>> data = 50.0 + 1j * np.random.random(3) >>> zfd = ImpedanceFreqDataArray(data, coords=coords) """ __slots__ = () class ImpedanceTimeDataArray(ImpedanceArray, TimeDataArray): """Impedance data array in time domain. Example ------- >>> import numpy as np >>> t = [0, 1e-9, 2e-9, 3e-9] >>> coords = dict(t=t) >>> data = 50.0 * np.ones_like(t) >>> ztd = ImpedanceTimeDataArray(data, coords=coords) """ __slots__ = () class ImpedanceFreqModeDataArray(ImpedanceArray, FreqModeDataArray): """Impedance data array in frequency-mode domain. Example ------- >>> import numpy as np >>> f = [2e9, 3e9] >>> mode_index = [0, 1] >>> coords = dict(f=f, mode_index=mode_index) >>> data = 50.0 + 10.0 * np.random.random((2, 2)) >>> zfmd = ImpedanceFreqModeDataArray(data, coords=coords) """ __slots__ = () def _make_base_result_data_array(result: DataArray) -> IntegralResultTypes: """Helper for creating the proper base result type.""" cls = FreqDataArray if "t" in result.coords: cls = TimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = FreqModeDataArray return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) def _make_voltage_data_array(result: DataArray) -> VoltageIntegralResultTypes: """Helper for creating the proper voltage array type.""" cls = VoltageFreqDataArray if "t" in result.coords: cls = VoltageTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = VoltageFreqModeDataArray return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) def _make_current_data_array(result: DataArray) -> CurrentIntegralResultTypes: """Helper for creating the proper current array type.""" cls = CurrentFreqDataArray if "t" in result.coords: cls = CurrentTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = CurrentFreqModeDataArray return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) def _make_impedance_data_array(result: DataArray) -> ImpedanceResultTypes: """Helper for creating the proper impedance array type.""" cls = ImpedanceFreqDataArray if "t" in result.coords: cls = ImpedanceTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = ImpedanceFreqModeDataArray return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) DATA_ARRAY_TYPES = [ SpatialDataArray, ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray, FluxDataArray, FluxTimeDataArray, ModeAmpsDataArray, ModeIndexDataArray, GroupIndexDataArray, ModeDispersionDataArray, FieldProjectionAngleDataArray, FieldProjectionCartesianDataArray, FieldProjectionKSpaceDataArray, DiffractionDataArray, FreqModeDataArray, FreqDataArray, TimeDataArray, FreqModeDataArray, TriangleMeshDataArray, HeatDataArray, EMEScalarFieldDataArray, EMEScalarModeFieldDataArray, EMESMatrixDataArray, EMECoefficientDataArray, EMEModeIndexDataArray, EMEFreqModeDataArray, ChargeDataArray, SteadyVoltageDataArray, PointDataArray, CellDataArray, IndexedDataArray, IndexedFieldVoltageDataArray, IndexedVoltageDataArray, SpatialVoltageDataArray, PerturbationCoefficientDataArray, IndexedTimeDataArray, VoltageFreqDataArray, VoltageTimeDataArray, VoltageFreqModeDataArray, CurrentFreqDataArray, CurrentTimeDataArray, CurrentFreqModeDataArray, ImpedanceFreqDataArray, ImpedanceTimeDataArray, ImpedanceFreqModeDataArray, ] DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES} IndexedDataArrayTypes = Union[ IndexedDataArray, IndexedVoltageDataArray, IndexedTimeDataArray, IndexedFieldVoltageDataArray, PointDataArray, ] IntegralResultTypes = Union[FreqDataArray, FreqModeDataArray, TimeDataArray] VoltageIntegralResultTypes = Union[ VoltageFreqDataArray, VoltageFreqModeDataArray, VoltageTimeDataArray ] CurrentIntegralResultTypes = Union[ CurrentFreqDataArray, CurrentFreqModeDataArray, CurrentTimeDataArray ] ImpedanceResultTypes = Union[ ImpedanceFreqDataArray, ImpedanceFreqModeDataArray, ImpedanceTimeDataArray ]