"""Solve for modes in a 2D cross-sectional plane in a simulation, assuming translational
invariance along a given propagation axis.
"""
from __future__ import annotations
from functools import wraps
from math import isclose
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, get_args
import numpy as np
import xarray as xr
from pydantic import Field, field_validator, model_validator
from tidy3d.components.base import (
Tidy3dBaseModel,
cached_property,
)
from tidy3d.components.boundary import (
PML,
Absorber,
Boundary,
BoundarySpec,
PECBoundary,
StablePML,
)
from tidy3d.components.data.data_array import (
CurrentFreqTerminalModeDataArray,
ImpedanceFreqTerminalTerminalDataArray,
ModeIndexDataArray,
ScalarModeFieldCylindricalDataArray,
ScalarModeFieldDataArray,
VoltageFreqTerminalModeDataArray,
_make_current_data_array,
_make_impedance_data_array,
_make_voltage_data_array,
)
from tidy3d.components.data.monitor_data import ModeSolverData
from tidy3d.components.data.sim_data import SimulationData
from tidy3d.components.eme.data.sim_data import EMESimulationData
from tidy3d.components.eme.simulation import EMESimulation
from tidy3d.components.geometry.base import Box
from tidy3d.components.geometry.utils import (
SnapBehavior,
SnapLocation,
SnappingSpec,
find_snap_location,
snap_box_to_grid,
)
from tidy3d.components.material.tensor_rotation import (
bend_axis_global_axis,
medium_is_rotation_invariant,
rotation_matrix_about_local_axis,
)
from tidy3d.components.medium import (
AnisotropicMedium,
FullyAnisotropicMedium,
IsotropicUniformMediumType,
LossyMetalMedium,
)
from tidy3d.components.microwave.data.dataset import (
TransmissionLineDataset,
TransmissionLineTerminalDataset,
)
from tidy3d.components.microwave.data.monitor_data import MicrowaveModeSolverData
from tidy3d.components.microwave.impedance_calculator import ImpedanceCalculator
from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec, MicrowaveTerminalModeSpec
from tidy3d.components.microwave.monitor import MicrowaveModeMonitor, MicrowaveModeSolverMonitor
from tidy3d.components.microwave.path_integrals.factory import (
make_path_integrals,
make_path_integrals_for_terminal,
)
from tidy3d.components.microwave.path_integrals.mode_plane_analyzer import ModePlaneAnalyzer
from tidy3d.components.monitor import ModeMonitor, ModeSolverMonitor
from tidy3d.components.scene import Scene
from tidy3d.components.simulation import Simulation
from tidy3d.components.source.field import ModeSource
from tidy3d.components.subpixel_spec import SurfaceImpedance
from tidy3d.components.types import ArrayComplex3D, Direction, EMField, FreqArray
from tidy3d.components.types.base import TYPE_TAG_STR, discriminated_union
from tidy3d.components.types.mode_spec import ModeSpecType
from tidy3d.components.validators import (
validate_colocated_integration,
validate_freqs_min,
validate_freqs_not_empty,
)
from tidy3d.components.viz import make_ax, plot_params_pml
from tidy3d.constants import C_0, fp_eps
from tidy3d.exceptions import SetupError, ValidationError
from tidy3d.log import log
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal
from matplotlib.colors import Colormap
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
from tidy3d.compat import Self
from tidy3d.components.data.data_array import (
FreqModeDataArray,
)
from tidy3d.components.grid.grid import Coords, Grid
from tidy3d.components.microwave.impedance_calculator import (
CurrentIntegralType,
VoltageIntegralType,
)
from tidy3d.components.mode_spec import ModeSpec
from tidy3d.components.monitor import AbstractModeMonitor
from tidy3d.components.source.time import SourceTime
from tidy3d.components.structure import Structure
from tidy3d.components.types import (
ArrayComplex4D,
ArrayFloat1D,
ArrayFloat2D,
Ax,
Axis,
Axis2D,
BoundOptional,
EpsSpecType,
PlotScale,
Symmetry,
TensorReal,
)
from tidy3d.components.types.monitor_data import ModeSolverDataType
from tidy3d.packaging import (
check_tidy3d_extras_licensed_feature,
supports_local_subpixel,
tidy3d_extras,
)
# Importing the local solver may not work if e.g. scipy is not installed
IMPORT_ERROR_MSG = """Could not import local solver, 'ModeSolver' objects can still be constructed
but will have to be run through the server.
"""
try:
from .solver import compute_modes
LOCAL_SOLVER_IMPORTED = True
except ImportError:
log.warning(IMPORT_ERROR_MSG)
LOCAL_SOLVER_IMPORTED = False
FIELD = tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]
MODE_MONITOR_NAME = "<<<MODE_SOLVER_MONITOR>>>"
# Warning for field intensity at edges over total field intensity larger than this value
FIELD_DECAY_CUTOFF = 1e-2
# Maximum allowed size of the field data produced by the mode solver
MAX_MODES_DATA_SIZE_GB = 20
MODE_SIMULATION_TYPE = discriminated_union(Simulation | EMESimulation)
MODE_SIMULATION_DATA_TYPE = discriminated_union(SimulationData | EMESimulationData)
MODE_PLANE_TYPE = discriminated_union(Box | ModeSource | ModeMonitor | ModeSolverMonitor)
# When using ``angle_rotation`` without a bend, use a very large effective radius
EFFECTIVE_RADIUS_FACTOR = 10_000
# Log a warning when the PML covers more than this portion of the mode plane in any axis
WARN_THICK_PML_PERCENT = 50
P = ParamSpec("P")
R = TypeVar("R")
def require_fdtd_simulation(fn: Callable[P, R]) -> Callable[P, R]:
"""Decorate a function to check that ``simulation`` is an FDTD ``Simulation``."""
@wraps(fn)
def _fn(*args: P.args, **kwargs: P.kwargs) -> R:
"""New decorated function."""
self = args[0]
if not isinstance(self.simulation, Simulation):
raise SetupError(
f"The function '{fn.__name__}' is only supported "
"for 'simulation' of type FDTD 'Simulation'."
)
return fn(*args, **kwargs)
return _fn
[docs]
class ModeSolver(Tidy3dBaseModel):
"""
Interface for solving electromagnetic eigenmodes in a 2D plane with translational
invariance in the third dimension.
See Also
--------
:class:`ModeSource`:
Injects current source to excite modal profile on finite extent plane.
**Notebooks:**
* `Waveguide Y junction <../../notebooks/YJunction.html>`_
* `Photonic crystal waveguide polarization filter <../../../notebooks/PhotonicCrystalWaveguidePolarizationFilter.html>`_
**Lectures:**
* `Prelude to Integrated Photonics Simulation: Mode Injection <https://www.flexcompute.com/fdtd101/Lecture-4-Prelude-to-Integrated-Photonics-Simulation-Mode-Injection/>`_
"""
simulation: MODE_SIMULATION_TYPE = Field(
title="Simulation",
description="Simulation or EMESimulation defining all structures and mediums.",
discriminator="type",
)
plane: MODE_PLANE_TYPE = Field(
title="Plane",
description="Cross-sectional plane in which the mode will be computed.",
)
mode_spec: ModeSpecType = Field(
title="Mode specification",
description="Container with specifications about the modes to be solved for.",
discriminator=TYPE_TAG_STR,
)
freqs: FreqArray = Field(
title="Frequencies",
description="A list of frequencies at which to solve.",
)
direction: Direction = Field(
"+",
title="Propagation direction",
description="Direction of waveguide mode propagation along the axis defined by its normal "
"dimension.",
)
colocate: bool = Field(
True,
title="Colocate fields",
description="Toggle whether fields should be colocated to grid cell boundaries (i.e. "
"primal grid nodes). Default is ``True``.",
)
use_colocated_integration: bool = Field(
True,
title="Use Colocated Integration",
description="Only takes effect when ``colocate=False``. If ``True``, dot products "
"and overlap integrals still use fields interpolated to grid cell boundaries "
"(colocated), even though the field data is stored at native Yee grid positions. "
"Experimental feature that can give improved accuracy by avoiding interpolation of "
"fields to Yee cell positions for integration.",
)
_colocated_integration_validator = validate_colocated_integration()
conjugated_dot_product: bool = Field(
True,
title="Conjugated Dot Product",
description="Use conjugated or non-conjugated dot product for mode decomposition.",
)
fields: tuple[EMField, ...] = Field(
["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"],
title="Field Components",
description="Collection of field components to store in the monitor. Note that some "
"methods like ``flux``, ``dot`` require all tangential field components, while others "
"like ``mode_area`` require all E-field components.",
)
@model_validator(mode="after")
def _validate_mode_spec(self) -> Self:
"""Validate that num_modes is an integer."""
if not isinstance(self.mode_spec.num_modes, int):
self._raise_validation_error_at_loc(
ValidationError("num_modes must be an integer."), "mode_spec"
)
return self
@field_validator("simulation")
@classmethod
def _convert_to_simulation(cls, val: MODE_SIMULATION_TYPE) -> MODE_SIMULATION_TYPE:
"""Convert to regular Simulation if e.g. JaxSimulation given."""
if hasattr(val, "to_simulation"):
val = val.to_simulation()[0]
log.warning(
"'JaxSimulation' is no longer directly supported in 'ModeSolver', "
"converting to static simulation."
)
return val
[docs]
@field_validator("plane")
@classmethod
def is_plane(cls, val: MODE_PLANE_TYPE) -> MODE_PLANE_TYPE:
"""Raise validation error if not planar."""
if val.size.count(0.0) != 1:
raise ValidationError(f"ModeSolver plane must be planar, given size={val}")
return val
_freqs_not_empty = validate_freqs_not_empty()
_freqs_lower_bound = validate_freqs_min()
[docs]
@model_validator(mode="after")
def plane_in_sim_bounds(self) -> Self:
"""Check that the plane is at least partially inside the simulation bounds."""
sim_box = Box(size=self.simulation.size, center=self.simulation.center)
if not sim_box.intersects(self.plane):
self._raise_validation_error_at_loc(
SetupError("'ModeSolver.plane' must intersect 'ModeSolver.simulation'."),
"plane",
)
return self
@model_validator(mode="after")
def _warn_plane_crosses_symmetry(self) -> Self:
"""Warn if the mode plane crosses the symmetry plane of the underlying simulation but
the centers do not match."""
for dim in range(3):
if self.simulation.symmetry[dim] != 0:
crosses_symmetry = (
self.plane.bounds[0][dim] < self.simulation.center[dim]
and self.plane.bounds[1][dim] > self.simulation.center[dim]
)
if crosses_symmetry:
if not isclose(self.plane.center[dim], self.simulation.center[dim]):
log.warning(
f"The original simulation is symmetric along {'xyz'[dim]} direction. "
"The mode simulation region does cross the symmetry plane but is "
"not symmetric with respect to it. To preserve correct symmetry, "
"the requested simulation region will be expanded by the solver."
)
return self
@model_validator(mode="after")
def _validate_warn_thick_pml(self) -> Self:
"""Warn if the pml covers a significant portion of the mode plane."""
self._warn_thick_pml(simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec)
self._validate_rotate_structures()
return self
@model_validator(mode="after")
def _validate_bend_radius(self) -> Self:
"""Validate that the bend radius is not too small."""
sim_box = Box(size=self.simulation.size, center=self.simulation.center)
self._validate_mode_plane_radius(self.mode_spec, self.plane, sim_box)
return self
@model_validator(mode="after")
def _validate_rotate_structures_after(self) -> Self:
self._validate_rotate_structures()
return self
@model_validator(mode="after")
def _validate_num_grid_points(self) -> Self:
"""Upper bound on mode plane grid points.
ARPACK subspace x eigenvector size must fit in a 32-bit integer.
Tensorial problems build a 4N x 4N matrix (vs 2N x 2N for scalar),
so the dimension factor is doubled.
"""
num_cells, _, num_modes = self._num_cells_freqs_modes
matrix_dim_factor = 4 if self._is_tensorial else 2
if num_cells * (20 + 2 * num_modes) * matrix_dim_factor > 2**32 - 1:
self._raise_validation_error_at_loc(
SetupError(
"Too many grid points on the modal plane. Please reduce the modal plane size, apply a coarser grid, "
"or reduce the number of modes."
),
"plane",
)
return self
@classmethod
def _warn_thick_pml(
cls,
simulation: Simulation,
plane: Box,
mode_spec: ModeSpec,
msg_prefix: str = "'ModeSolver'",
) -> None:
"""Warn if the pml covers a significant portion of the mode plane."""
coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane)
num_cells = [len(coord_0), len(coord_1)]
effective_num_pml = cls._effective_num_pml(
simulation=simulation, plane=plane, mode_spec=mode_spec
)
for i in (0, 1):
if 2 * effective_num_pml[i] > (WARN_THICK_PML_PERCENT / 100) * num_cells[i]:
log.warning(
f"{msg_prefix}: "
f"The mode solver pml in tangential axis '{i}' "
f"covers more than '{WARN_THICK_PML_PERCENT}%' of the "
"mode plane cells. Consider using a larger mode plane "
"or smaller 'num_pml'."
)
@staticmethod
def _mode_plane(plane: Box, sim_geom: Box) -> Box:
"""Intersect the mode plane with the sim geometry to get the effective
mode plane."""
mode_plane_bnds = plane.bounds_intersection(plane.bounds, sim_geom.bounds)
return Box.from_bounds(*mode_plane_bnds)
@classmethod
def _validate_mode_plane_radius(cls, mode_spec: ModeSpec, plane: Box, sim_geom: Box) -> None:
"""Validate that the radius of a mode spec with a bend is not smaller than half the size of
the plane along the radial direction."""
if not mode_spec.bend_radius:
return
mode_plane = cls._mode_plane(plane=plane, sim_geom=sim_geom)
# radial axis is the plane axis that is not the bend axis
_, plane_axs = mode_plane.pop_axis([0, 1, 2], mode_plane.size.index(0.0))
radial_ax = plane_axs[(mode_spec.bend_axis + 1) % 2]
if np.abs(mode_spec.bend_radius) <= mode_plane.size[radial_ax] / 2 + fp_eps:
raise ValueError(
"Mode solver bend radius is smaller than half the mode plane size "
"along the radial axis, which can produce wrong results."
)
def _validate_rotate_structures(self) -> None:
"""Validate that structures can be rotated if angle_rotation is True."""
if np.abs(self.mode_spec.angle_theta) > 0 and self.mode_spec.angle_rotation:
self._validate_plane_rotation_media(
mediums=self._intersecting_media,
rotate_kwargs=self._rotation_kwargs,
freqs=self._sampling_freqs,
)
_ = self._make_rotated_structures(
structures=Scene.intersecting_structures(self.plane, self.simulation.structures),
translate_kwargs=self._rotation_translate_kwargs,
rotate_kwargs=self._rotation_kwargs,
freqs=self._sampling_freqs,
)
@staticmethod
def _medium_supports_plane_rotation(
medium: object, rotation_matrix: TensorReal, freqs: FreqArray
) -> bool:
"""Whether ``medium`` is supported by angled-plane structure rotation."""
is_uniform_isotropic = isinstance(medium, get_args(IsotropicUniformMediumType))
is_rotation_invariant_anisotropic = isinstance(
medium, (AnisotropicMedium, FullyAnisotropicMedium)
) and medium_is_rotation_invariant(
medium=medium, rotation_matrix=rotation_matrix, freqs=freqs
)
return is_uniform_isotropic or is_rotation_invariant_anisotropic
@classmethod
def _validate_plane_rotation_media(
cls,
mediums: list[object],
rotate_kwargs: dict[str, float | Axis],
freqs: FreqArray,
) -> None:
"""Reject angled-plane rotations through unsupported intersecting media."""
rotation_matrix = rotation_matrix_about_local_axis(
axis=rotate_kwargs["axis"], angle=rotate_kwargs["angle"]
)
if all(
cls._medium_supports_plane_rotation(
medium=medium,
rotation_matrix=rotation_matrix,
freqs=freqs,
)
for medium in mediums
):
return
raise SetupError( # post-init-tidy3d-error: ignore
"'angle_rotation' set to True but the mode solver plane intersects an unsupported "
"medium. Only uniform isotropic media and rotation-invariant anisotropic media are "
"supported for the plane rotation."
)
@staticmethod
def _make_rotated_structures(
structures: list[Structure],
translate_kwargs: dict[str, float],
rotate_kwargs: dict[str, float | Axis],
freqs: FreqArray,
) -> list[Structure]:
try:
rotated_structures = []
rotation_matrix = rotation_matrix_about_local_axis(
axis=rotate_kwargs["axis"], angle=rotate_kwargs["angle"]
)
for structure in structures:
medium = structure.medium
if not ModeSolver._medium_supports_plane_rotation(
medium=medium,
rotation_matrix=rotation_matrix,
freqs=freqs,
):
raise NotImplementedError(
"Mode solver plane intersects an unsupported medium. "
"Only uniform isotropic media and rotation-invariant anisotropic "
"media are supported for the plane rotation."
)
# Rotate and apply translations
geometry = structure.geometry
geometry = (
geometry.translated(**{key: -val for key, val in translate_kwargs.items()})
.rotated(**rotate_kwargs)
.translated(**translate_kwargs)
)
rotated_structures.append(structure.updated_copy(geometry=geometry))
return rotated_structures
except Exception as e:
raise SetupError( # post-init-tidy3d-error: ignore
f"'angle_rotation' set to True but could not rotate structures: {e!s}"
) from e
@staticmethod
def _rotation_validation_freqs(mode_object: ModeSource | AbstractModeMonitor) -> FreqArray:
"""Frequencies relevant to validating structure rotations for ``angle_rotation``."""
if isinstance(mode_object, ModeSource):
freqs = np.asarray(mode_object.frequency_grid, dtype=float)
else:
freqs = np.asarray(mode_object.freqs, dtype=float)
return np.asarray(
mode_object.mode_spec._sampling_freqs_mode_solver(freqs=freqs), dtype=float
)
@classmethod
def _validate_microwave_mode_spec(cls, mode_spec: MicrowaveModeSpec, plane: Box) -> None:
"""Validate that the microwave mode spec is correctly setup."""
mode_spec._check_path_integrals_within_box(plane)
@cached_property
def normal_axis(self) -> Axis:
"""Axis normal to the mode plane."""
return self.plane.size.index(0.0)
[docs]
@staticmethod
def plane_center_tangential(plane: MODE_PLANE_TYPE) -> tuple[float, float]:
"""Mode lane center in the tangential axes."""
_, plane_center = plane.pop_axis(plane.center, plane.size.index(0.0))
return plane_center
@cached_property
def normal_axis_2d(self) -> Axis2D:
"""Axis normal to the mode plane in a 2D plane that is normal to the bend_axis_3d."""
_, idx_plane = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d)
return idx_plane.index(self.normal_axis)
@staticmethod
def _solver_symmetry(simulation: Simulation, plane: Box) -> tuple[Symmetry, Symmetry]:
"""Get symmetry for solver for propagation along self.normal axis."""
normal_axis = plane.size.index(0.0)
mode_symmetry = list(simulation.symmetry)
for dim in range(3):
if not isclose(simulation.center[dim], plane.center[dim]):
mode_symmetry[dim] = 0
_, solver_sym = plane.pop_axis(mode_symmetry, axis=normal_axis)
return tuple(solver_sym)
@cached_property
def solver_symmetry(self) -> tuple[Symmetry, Symmetry]:
"""Get symmetry for solver for propagation along self.normal axis."""
return self._solver_symmetry(simulation=self.simulation, plane=self.plane)
@classmethod
def _get_output_grid(
cls,
simulation: Simulation,
plane: Box,
keep_additional_layers: bool = False,
truncate_symmetry: bool = True,
) -> Grid:
"""Full mode-solver output grid, not snapped to plane or simulation zero dims.
This is the extended grid used for output field coordinates and colocation.
It may extend beyond the final PEC-bounded eigensolve domain.
Parameters
----------
keep_additional_layers : bool = False
Do not discard layers of cells in front and behind the main layer of cells. Together they
represent the region where custom medium data is needed for proper subpixel.
truncate_symmetry : bool = True
Truncate to symmetry quadrant if symmetry present.
Returns
-------
:class:`.Grid`
The resulting grid.
"""
span_inds = simulation._discretize_inds_monitor(plane, colocate=False)
normal_axis = plane.size.index(0.0)
solver_symmetry = cls._solver_symmetry(simulation=simulation, plane=plane)
# Remove extension along monitor normal
if not keep_additional_layers:
span_inds[normal_axis, 0] += 1
span_inds[normal_axis, 1] -= 1
# Do not extend if simulation has a single pixel along a dimension
for dim, num_cells in enumerate(simulation.grid.num_cells):
if num_cells <= 1:
span_inds[dim] = [0, 1]
# Truncate to symmetry quadrant if symmetry present
if truncate_symmetry:
_, plane_inds = Box.pop_axis([0, 1, 2], normal_axis)
for dim, sym in enumerate(solver_symmetry):
if sym != 0:
span_inds[plane_inds[dim], 0] += np.diff(span_inds[plane_inds[dim]])[0] // 2
return simulation._subgrid(span_inds=span_inds)
@cached_property
def _output_grid(self) -> Grid:
"""Full output grid, not snapped to plane or simulation zero dims, and with
a small correction for symmetries. We don't do the snapping yet because 0-sized cells are
currently confusing to the subpixel averaging. The final data coordinates along the
plane normal dimension and dimensions where the simulation domain is 2D will be correctly
set after the solve."""
return self._get_output_grid(
simulation=self.simulation,
plane=self.plane,
keep_additional_layers=False,
truncate_symmetry=True,
)
@staticmethod
def _snapped_mode_domain(grid: Grid, box: Box, normal_axis: int) -> Box:
"""Snap a mode plane to grid boundary positions on tangential axes.
Returns a zero-thickness Box at the original normal-axis center with
tangential bounds expanded outward to the nearest grid boundary.
Snapping is disabled on the normal axis and degenerate axes
(``num_cells <= 1``, arising from 2D simulations).
"""
behavior = [SnapBehavior.Off] * 3
location = [SnapLocation.Boundary] * 3
for ax in range(3):
if ax != normal_axis and grid.num_cells[ax] > 1:
behavior[ax] = SnapBehavior.Expand
snap_spec = SnappingSpec(location=tuple(location), behavior=tuple(behavior))
return snap_box_to_grid(grid, box, snap_spec)
@cached_property
def _solver_grid_span_inds(self) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]:
"""Cell index ranges ``[beg, end]`` inside ``_output_grid`` for the solver grid.
The mode plane is snapped outward to ``simulation.grid`` boundaries
via ``_snapped_mode_domain``, then the snapped bounds are mapped to
cell indices in ``_output_grid``. On symmetry axes the start index
is pinned to 0 (the symmetry plane edge).
"""
output_bounds = self._output_grid.boundaries.to_list
num_cells = self._output_grid.num_cells
span_inds: list[tuple[int, int]] = [(0, num_cells[ax]) for ax in range(3)]
snapped = self._snapped_mode_domain(self.simulation.grid, self.plane, self.normal_axis)
_, tangential_axes = Box.pop_axis([0, 1, 2], self.normal_axis)
for tangential_idx, axis in enumerate(tangential_axes):
# No truncation on degenerate axes (1D mode solves in 2D simulations)
if num_cells[axis] <= 1:
continue
axis_bounds = output_bounds[axis]
# Keep the symmetry edge (index 0) unchanged.
ind_beg = 0
if self.solver_symmetry[tangential_idx] == 0:
bound_min = snapped.bounds[0][axis]
ind_beg = find_snap_location(
axis_bounds, bound_min, "lower", rel_tol=fp_eps, abs_tol=fp_eps
)
bound_max = snapped.bounds[1][axis]
ind_end = find_snap_location(
axis_bounds, bound_max, "upper", rel_tol=fp_eps, abs_tol=fp_eps
)
span_inds[axis] = (ind_beg, ind_end)
return tuple(span_inds)
@classmethod
def _get_solver_grid(
cls,
output_grid: Grid,
span_inds: tuple[tuple[int, int], tuple[int, int], tuple[int, int]],
) -> Grid:
"""Build truncated eigensolve grid as a subgrid of ``output_grid``."""
boundary_dict = {}
for axis, dim in enumerate("xyz"):
ind_beg, ind_end = span_inds[axis]
boundary_dict[dim] = output_grid.extended_subspace(
axis=axis,
ind_beg=ind_beg,
ind_end=ind_end + 1,
)
boundaries = output_grid.boundaries.updated_copy(**boundary_dict)
return output_grid.updated_copy(boundaries=boundaries)
@cached_property
def _solver_grid(self) -> Grid:
"""Truncated eigensolve grid built from ``_output_grid``."""
return self._get_solver_grid(
output_grid=self._output_grid,
span_inds=self._solver_grid_span_inds,
)
@cached_property
def _solver_grid_pad_widths(self) -> tuple[tuple[int, int], ...]:
"""Pad widths to expand solver-grid arrays back to ``_output_grid`` extents."""
return tuple(
(ind_beg, num_cells - ind_end)
for (ind_beg, ind_end), num_cells in zip(
self._solver_grid_span_inds, self._output_grid.num_cells
)
)
@cached_property
def _solver_grid_trim_inds_tangential(self) -> tuple[tuple[int, int], tuple[int, int]]:
"""Tangential-axis pad/trim widths (used by backend mode solver)."""
_, tangential_axes = Box.pop_axis([0, 1, 2], self.normal_axis)
return tuple(self._solver_grid_pad_widths[axis] for axis in tangential_axes)
@staticmethod
def _compute_solver_field_bounds(
grid: Grid,
plane: Box,
normal_axis: Axis,
symmetry: tuple = (0, 0, 0),
symmetry_center: tuple = (0.0, 0.0, 0.0),
) -> BoundOptional:
"""Compute solver_field_bounds from a grid and a box.
``normal_axis`` bounds are set to ``None``.
Tangential axes use the snapped plane bounds on the grid.
When symmetry is active on a tangential axis, the min bound is
set to the symmetry center.
"""
snapped = ModeSolver._snapped_mode_domain(grid, plane, normal_axis)
rmin = [None if ax == normal_axis else snapped.bounds[0][ax] for ax in range(3)]
rmax = [None if ax == normal_axis else snapped.bounds[1][ax] for ax in range(3)]
_, tangential_axes = Box.pop_axis([0, 1, 2], normal_axis)
for _tang_idx, axis in enumerate(tangential_axes):
if symmetry[axis] != 0:
rmin[axis] = symmetry_center[axis]
return (tuple(rmin), tuple(rmax))
@cached_property
def _solver_field_bounds(self) -> BoundOptional:
"""Per-axis bounds where solver field data is physically valid."""
return self._compute_solver_field_bounds(
grid=self.simulation.grid,
plane=self.plane,
normal_axis=self.normal_axis,
symmetry=self.simulation.symmetry,
symmetry_center=self.simulation.center,
)
@cached_property
def _num_cells_freqs_modes(self) -> tuple[int, int, int]:
"""Get the number of spatial points, number of freqs, and number of modes requested."""
num_cells = np.prod(self._solver_grid.num_cells)
num_modes = self.mode_spec.num_modes
num_freqs = len(self._sampling_freqs)
return num_cells, num_freqs, num_modes
@property
def _has_microwave_mode_spec(self) -> bool:
"""Check if the mode solver is using a :class:`.MicrowaveModeSpec` but not :class:`.MicrowaveTerminalModeSpec`."""
return (
isinstance(self.mode_spec, MicrowaveModeSpec)
and not self._has_microwave_terminal_mode_spec
)
@property
def _has_microwave_terminal_mode_spec(self) -> bool:
"""Check if the mode solver is using a :class:`.MicrowaveTerminalModeSpec`."""
return isinstance(self.mode_spec, MicrowaveTerminalModeSpec)
[docs]
@supports_local_subpixel
def solve(self) -> ModeSolverData:
""":class:`.ModeSolverData` containing the field and effective index data.
Returns
-------
ModeSolverData
:class:`.ModeSolverData` object containing the effective index and mode fields.
Note
----
By default, this method does not use subpixel averaging, which may reduce accuracy.
For better accuracy, either use the remote mode solver via ``tidy3d.web.run(...)``
or install ``tidy3d-extras`` (``pip install "tidy3d[extras]"``) and set
``config.simulation.use_local_subpixel = True``. See
:attr:`SimulationConfig.use_local_subpixel \
<tidy3d.config.sections.SimulationConfig.use_local_subpixel>`
for details.
"""
# only warn if tidy3d-extras local subpixel is not enabled
if not tidy3d_extras["use_local_subpixel"]:
log.warning(
"Use the remote mode solver with subpixel averaging for better accuracy through "
"'tidy3d.web.run(...)' or the deprecated 'tidy3d.plugins.mode.web.run(...)'. "
"Alternatively, you can install the package 'tidy3d-extras' using "
"'pip install \"tidy3d[extras]\"' and set 'config.simulation.use_local_subpixel=True'.",
log_once=True,
)
return self.data
def _freqs_for_group_index(self, freqs: FreqArray) -> FreqArray:
"""Get frequencies used to compute group index."""
return self.mode_spec._freqs_for_group_index(freqs=self.freqs)
@cached_property
def _sampling_freqs(self) -> FreqArray:
"""Get frequencies used to compute group index and interpolation."""
return self.mode_spec._sampling_freqs_mode_solver(freqs=self.freqs)
@cached_property
def grid_snapped(self) -> Grid:
"""The output grid snapped to the plane normal and to simulation 0-sized dims if any."""
return self._grid_snapped(
simulation=self.simulation,
plane=self.plane,
)
@classmethod
def _grid_snapped(cls, simulation: Simulation, plane: Box) -> Grid:
"""Grid snapped to the plane normal and to simulation 0-sized dims if any."""
mode_grid = cls._get_output_grid(
simulation=simulation,
plane=plane,
keep_additional_layers=False,
truncate_symmetry=True,
)
# snap to plane center along normal direction
grid_snapped = mode_grid.snap_to_box_zero_dim(plane)
# snap to simulation center if simulation is 0D along a tangential dimension
normal_axis = plane.size.index(0.0)
return simulation._snap_zero_dim(grid_snapped, skip_axis=normal_axis)
@cached_property
def data_raw(self) -> ModeSolverDataType:
""":class:`.ModeSolverData` containing the field and effective index on unexpanded grid.
Returns
-------
ModeSolverDataType
A mode solver data type object containing the effective index and mode fields.
"""
if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0:
return self.rotated_mode_solver_data
# Compute data on the Yee grid
mode_solver_data = self._data_on_yee_grid()
if self._has_microwave_mode_spec or self._has_microwave_terminal_mode_spec:
data = mode_solver_data.model_dump(exclude={"type", "monitor"})
data["monitor"] = mode_solver_data.monitor
mode_solver_data = MicrowaveModeSolverData(**data)
# Colocate to grid boundaries if requested
if self.colocate:
mode_solver_data = self._colocate_data(mode_solver_data=mode_solver_data)
# normalize modes
self._normalize_modes(mode_solver_data=mode_solver_data)
# filter polarization if requested
mode_solver_data = self._filter_polarization(mode_solver_data=mode_solver_data)
# filter and sort modes if requested by sort_spec
mode_solver_data = mode_solver_data.sort_modes(
sort_spec=self.mode_spec.sort_spec,
track_freq=self.mode_spec.track_freq,
)
self._field_decay_warning(mode_solver_data.symmetry_expanded)
mode_solver_data = self._filter_components(mode_solver_data)
if self.mode_spec.group_index_step > 0:
mode_solver_data = mode_solver_data._group_index_post_process(
self.mode_spec.group_index_step
)
if self.mode_spec._is_interp_spec_applied(self.freqs):
# set interp_spec back
interp_spec = self.mode_spec.interp_spec.updated_copy(reduce_data=True)
mode_solver_data = mode_solver_data.updated_copy(
monitor=mode_solver_data.monitor.updated_copy(
freqs=self.freqs,
mode_spec=self.mode_spec.updated_copy(interp_spec=interp_spec),
)
)
if not self.mode_spec.interp_spec.reduce_data:
mode_solver_data = mode_solver_data.interpolated_copy
# Calculate and add the transmission line data
if self._has_microwave_mode_spec:
mode_solver_data = self._add_microwave_data(mode_solver_data)
if self._has_microwave_terminal_mode_spec:
mode_solver_data = self._add_microwave_terminal_data(mode_solver_data)
return mode_solver_data
@cached_property
def bend_axis_3d(self) -> Axis:
"""Converts the 2D bend axis into its corresponding 3D axis for a bend structure.
For a straight waveguide, the rotated axis is equivalent to the bend axis
and can be determined using angle_phi."""
return self._bend_axis_3d_from_plane_and_mode_spec(self.plane, self.mode_spec)
@staticmethod
def _bend_axis_3d_from_plane_and_mode_spec(plane: Box, mode_spec: ModeSpecType) -> Axis:
"""3D bend axis used by the angled-mode rotation for ``plane`` and ``mode_spec``."""
normal_axis = plane.size.index(0.0)
if mode_spec.bend_axis is not None:
return bend_axis_global_axis(normal_axis=normal_axis, bend_axis=mode_spec.bend_axis)
_, idx_plane = plane.pop_axis((0, 1, 2), axis=normal_axis)
rotation_axis_index = int(abs(np.cos(mode_spec.angle_phi)))
return idx_plane[rotation_axis_index]
@cached_property
def rotated_mode_solver_data(self) -> ModeSolverData:
# Create a mode solver with rotated geometries for a reference solution with 0-degree angle
solver_ref = self.rotated_structures_copy
# Need to store all fields for the rotation
solver_ref = solver_ref.updated_copy(
fields=["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], validate=False
)
solver_ref_data = solver_ref.data_raw
# The reference data should always be colocated
n_complex = solver_ref_data.n_complex
if not solver_ref.colocate:
solver_ref_data = solver_ref._colocate_data(mode_solver_data=solver_ref_data)
# Transform the colocated mode solution from Cartesian to cylindrical coordinates
# if a bend structure is simulated.
solver_ref_data_cylindrical = self._car_2_cyn(mode_solver_data=solver_ref_data)
# # if self.mode_spec.bend_radius is None, use this instead
# solver_ref_data_straight = self._ref_data_straight(mode_solver_data=solver_ref_data)
solver = self._reduced_simulation_copy_with_fallback
# Compute the mode solution by rotating the reference data to the monitor plane
rotated_mode_fields = self._mode_rotation(
solver_ref_data_cylindrical=solver_ref_data_cylindrical,
solver=solver,
)
# # if self.mode_spec.bend_radius is None, use this instead
# rotated_mode_fields = self._mode_rotation_straight(
# solver_ref_data=solver_ref_data_straight,
# solver=solver,
# )
# TODO: At a later time, we should ensure that `eps_spec` is automatically returned to
# to compute the backward propagation mode solution using a mode solver
# with direction "-".
eps_spec = []
for _ in solver.freqs:
eps_spec.append("tensorial_complex")
# finite grid corrections
grid_factors, relative_grid_distances = solver._grid_correction(
simulation=solver.simulation,
plane=solver.plane,
mode_spec=solver.mode_spec,
n_complex=n_complex,
direction=solver.direction,
)
# Make mode solver data on the Yee grid
mode_solver_monitor = solver.to_mode_solver_monitor(name=MODE_MONITOR_NAME)
grid_expanded = solver.simulation.discretize_monitor(mode_solver_monitor)
rotated_mode_data = ModeSolverData(
monitor=mode_solver_monitor,
symmetry=solver.simulation.symmetry,
symmetry_center=solver.simulation.center,
grid_expanded=grid_expanded,
grid_primal_correction=grid_factors[0],
grid_dual_correction=grid_factors[1],
eps_spec=eps_spec,
grid_distances_primal=relative_grid_distances[0],
grid_distances_dual=relative_grid_distances[1],
**rotated_mode_fields,
)
self._normalize_modes(mode_solver_data=rotated_mode_data)
rotated_mode_data = self._filter_components(rotated_mode_data)
return rotated_mode_data
@cached_property
def rotated_structures_copy(self) -> ModeSolver:
"""Create a copy of the original ModeSolver with rotated structures
to the simulation and updates the ModeSpec to disable bend correction
and reset angles to normal."""
rotated_structures = self._rotate_structures
rotated_grid_spec = self.simulation.grid_spec
if rotated_grid_spec.auto_grid_used and rotated_grid_spec.wavelength is None:
wavelength = rotated_grid_spec.get_wavelength(self.simulation.sources)
rotated_grid_spec = rotated_grid_spec.updated_copy(wavelength=wavelength)
# Sources are not used by mode solving and can become invalid after rotating structures.
rotated_simulation = self.simulation.updated_copy(
structures=rotated_structures,
sources=(),
monitors=(),
grid_spec=rotated_grid_spec,
validate=False,
)
rotated_mode_spec = self.mode_spec.updated_copy(
angle_rotation=False, angle_theta=0, angle_phi=0
)
return self.updated_copy(simulation=rotated_simulation, mode_spec=rotated_mode_spec)
@cached_property
def _rotate_structures(self) -> list[Structure]:
"""Rotate the structures intersecting with modal plane by angle theta
if bend_correction is enabeled for bend simulations."""
structs_in = Scene.intersecting_structures(self.plane, self.simulation.structures)
return self._make_rotated_structures(
structures=structs_in,
translate_kwargs=self._rotation_translate_kwargs,
rotate_kwargs=self._rotation_kwargs,
freqs=self._sampling_freqs,
)
@cached_property
def _rotation_translate_kwargs(self) -> dict[str, float]:
"""Translations applied before and after rotating structures for ``angle_rotation``."""
return self._rotation_translate_kwargs_for_plane_and_mode_spec(self.plane, self.mode_spec)
@classmethod
def _rotation_translate_kwargs_for_plane_and_mode_spec(
cls, plane: Box, mode_spec: ModeSpecType
) -> dict[str, float]:
"""Translations applied before and after rotating structures for ``angle_rotation``."""
bend_axis_3d = cls._bend_axis_3d_from_plane_and_mode_spec(plane, mode_spec)
_, (idx_u, idx_v) = plane.pop_axis((0, 1, 2), axis=bend_axis_3d)
translate_coords = [0.0, 0.0, 0.0]
translate_coords[idx_u] = plane.center[idx_u]
translate_coords[idx_v] = plane.center[idx_v]
return dict(zip("xyz", translate_coords))
@cached_property
def _rotation_kwargs(self) -> dict[str, float | Axis]:
"""Rotation applied to intersecting structures for ``angle_rotation``."""
return self._rotation_kwargs_for_plane_and_mode_spec(self.plane, self.mode_spec)
@classmethod
def _rotation_kwargs_for_plane_and_mode_spec(
cls, plane: Box, mode_spec: ModeSpecType
) -> dict[str, float | Axis]:
"""Rotation applied to intersecting structures for ``angle_rotation``."""
normal_axis = plane.size.index(0.0)
bend_axis_3d = cls._bend_axis_3d_from_plane_and_mode_spec(plane, mode_spec)
angle_theta = mode_spec.angle_theta
angle_phi = mode_spec.angle_phi
theta_map = {
(0, 2): -angle_theta * np.cos(angle_phi),
(0, 1): angle_theta * np.sin(angle_phi),
(1, 2): angle_theta * np.cos(angle_phi),
(1, 0): -angle_theta * np.sin(angle_phi),
(2, 1): -angle_theta * np.cos(angle_phi),
(2, 0): angle_theta * np.sin(angle_phi),
}
theta = theta_map.get((normal_axis, bend_axis_3d), 0.0)
return {"angle": theta, "axis": bend_axis_3d}
@cached_property
def rotated_bend_center(self) -> list:
"""Calculate the center at the rotated bend such that the modal plane is normal
to the azimuthal direction of the bend."""
rotated_bend_center = list(self.plane.center)
idx_rotate = 3 - self.normal_axis - self.bend_axis_3d
rotated_bend_center[idx_rotate] -= self._bend_radius
return rotated_bend_center
# # Leaving for future reference if needed
# def _ref_data_straight(
# self, mode_solver_data: ModeSolverData
# ) -> dict[Union[ScalarModeFieldDataArray, ModeIndexDataArray]]:
# """Convert reference data to be centered at the monitor center."""
# # Reference solution stored
# lateral_axis = 3 - self.normal_axis - self.bend_axis_3d
# axes = ("x", "y", "z")
# normal_dim = axes[self.normal_axis]
# lateral_dim = axes[lateral_axis]
# solver_data_straight = {}
# for name, field in mode_solver_data.field_components.items():
# solver_data_straight[name] = field.copy()
# for dim, dim_name in enumerate("xyz"):
# # Only shift coordinates for normal_dim and lateral_dim
# if dim_name in (normal_dim, lateral_dim):
# coords_shift = field.coords[dim_name] - self.plane.center[dim]
# solver_data_straight[name].coords[dim_name] = coords_shift
# solver_data_straight["n_complex"] = mode_solver_data.n_complex
# return solver_data_straight
def _car_2_cyn(
self, mode_solver_data: ModeSolverData
) -> dict[str, ScalarModeFieldCylindricalDataArray | ModeIndexDataArray]:
"""Convert cartesian fields to cylindrical fields centered at the
rotated bend center."""
# Extract coordinates from one of the six field components as they are colocated
pts = [mode_solver_data.Ex[name].values.copy() for name in ["x", "y", "z"]]
f, mode_index = mode_solver_data.Ex.f, mode_solver_data.Ex.mode_index
lateral_axis = 3 - self.normal_axis - self.bend_axis_3d
idx_w, idx_uv = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d)
idx_u, idx_v = idx_uv
pts[lateral_axis] -= self.rotated_bend_center[lateral_axis]
rho = np.sort(np.abs(pts[lateral_axis]))
theta = np.atleast_1d(
np.arctan2(
self.plane.center[idx_v] - self.rotated_bend_center[idx_v],
self.plane.center[idx_u] - self.rotated_bend_center[idx_u],
)
)
axial = pts[idx_w]
# Initialize output data arrays for cylindrical fields
field_data_cylindrical = {
field_name_cylindrical: np.zeros(
(len(rho), len(theta), len(axial), len(f), len(mode_index)), dtype=complex
)
for field_name_cylindrical in ("Er", "Etheta", "Eaxial", "Hr", "Htheta", "Haxial")
}
cos_theta, sin_theta = np.cos(theta), np.sin(theta)
axes = ("x", "y", "z")
axial_name, plane_names = self.plane.pop_axis(axes, axis=self.bend_axis_3d)
cmp_1, cmp_2 = plane_names
plane_name_normal = axes[self.normal_axis]
plane_name_lateral = axes[lateral_axis]
# Determine which coordinate transformation to use based on the lateral axis
if cmp_1 == plane_name_lateral:
lateral_coord_value = rho * cos_theta + self.rotated_bend_center[idx_u]
elif cmp_2 == plane_name_lateral:
lateral_coord_value = rho * sin_theta + self.rotated_bend_center[idx_v]
# Transform fields from cartesian to cylindrical
fields_interp = {
field_name: getattr(mode_solver_data, field_name)
.sel({plane_name_normal: self.plane.center[self.normal_axis]}, method="nearest")
.interp(
{plane_name_lateral: lateral_coord_value},
method="linear",
kwargs={"fill_value": "extrapolate"},
)
.transpose(plane_name_lateral, ...)
.values
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz")
}
for field_type in ["E", "H"]:
field_data_cylindrical[field_type + "r"][:, 0] = (
fields_interp[field_type + cmp_1] * cos_theta
+ fields_interp[field_type + cmp_2] * sin_theta
)
field_data_cylindrical[field_type + "theta"][:, 0] = (
-fields_interp[field_type + cmp_1] * sin_theta
+ fields_interp[field_type + cmp_2] * cos_theta
)
field_data_cylindrical[field_type + "axial"][:, 0] = fields_interp[
field_type + axial_name
]
coords = {
"rho": rho,
"theta": theta,
"axial": axial,
"f": f,
"mode_index": mode_index,
}
solver_data_cylindrical = {
name: ScalarModeFieldCylindricalDataArray(field_data_cylindrical[name], coords=coords)
for name in ("Er", "Etheta", "Eaxial", "Hr", "Htheta", "Haxial")
}
solver_data_cylindrical["n_complex"] = mode_solver_data.n_complex
return solver_data_cylindrical
# # Leaving for future reference if needed
# def _mode_rotation_straight(
# self,
# solver_ref_data: dict[Union[ModeSolverData]],
# solver: ModeSolver,
# ) -> ModeSolverData:
# """Rotate the mode solver solution from the reference plane
# to the desired monitor plane."""
# rotated_data_arrays = {}
# for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
# if self.colocate is True:
# # Get colocation coordinates in the solver plane
# normal_dim, _ = self.plane.pop_axis("xyz", self.normal_axis)
# colocate_coords = self._get_colocation_coordinates()
# colocate_coords[normal_dim] = np.atleast_1d(self.plane.center[self.normal_axis])
# x = colocate_coords["x"]
# y = colocate_coords["y"]
# z = colocate_coords["z"]
# xyz_coords = [x.copy(), y.copy(), z.copy()]
# else:
# # Extract coordinate values from the corresponding field component
# xyz_coords = solver.grid_snapped[field_name].to_list
# x, y, z = (coord.copy() for coord in xyz_coords)
# lateral_axis = 3 - self.normal_axis - self.bend_axis_3d
# axes = ("x", "y", "z")
# normal_dim = axes[self.normal_axis]
# lateral_dim = axes[lateral_axis]
# axial_dim = axes[self.bend_axis_3d]
# pts = [coord.copy() for coord in xyz_coords]
# pts[self.normal_axis] -= self.plane.center[self.normal_axis]
# pts[lateral_axis] -= self.plane.center[lateral_axis]
# axial = pts[self.bend_axis_3d]
# f = np.atleast_1d(self.freqs)
# mode_index = np.arange(self.mode_spec.num_modes)
# k0 = 2 * np.pi * f / C_0
# n_eff = solver_ref_data["n_complex"].values
# beta = k0[:, None] * n_eff
# beta = beta.reshape(1, 1, 1, len(f), len(mode_index))
# # Interpolation coords for amplitude and phase of local fields
# amp_interp_coors = pts[lateral_axis] * np.cos(self.mode_spec.angle_theta)
# phase_interp_coors = pts[lateral_axis] * np.sin(self.mode_spec.angle_theta)
# # Initialize output arrays
# shape = (len(x), len(y), len(z), len(f), len(mode_index))
# phase_fields = np.zeros(shape, dtype=complex)
# rotated_field_cmp = np.zeros(shape, dtype=complex)
# phase_shape = [1] * 5
# phase_shape[lateral_axis] = shape[lateral_axis]
# phase_interp_coors = phase_interp_coors.reshape(phase_shape)
# sign = 1 if self.normal_axis_2d == 0 else 1
# if self.direction == "+":
# phase_fields[...] = np.exp(sign * 1j * phase_interp_coors * beta)
# else:
# phase_fields[...] = np.exp(sign * -1j * phase_interp_coors * beta)
# # Interpolate field components
# local_fields = {}
# for field in ["E", "H"]:
# for comp in ["x", "y", "z"]:
# data = solver_ref_data[f"{field}{comp}"].interp(
# {axial_dim: axial, lateral_dim: amp_interp_coors},
# method="linear",
# kwargs={"fill_value": "extrapolate"},
# )
# local_fields[f"{field}{comp}"] = phase_fields * data.values
# if field_name == f"E{lateral_dim}":
# rotated_field_cmp = local_fields[f"E{lateral_dim}"] * np.cos(
# self.mode_spec.angle_theta
# )
# +local_fields[f"E{normal_dim}"] * np.sin(self.mode_spec.angle_theta)
# elif field_name == f"E{normal_dim}":
# rotated_field_cmp = local_fields[f"E{normal_dim}"] * np.sin(
# self.mode_spec.angle_theta
# )
# -local_fields[f"E{lateral_dim}"] * np.sin(self.mode_spec.angle_theta)
# elif field_name == f"E{axial_dim}":
# rotated_field_cmp = local_fields[f"E{axial_dim}"]
# if field_name == f"H{lateral_dim}":
# rotated_field_cmp = local_fields[f"H{lateral_dim}"] * np.cos(
# self.mode_spec.angle_theta
# )
# +local_fields[f"H{normal_dim}"] * np.sin(self.mode_spec.angle_theta)
# elif field_name == f"H{normal_dim}":
# rotated_field_cmp = local_fields[f"H{normal_dim}"] * np.sin(
# self.mode_spec.angle_theta
# )
# -local_fields[f"H{lateral_dim}"] * np.sin(self.mode_spec.angle_theta)
# elif field_name == f"H{axial_dim}":
# rotated_field_cmp = local_fields[f"H{axial_dim}"]
# coords = {"x": x, "y": y, "z": z, "f": f, "mode_index": mode_index}
# rotated_data_arrays[field_name] = ScalarModeFieldDataArray(
# rotated_field_cmp, coords=coords
# )
# rotated_data_arrays["n_complex"] = solver_ref_data["n_complex"]
# return rotated_data_arrays
def _mode_rotation(
self,
solver_ref_data_cylindrical: dict[
str, ScalarModeFieldCylindricalDataArray | ModeIndexDataArray
],
solver: ModeSolver,
) -> ModeSolverData:
"""Rotate the mode solver solution from the reference plane in cylindrical coordinates
to the desired monitor plane."""
rotated_data_arrays = {}
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
if self.colocate is True:
# Get colocation coordinates in the solver plane
normal_dim, _ = self.plane.pop_axis("xyz", self.normal_axis)
colocate_coords = self._get_colocation_coordinates()
colocate_coords[normal_dim] = np.atleast_1d(self.plane.center[self.normal_axis])
x = colocate_coords["x"]
y = colocate_coords["y"]
z = colocate_coords["z"]
xyz_coords = [x.copy(), y.copy(), z.copy()]
else:
# Extract coordinate values from one of the six field components
xyz_coords = solver.grid_snapped[field_name].to_list
x, y, z = (coord.copy() for coord in xyz_coords)
f = np.atleast_1d(self._sampling_freqs)
mode_index = np.arange(self.mode_spec.num_modes)
# Initialize output arrays
shape = (x.size, y.size, z.size, len(f), mode_index.size)
rotated_field_cmp = np.zeros(shape, dtype=complex)
idx_w, idx_uv = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d)
idx_u, idx_v = idx_uv
pts = [coord.copy() for coord in xyz_coords]
pts[idx_u] -= self.bend_center[idx_u]
pts[idx_v] -= self.bend_center[idx_v]
rho = np.sqrt(pts[idx_u] ** 2 + pts[idx_v] ** 2)
theta = np.arctan2(pts[idx_v], pts[idx_u])
axial = pts[idx_w]
theta_rel = theta - self.theta_reference
cos_theta = pts[idx_u] / rho
sin_theta = pts[idx_v] / rho
cmp_normal, source_names = self.plane.pop_axis(("x", "y", "z"), axis=self.bend_axis_3d)
cmp_1, cmp_2 = source_names
k0 = 2 * np.pi * f / C_0
n_eff = solver_ref_data_cylindrical["n_complex"].values
beta = k0[:, None] * n_eff * np.abs(self._bend_radius)
lateral_axis = 3 - self.normal_axis - self.bend_axis_3d
# Interpolate field components
fields = {}
for field in ["E", "H"]:
for comp in ["r", "theta", "axial"]:
data = (
solver_ref_data_cylindrical[f"{field}{comp}"]
.isel(theta=0)
.interp(
rho=rho,
axial=axial,
method="linear",
kwargs={"fill_value": "extrapolate"},
)
)
if lateral_axis > self.bend_axis_3d:
data = data.transpose("axial", ...)
fields[f"{field}{comp}"] = data.values
# Determine the phase factor based on normal_axis_2d
sign = -1 if self.normal_axis_2d == 0 else 1
if (self.direction == "+" and self._bend_radius >= 0) or (
self.direction == "-" and self._bend_radius < 0
):
phase = np.exp(sign * 1j * theta_rel[:, None, None] * beta)
else:
phase = np.exp(sign * -1j * theta_rel[:, None, None] * beta)
# Set fixed index to normal_axis
idx = [slice(None)] * 3
idx[self.normal_axis] = 0
idx_x, idx_y, idx_z = idx
# Assign rotated fields
if lateral_axis > self.bend_axis_3d:
phase_expansion = phase[None, :]
cos_theta_expansion = cos_theta[None, :, None, None]
sin_theta_expansion = sin_theta[None, :, None, None]
else:
phase_expansion = phase[:, None]
cos_theta_expansion = cos_theta[:, None, None, None]
sin_theta_expansion = sin_theta[:, None, None, None]
if field_name == f"E{cmp_1}":
rotated_field_cmp[idx_x, idx_y, idx_z, :, :] = phase_expansion * (
fields["Er"] * cos_theta_expansion - fields["Etheta"] * sin_theta_expansion
)
elif field_name == f"E{cmp_2}":
rotated_field_cmp[idx_x, idx_y, idx_z, :, :] = phase_expansion * (
fields["Er"] * sin_theta_expansion + fields["Etheta"] * cos_theta_expansion
)
elif field_name == f"E{cmp_normal}":
rotated_field_cmp[idx_x, idx_y, idx_z, :, :] = phase_expansion * fields["Eaxial"]
elif field_name == f"H{cmp_1}":
rotated_field_cmp[idx_x, idx_y, idx_z, :, :] = phase_expansion * (
fields["Hr"] * cos_theta_expansion - fields["Htheta"] * sin_theta_expansion
)
elif field_name == f"H{cmp_2}":
rotated_field_cmp[idx_x, idx_y, idx_z, :, :] = phase_expansion * (
fields["Hr"] * sin_theta_expansion + fields["Htheta"] * cos_theta_expansion
)
elif field_name == f"H{cmp_normal}":
rotated_field_cmp[idx_x, idx_y, idx_z, :, :] = phase_expansion * fields["Haxial"]
coords = {"x": x, "y": y, "z": z, "f": f, "mode_index": mode_index}
rotated_data_arrays[field_name] = ScalarModeFieldDataArray(
rotated_field_cmp, coords=coords
)
rotated_data_arrays["n_complex"] = solver_ref_data_cylindrical["n_complex"]
return rotated_data_arrays
@cached_property
def theta_reference(self) -> float:
"""Computes the azimutal angle of the reference modal plane."""
_, local_coords = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d)
local_coord_x, local_coord_y = local_coords
theta_ref = np.arctan2(
self.plane.center[local_coord_y] - self.bend_center[local_coord_y],
self.plane.center[local_coord_x] - self.bend_center[local_coord_x],
)
return theta_ref
@cached_property
def _bend_radius(self) -> float:
"""A bend_radius to use when ``angle_rotation`` is on. When there is no bend defined, we
use an effectively very large radius, much larger than the mode plane. This is only used
for the rotation of the fields - the reference modes are still computed without any
bend applied."""
if self.mode_spec.bend_radius is not None:
return self.mode_spec.bend_radius
mode_plane_bnds = self.plane.bounds_intersection(self.plane.bounds, self.simulation.bounds)
largest_dim = np.amax(np.array(mode_plane_bnds[1]) - np.array(mode_plane_bnds[0]))
return EFFECTIVE_RADIUS_FACTOR * largest_dim
@cached_property
def bend_center(self) -> list[float]:
"""Computes the bend center based on plane center, angle_theta and angle_phi."""
_, id_bend_uv = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d)
id_bend_u, id_bend_v = id_bend_uv
theta = self.mode_spec.angle_theta
phi = self.mode_spec.angle_phi
bend_radius = self._bend_radius
bend_center = list(self.plane.center)
angle_map = {
(0, 2): lambda: theta * np.cos(phi),
(0, 1): lambda: theta * np.sin(phi),
(1, 2): lambda: theta * np.cos(phi),
(1, 0): lambda: theta * np.sin(phi),
(2, 0): lambda: theta * np.sin(phi),
(2, 1): lambda: theta * np.cos(phi),
}
update_map = {
(0, 2): lambda angle: (bend_radius * np.sin(angle), -bend_radius * np.cos(angle)),
(0, 1): lambda angle: (bend_radius * np.sin(angle), -bend_radius * np.cos(angle)),
(1, 2): lambda angle: (-bend_radius * np.cos(angle), bend_radius * np.sin(angle)),
(1, 0): lambda angle: (bend_radius * np.sin(angle), -bend_radius * np.cos(angle)),
(2, 0): lambda angle: (-bend_radius * np.cos(angle), bend_radius * np.sin(angle)),
(2, 1): lambda angle: (-bend_radius * np.cos(angle), bend_radius * np.sin(angle)),
}
if (self.normal_axis, self.bend_axis_3d) in angle_map:
angle = angle_map[(self.normal_axis, self.bend_axis_3d)]()
delta_u, delta_v = update_map[(self.normal_axis, self.bend_axis_3d)](angle)
bend_center[id_bend_u] = self.plane.center[id_bend_u] + delta_u
bend_center[id_bend_v] = self.plane.center[id_bend_v] + delta_v
return bend_center
@cached_property
def _reduced_simulation_copy_with_fallback(self) -> ModeSolver:
"""Try to get a reduced simulation copy. If it fails, fall back to the non-reduced simulation."""
# we try to do reduced simulation copy for efficiency
# it should never fail -- if it does, this is likely due to an oversight
# in the Simulation.subsection method. but falling back to non-reduced
# simulation prevents unneeded errors in this case
try:
return self.reduced_simulation_copy
except Exception as e:
log.warning(
"Mode solver reduced_simulation_copy failed. "
"Falling back to non-reduced simulation, which may be slower. "
f"Exception: {e!s}"
)
return self
@staticmethod
def _slice_material_to_solver_grid(
mat_data: np.ndarray | None, trim_inds: tuple[tuple[int, int], ...]
) -> np.ndarray | None:
"""Slice modal-plane material arrays from output-grid extents to solver-grid extents.
Parameters
----------
trim_inds : tuple of (trim_before, trim_after) per spatial axis
Number of cells to remove from each end. ``trim_after=0`` keeps to the end.
"""
if mat_data is None:
return None
if all(trim == (0, 0) for trim in trim_inds):
return mat_data
slices = tuple(slice(before, -after if after else None) for before, after in trim_inds)
return mat_data[(..., *slices)]
def _trim_output_field_to_solver_grid(self, field: ArrayComplex4D) -> ArrayComplex4D:
"""Trim an output-grid field array down to ``_solver_grid`` extents."""
(x0, x1), (y0, y1), (z0, z1) = self._solver_grid_span_inds
return field[x0:x1, y0:y1, z0:z1, :]
def _pad_solver_field_to_output_grid(self, field: ArrayComplex4D) -> ArrayComplex4D:
"""Zero-pad a solver-grid field array to ``_output_grid`` extents."""
pad_widths = self._solver_grid_pad_widths
if all(before == 0 and after == 0 for before, after in pad_widths):
return field
return np.pad(
field,
pad_width=(*pad_widths, (0, 0)),
mode="constant",
constant_values=0.0,
)
def _pad_solver_fields_to_output_grid(
self, fields: list[dict[str, ArrayComplex4D]]
) -> list[dict[str, ArrayComplex4D]]:
"""Zero-pad field dictionaries from solver-grid extents to output-grid extents."""
pad_widths = self._solver_grid_pad_widths
if all(before == 0 and after == 0 for before, after in pad_widths):
return fields
return [
{
field_name: self._pad_solver_field_to_output_grid(field)
for field_name, field in fields_freq.items()
}
for fields_freq in fields
]
def _data_on_yee_grid(self) -> ModeSolverData:
"""Solve for all modes, and construct data with fields on the Yee grid."""
solver = self._reduced_simulation_copy_with_fallback
# Replace freqs with the pre-expanded sampling frequencies and strip
# interp_spec / group_index_step so they are not re-applied on the copy
# (the expansion has already been folded into _sampling_freqs).
solver = solver.updated_copy(
freqs=self._sampling_freqs,
mode_spec=self.mode_spec.updated_copy(interp_spec=None, group_index_step=False),
)
_, _solver_coords = solver.plane.pop_axis(
solver._solver_grid.boundaries.to_list, axis=solver.normal_axis
)
# Compute and store the modes at all frequencies
n_complex, fields, eps_spec = solver._solve_all_freqs(
coords=_solver_coords,
symmetry=solver.solver_symmetry,
)
fields = solver._pad_solver_fields_to_output_grid(fields)
# start a dictionary storing the data arrays for the ModeSolverData
index_data = ModeIndexDataArray(
np.stack(n_complex, axis=0),
coords={
"f": list(solver.freqs),
"mode_index": np.arange(solver.mode_spec.num_modes),
},
)
data_dict = {"n_complex": index_data}
# Construct the field data on Yee grid
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
xyz_coords = solver.grid_snapped[field_name].to_list
scalar_field_data = ScalarModeFieldDataArray(
np.stack([field_freq[field_name] for field_freq in fields], axis=-2),
coords={
"x": xyz_coords[0],
"y": xyz_coords[1],
"z": xyz_coords[2],
"f": list(solver.freqs),
"mode_index": np.arange(solver.mode_spec.num_modes),
},
)
data_dict[field_name] = scalar_field_data
# finite grid corrections
grid_factors, relative_grid_distances = solver._grid_correction(
simulation=solver.simulation,
plane=solver.plane,
mode_spec=solver.mode_spec,
n_complex=index_data,
direction=solver.direction,
)
# make mode solver data on the Yee grid
mode_solver_monitor = solver.to_mode_solver_monitor(name=MODE_MONITOR_NAME, colocate=False)
grid_expanded = solver.simulation.discretize_monitor(mode_solver_monitor)
mode_solver_data = ModeSolverData(
monitor=mode_solver_monitor,
symmetry=solver.simulation.symmetry,
symmetry_center=solver.simulation.center,
grid_expanded=grid_expanded,
grid_primal_correction=grid_factors[0],
grid_dual_correction=grid_factors[1],
grid_distances_primal=relative_grid_distances[0],
grid_distances_dual=relative_grid_distances[1],
eps_spec=eps_spec,
**data_dict,
)
return mode_solver_data
def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData:
"""Solve for all modes, and construct data with fields on the Yee grid."""
if basis.monitor.colocate:
raise ValidationError("Relative mode solver 'basis' must have 'colocate=False'.")
_, _solver_coords = self.plane.pop_axis(
self._solver_grid.boundaries.to_list, axis=self.normal_axis
)
basis_fields = []
for freq_ind in range(len(basis.n_complex.f)):
basis_fields_freq = {}
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
basis_fields_freq[field_name] = self._trim_output_field_to_solver_grid(
basis.field_components[field_name].isel(f=freq_ind).to_numpy()
)
basis_fields.append(basis_fields_freq)
# Compute and store the modes at all frequencies
n_complex, fields, eps_spec = self._solve_all_freqs_relative(
coords=_solver_coords,
symmetry=self.solver_symmetry,
basis_fields=basis_fields,
)
fields = self._pad_solver_fields_to_output_grid(fields)
# start a dictionary storing the data arrays for the ModeSolverData
index_data = ModeIndexDataArray(
np.stack(n_complex, axis=0),
coords={
"f": list(self.freqs),
"mode_index": np.arange(self.mode_spec.num_modes),
},
)
data_dict = {"n_complex": index_data}
# Construct the field data on Yee grid
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
xyz_coords = self.grid_snapped[field_name].to_list
scalar_field_data = ScalarModeFieldDataArray(
np.stack([field_freq[field_name] for field_freq in fields], axis=-2),
coords={
"x": xyz_coords[0],
"y": xyz_coords[1],
"z": xyz_coords[2],
"f": list(self.freqs),
"mode_index": np.arange(self.mode_spec.num_modes),
},
)
data_dict[field_name] = scalar_field_data
# finite grid corrections
grid_factors, relative_grid_distances = self._grid_correction(
simulation=self.simulation,
plane=self.plane,
mode_spec=self.mode_spec,
n_complex=index_data,
direction=self.direction,
)
# make mode solver data on the Yee grid
mode_solver_monitor = self.to_mode_solver_monitor(name=MODE_MONITOR_NAME, colocate=False)
grid_expanded = self.simulation.discretize_monitor(mode_solver_monitor)
mode_solver_data = ModeSolverData(
monitor=mode_solver_monitor,
symmetry=self.simulation.symmetry,
symmetry_center=self.simulation.center,
grid_expanded=grid_expanded,
grid_primal_correction=grid_factors[0],
grid_dual_correction=grid_factors[1],
grid_distances_primal=relative_grid_distances[0],
grid_distances_dual=relative_grid_distances[1],
eps_spec=eps_spec,
**data_dict,
)
return mode_solver_data
def _get_colocation_coordinates(self) -> dict[str, ArrayFloat1D]:
"""Get colocation coordinates in the solver plane.
Returns:
colocate_coords (dict): Dictionary containing the colocation coordinates for each dimension.
"""
# Get colocation coordinates in the solver plane
_, plane_dims = self.plane.pop_axis("xyz", self.normal_axis)
colocate_coords = {}
for dim, sym in zip(plane_dims, self.solver_symmetry):
coords = self.grid_snapped.boundaries.to_dict[dim]
if len(coords) > 2:
if sym == 0:
colocate_coords[dim] = coords[1:-1]
else:
colocate_coords[dim] = coords[:-1]
return colocate_coords
def _colocate_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
"""Colocate data to Yee grid boundaries."""
colocate_coords = self._get_colocation_coordinates()
expanded = mode_solver_data.symmetry_expanded
data_dict_colocated = {}
for key, field in expanded.field_components.items():
result = field.interp_within_domain(
colocate_coords,
expanded.solver_field_bounds,
assume_sorted=True,
)
data_dict_colocated[key] = result.astype(field.dtype)
# Update data
mode_solver_monitor = self.to_mode_solver_monitor(
name=MODE_MONITOR_NAME,
mode_spec=mode_solver_data.monitor.mode_spec,
freqs=mode_solver_data.monitor.freqs,
)
grid_expanded = self.simulation.discretize_monitor(mode_solver_monitor)
data_dict_colocated.update({"monitor": mode_solver_monitor, "grid_expanded": grid_expanded})
mode_solver_data = mode_solver_data.updated_copy(**data_dict_colocated, deep=False)
return mode_solver_data
def _normalize_modes(self, mode_solver_data: ModeSolverData) -> None:
"""Normalize modes. Note: this modifies ``mode_solver_data`` in-place."""
mode_solver_data._normalize_modes()
def _filter_components(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
skip_components = {
comp: None
for comp in mode_solver_data.field_components.keys()
if comp not in self.fields
}
return mode_solver_data.updated_copy(**skip_components, validate=False)
def _filter_polarization(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
"""Filter polarization."""
filter_pol = self.mode_spec.filter_pol
if filter_pol is None:
return mode_solver_data
num_freqs = len(self._sampling_freqs)
num_modes = self.mode_spec.num_modes
identity = np.arange(num_modes)
sort_inds_2d = np.tile(identity, (num_freqs, 1))
pol_frac = mode_solver_data.pol_fraction
for ifreq in range(num_freqs):
te_frac = pol_frac.te.isel(f=ifreq).values
if filter_pol == "te":
sort_inds = np.concatenate(
(
np.where(te_frac >= 0.5)[0],
np.where(te_frac < 0.5)[0],
np.where(np.isnan(te_frac))[0],
)
)
elif filter_pol == "tm":
sort_inds = np.concatenate(
(
np.where(te_frac <= 0.5)[0],
np.where(te_frac > 0.5)[0],
np.where(np.isnan(te_frac))[0],
)
)
sort_inds_2d[ifreq, : len(sort_inds)] = sort_inds
# If no reordering needed across all frequencies, skip
if np.all(sort_inds_2d == np.tile(identity, (num_freqs, 1))):
return mode_solver_data
return mode_solver_data._apply_mode_reorder(sort_inds_2d)
def _make_path_integrals(
self,
) -> tuple[tuple[VoltageIntegralType | None], tuple[CurrentIntegralType | None]]:
"""Wrapper for making path integrals from the MicrowaveModeSpec. Note: overriden in the backend to support
auto creation of path integrals."""
if not self._has_microwave_mode_spec:
raise ValueError(
"Cannot make path integrals for when 'mode_spec' is not a 'MicrowaveModeSpec'."
)
return make_path_integrals(self.mode_spec)
def _make_path_integrals_for_terminal(
self,
) -> dict[str, tuple[VoltageIntegralType | None, CurrentIntegralType | None]]:
"""Wrapper for making path integrals for each terminal from the MicrowaveTerminalModeSpec."""
if not self._has_microwave_terminal_mode_spec:
raise ValueError(
"Cannot make path integrals for when 'mode_spec' is not a 'MicrowaveTerminalModeSpec'."
)
return make_path_integrals_for_terminal(self.mode_spec)
def _generate_transmission_line_data(
self, mode_solver_data: MicrowaveModeSolverData
) -> TransmissionLineDataset:
"""Generate transmission line data from mode solver data using path integrals.
Parameters
----------
mode_solver_data : MicrowaveModeSolverData
Mode solver data to compute transmission line quantities from.
Returns
-------
TransmissionLineDataset
Dataset containing characteristic impedance, voltage coefficients, and current coefficients.
"""
voltage_integrals, current_integrals = self._make_path_integrals()
# Need to operate on the full symmetry expanded fields
mode_solver_data_expanded = mode_solver_data.symmetry_expanded_copy
Z0_list = []
V_list = []
I_list = []
if len(voltage_integrals) == 1 and self.mode_spec.num_modes > 1:
voltage_integrals = voltage_integrals * self.mode_spec.num_modes
current_integrals = current_integrals * self.mode_spec.num_modes
for mode_index in range(self.mode_spec.num_modes):
vi = voltage_integrals[mode_index]
ci = current_integrals[mode_index]
if vi is None and ci is None:
continue
impedance_calc = ImpedanceCalculator(
voltage_integral=voltage_integrals[mode_index],
current_integral=current_integrals[mode_index],
)
single_mode_data = mode_solver_data_expanded._isel(mode_index=[mode_index])
Z0, voltage, current = impedance_calc.compute_impedance(
single_mode_data, return_voltage_and_current=True
)
Z0_list.append(Z0)
V_list.append(voltage)
I_list.append(current)
all_mode_Z0 = xr.concat(Z0_list, dim="mode_index")
all_mode_Z0 = _make_impedance_data_array(all_mode_Z0)
all_mode_V = xr.concat(V_list, dim="mode_index")
all_mode_V = _make_voltage_data_array(all_mode_V)
all_mode_I = xr.concat(I_list, dim="mode_index")
all_mode_I = _make_current_data_array(all_mode_I)
return TransmissionLineDataset(
Z0=all_mode_Z0, voltage_coeffs=all_mode_V, current_coeffs=all_mode_I
)
@staticmethod
def _assemble_transform_matrices(
input_list: dict[str, list] | None,
terminal_labels: list[str],
) -> xr.DataArray:
"""Assemble voltage/current transform matrices from per-terminal, per-mode data.
Parameters
----------
input_list : Optional[dict[str, list]]
Dictionary mapping terminal labels to lists of voltage/current data per mode.
terminal_labels : list[str]
Ordered list of terminal labels.
Returns
-------
xr.DataArray
Transform matrix with dimensions (f, terminal_label, mode_index).
"""
# Stack voltage/current data for each terminal across modes, then stack terminals
# Result: (f, terminal_label, mode_index)
stacked_list = []
for terminal_label in terminal_labels:
# Concat across modes for this terminal: list of FreqDataArray -> (f, mode_index)
mode_data = xr.concat(input_list[terminal_label], dim="mode_index")
stacked_list.append(mode_data)
# Stack all terminals: (f, mode_index) per terminal -> (f, terminal_label, mode_index)
result = xr.concat(stacked_list, dim="terminal_label")
result = result.assign_coords(terminal_label=terminal_labels)
# Assign explicit mode_index coordinates
num_modes = len(input_list[terminal_labels[0]])
result = result.assign_coords(mode_index=np.arange(num_modes))
# Ensure dimension order is (f, terminal_label, mode_index)
result = result.transpose("f", "terminal_label", "mode_index")
return result
@staticmethod
def _construct_differential_pair_transform(
terminals_mapping: dict[str, str | tuple[str, str]] | None,
terminal_labels: list[str],
voltage_transform: bool = True,
) -> np.ndarray:
"""Construct differential pair transformation matrix Q.
This matrix transforms single-ended terminal voltages to output terminal voltages,
which may include differential pairs. For differential pairs:
- Common mode: V_comm = 0.5 * (V_1 + V_2), I_comm = I_1 + I_2
- Differential mode: V_diff = V_1 - V_2, I_diff = 0.5 * (I_1 - I_2)
Parameters
----------
terminals_mapping : Optional[dict[str, Union[str, tuple[str, str]]]]
Mapping from output terminals (including differential pairs) to input single-ended terminals.
Keys ending with "@comm" or "@diff" indicate differential pair modes.
terminal_labels : list[str]
Ordered list of input single-ended terminal labels.
voltage_transform : bool, optional
Whether it's applied to voltage transform (True), or current transform (False).
Returns
-------
np.ndarray
Transformation matrix Q with shape (num_output_terminals, num_input_terminals).
"""
# if terminals_mapping is None, return an identity matrix
if terminals_mapping is None:
return np.eye(len(terminal_labels))
# Create mapping from terminal labels to indices for fast lookup
label_to_idx = {label: idx for idx, label in enumerate(terminal_labels)}
# Initialize Q matrix
num_output = len(terminals_mapping)
num_input = len(terminal_labels)
Q = np.zeros((num_output, num_input))
# Fill Q matrix based on terminals_mapping
for output_idx, (output_label, input_mapping) in enumerate(terminals_mapping.items()):
if isinstance(input_mapping, str):
# Single-ended terminal: direct mapping
input_idx = label_to_idx[input_mapping]
Q[output_idx, input_idx] = 1.0
else:
# Differential pair: tuple (label1, label2)
label1, label2 = input_mapping
idx1 = label_to_idx[label1]
idx2 = label_to_idx[label2]
if output_label.endswith("@comm"):
factor = 0.5 if voltage_transform else 1.0
Q[output_idx, idx1] = factor
Q[output_idx, idx2] = factor
elif output_label.endswith("@diff"):
factor = 1 if voltage_transform else 0.5
Q[output_idx, idx1] = factor
Q[output_idx, idx2] = -factor
else:
# Should not happen based on the design, but handle gracefully
raise ValueError(
f"Differential pair terminal '{output_label}' must end with '@comm' or '@diff'"
)
return Q
def _generate_transmission_line_terminal_data(
self, mode_solver_data: MicrowaveModeSolverData
) -> TransmissionLineTerminalDataset:
"""Generate transmission line terminal data from mode field integrals.
Computes voltage and current path integrals over the mode fields for each
terminal, then applies a terminal transformation to obtain the impedance
matrix and mode-to-terminal transformation matrices.
Parameters
----------
mode_solver_data : MicrowaveModeSolverData
Mode solver data to extract frequency and mode dimensions from.
Returns
-------
TransmissionLineTerminalDataset
Dataset containing Z0 matrix (terminal_label_out × terminal_label_in),
voltage transformation matrix (terminal_label × mode_index),
and current transformation matrix (terminal_label × mode_index).
"""
integrals_dict = self._make_path_integrals_for_terminal()
impedance_definition = self.mode_spec.impedance_definition
# Need to operate on the full symmetry expanded fields
mode_solver_data_expanded = mode_solver_data.symmetry_expanded_copy
# Initialize dictionaries to collect voltage/current for each terminal across all modes
# Structure: {terminal_label: [voltage_mode0, voltage_mode1, ...]}
V_list = {terminal_label: [] for terminal_label in integrals_dict.keys()}
I_list = {terminal_label: [] for terminal_label in integrals_dict.keys()}
# Loop over mode indices
for mode_index in range(self.mode_spec.num_modes):
single_mode_data = mode_solver_data_expanded._isel(mode_index=[mode_index])
for terminal_label, (v_integral, i_integral) in integrals_dict.items():
impedance_calc = ImpedanceCalculator(
voltage_integral=v_integral, current_integral=i_integral
)
voltage, current = impedance_calc.compute_voltage_current(
single_mode_data,
)
# Append to list for this terminal
V_list[terminal_label].append(voltage)
I_list[terminal_label].append(current)
# Validate that terminal transforms feature is available
check_tidy3d_extras_licensed_feature("terminal_transformation_matrix")
# Get input terminal labels (before differential pair transformation)
terminal_labels = list(integrals_dict.keys())
# Determine which arrays are needed based on impedance definition
# VI: needs v_list, i_list
# PI: needs i_list, power_matrix
# PV: needs v_list, power_matrix
need_v = impedance_definition in ("VI", "PV")
need_i = impedance_definition in ("VI", "PI")
need_p = impedance_definition in ("PI", "PV")
# Assemble V and I arrays as numpy arrays: (num_freqs, num_terminals, num_modes)
v_array = (
self._assemble_transform_matrices(V_list, terminal_labels).values if need_v else None
)
i_array = (
self._assemble_transform_matrices(I_list, terminal_labels).values if need_i else None
)
# Compute power matrix if needed: (num_freqs, num_modes, num_modes)
# Use truncate_to_monitor_bounds=True so that Yee-grid area elements are
# consistent with complex_flux used by ImpedanceCalculator for PI/PV fallback.
power_matrix = None
if need_p:
power_matrix = (
2
* mode_solver_data_expanded.outer_dot(
mode_solver_data_expanded,
conjugate=True,
bidirectional=False,
truncate_to_monitor_bounds=True,
)
).values.conj()
# Apply direction sign correction for backward modes.
# The outer_dot computes (1/2) int E* x H . dS, which flips sign for
# backward-propagating modes. ImpedanceCalculator.compute_impedance
# corrects this via flux_sign, but power_matrix needs the same correction
# so that derived quantities (voltage_transform, Z0) have consistent sign.
if self.direction == "-":
power_matrix *= -1
# Build Q matrices for differential pair transformation
Q_voltage = self._construct_differential_pair_transform(
self.mode_spec.terminals_mapping, terminal_labels, voltage_transform=True
)
Q_current = self._construct_differential_pair_transform(
self.mode_spec.terminals_mapping, terminal_labels, voltage_transform=False
)
# Call tidy3d_extras to compute transforms and Z0
voltage_transform_arr, current_transform_arr, z0_arr = tidy3d_extras[
"mod"
].extension._compute_terminal_transforms(
v_array,
i_array,
power_matrix,
Q_voltage,
Q_current,
)
# Get output terminal labels (after differential pair transformation)
out_terminal_labels = list(self.mode_spec._terminal_indices)
# Get coordinates from mode solver data
freqs = mode_solver_data.n_eff.coords["f"].values
mode_indices = mode_solver_data.n_eff.coords["mode_index"].values
# Wrap voltage_transform: (num_freqs, num_terminals, num_modes)
voltage_transform = VoltageFreqTerminalModeDataArray(
xr.DataArray(
voltage_transform_arr,
coords={
"f": freqs,
"terminal_label": out_terminal_labels,
"mode_index": mode_indices,
},
dims=["f", "terminal_label", "mode_index"],
)
)
# Wrap current_transform: (num_freqs, num_terminals, num_modes)
current_transform = CurrentFreqTerminalModeDataArray(
xr.DataArray(
current_transform_arr,
coords={
"f": freqs,
"terminal_label": out_terminal_labels,
"mode_index": mode_indices,
},
dims=["f", "terminal_label", "mode_index"],
)
)
# Wrap Z0: (num_freqs, num_terminals, num_terminals)
Z0 = ImpedanceFreqTerminalTerminalDataArray(
xr.DataArray(
z0_arr,
coords={
"f": freqs,
"terminal_label_out": out_terminal_labels,
"terminal_label_in": out_terminal_labels,
},
dims=["f", "terminal_label_out", "terminal_label_in"],
)
)
return TransmissionLineTerminalDataset(
Z0=Z0, voltage_transform=voltage_transform, current_transform=current_transform
)
def _add_microwave_data(
self, mode_solver_data: MicrowaveModeSolverData
) -> MicrowaveModeSolverData:
"""Calculate and add microwave data to ``mode_solver_data`` which uses the path specifications."""
mw_data = self._generate_transmission_line_data(mode_solver_data)
return mode_solver_data.updated_copy(transmission_line_data=mw_data)
def _add_microwave_terminal_data(
self, mode_solver_data: MicrowaveModeSolverData
) -> MicrowaveModeSolverData:
"""Calculate and add microwave terminal data to ``mode_solver_data`` which uses the path specifications."""
mw_data = self._generate_transmission_line_terminal_data(mode_solver_data)
return mode_solver_data.updated_copy(transmission_line_terminal_data=mw_data)
@cached_property
def data(self) -> ModeSolverDataType:
""":class:`.ModeSolverData` containing the field and effective index data.
Returns
-------
ModeSolverDataType
A mode solver data type object containing the effective index and mode fields.
"""
mode_solver_data = self.data_raw
return mode_solver_data.symmetry_expanded_copy
@cached_property
def sim_data(self) -> MODE_SIMULATION_DATA_TYPE:
""":class:`.SimulationData` object containing the :class:`.ModeSolverData` for this object.
Returns
-------
SimulationData
:class:`.SimulationData` object containing the effective index and mode fields.
"""
monitor_data = self.data
new_monitors = (*self.simulation.monitors, monitor_data.monitor)
new_simulation = self.simulation.copy(update={"monitors": new_monitors})
if isinstance(new_simulation, Simulation):
return SimulationData(simulation=new_simulation, data=(monitor_data,))
if isinstance(new_simulation, EMESimulation):
return EMESimulationData(
simulation=new_simulation, data=(monitor_data,), smatrix=None, port_modes=None
)
raise SetupError(
"The 'simulation' provided does not correspond to any known "
"'AbstractSimulationData' type."
)
def _get_epsilon(self, freq: float) -> ArrayComplex4D:
"""Compute the epsilon tensor in the plane. Order of components is xx, xy, xz, yx, etc."""
eps_keys = ["Ex", "Exy", "Exz", "Eyx", "Ey", "Eyz", "Ezx", "Ezy", "Ez"]
eps_tensor = [
self.simulation.epsilon_on_grid(self._solver_grid, key, freq) for key in eps_keys
]
return np.stack(eps_tensor, axis=0)
@staticmethod
def _tensorial_material_profile_modal_plane_tranform(
mat_data: ArrayComplex4D, normal_axis: Axis
) -> ArrayComplex4D:
"""For tensorial material response function such as epsilon and mu, pick and tranform it to
modal plane with normal axis rotated to z.
"""
# get rid of normal axis
mat_tensor = np.take(mat_data, indices=[0], axis=1 + normal_axis)
mat_tensor = np.squeeze(mat_tensor, axis=1 + normal_axis)
# convert to into 3-by-3 representation for easier axis swap
flat_shape = np.shape(mat_tensor) # 9 components flat
tensor_shape = [3, 3, *flat_shape[1:]] # 3-by-3 matrix
mat_tensor = mat_tensor.reshape(tensor_shape)
# swap axes to plane coordinates (normal_axis goes to z)
if normal_axis == 0:
# swap x and y
mat_tensor[[0, 1], :, ...] = mat_tensor[[1, 0], :, ...]
mat_tensor[:, [0, 1], ...] = mat_tensor[:, [1, 0], ...]
if normal_axis <= 1:
# swap x (normal_axis==0) or y (normal_axis==1) and z
mat_tensor[[1, 2], :, ...] = mat_tensor[[2, 1], :, ...]
mat_tensor[:, [1, 2], ...] = mat_tensor[:, [2, 1], ...]
# back to "flat" representation
mat_tensor = mat_tensor.reshape(flat_shape)
# construct to feed to mode solver
return mat_tensor
@staticmethod
def _diagonal_material_profile_modal_plane_tranform(
mat_data: ArrayComplex4D, normal_axis: Axis
) -> ArrayComplex3D:
"""For diagonal material response function such as epsilon and mu, pick and tranform it to
modal plane with normal axis rotated to z.
"""
# get rid of normal axis
mat_tensor = np.take(mat_data, indices=[0], axis=1 + normal_axis)
mat_tensor = np.squeeze(mat_tensor, axis=1 + normal_axis)
# swap axes to plane coordinates (normal_axis goes to z)
if normal_axis == 0:
# swap x and y
mat_tensor[[0, 1], :, ...] = mat_tensor[[1, 0], :, ...]
if normal_axis <= 1:
# swap x (normal_axis==0) or y (normal_axis==1) and z
mat_tensor[[1, 2], :, ...] = mat_tensor[[2, 1], :, ...]
# construct to feed to mode solver
return mat_tensor
def _solver_eps(self, freq: float) -> ArrayComplex4D:
"""Diagonal permittivity in the shape needed by solver, with normal axis rotated to z."""
# Get diagonal epsilon components in the plane
eps_tensor = self._get_epsilon(freq)
# tranformation
return self._tensorial_material_profile_modal_plane_tranform(eps_tensor, self.normal_axis)
@supports_local_subpixel
def _solve_all_freqs(
self,
coords: tuple[ArrayFloat1D, ArrayFloat1D],
symmetry: tuple[Symmetry, Symmetry],
) -> tuple[list[float], list[dict[str, ArrayComplex4D]], list[EpsSpecType]]:
"""Call the mode solver at all requested frequencies."""
if tidy3d_extras["use_local_subpixel"]:
subpixel_ms = tidy3d_extras["mod"].SubpixelModeSolver.from_mode_solver(self)
return subpixel_ms._solve_all_freqs(
coords=coords,
symmetry=symmetry,
)
fields = []
n_complex = []
eps_spec = []
for freq in self.freqs:
n_freq, fields_freq, eps_spec_freq = self._solve_single_freq(
freq=freq,
coords=coords,
symmetry=symmetry,
)
fields.append(fields_freq)
n_complex.append(n_freq)
eps_spec.append(eps_spec_freq)
return n_complex, fields, eps_spec
@supports_local_subpixel
def _solve_all_freqs_relative(
self,
coords: tuple[ArrayFloat1D, ArrayFloat1D],
symmetry: tuple[Symmetry, Symmetry],
basis_fields: list[dict[str, ArrayComplex4D]],
) -> tuple[list[float], list[dict[str, ArrayComplex4D]], list[EpsSpecType]]:
"""Call the mode solver at all requested frequencies."""
if tidy3d_extras["use_local_subpixel"]:
subpixel_ms = tidy3d_extras["mod"].SubpixelModeSolver.from_mode_solver(self)
return subpixel_ms._solve_all_freqs_relative(
coords=coords,
symmetry=symmetry,
basis_fields=basis_fields,
)
fields = []
n_complex = []
eps_spec = []
for freq, basis_fields_freq in zip(self.freqs, basis_fields):
n_freq, fields_freq, eps_spec_freq = self._solve_single_freq_relative(
freq=freq,
coords=coords,
symmetry=symmetry,
basis_fields=basis_fields_freq,
)
fields.append(fields_freq)
n_complex.append(n_freq)
eps_spec.append(eps_spec_freq)
return n_complex, fields, eps_spec
@staticmethod
def _postprocess_solver_fields(
solver_fields: ArrayComplex4D,
normal_axis: Axis,
plane: MODE_PLANE_TYPE,
mode_spec: ModeSpec,
coords: tuple[ArrayFloat1D, ArrayFloat1D],
) -> dict[str, ArrayComplex4D]:
"""Postprocess `solver_fields` from `compute_modes` to proper coordinate"""
fields = {key: [] for key in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz")}
diff_coords = (np.diff(coords[0]), np.diff(coords[1]))
for mode_index in range(mode_spec.num_modes):
# Get E and H fields at the current mode_index
((Ex, Ey, Ez), (Hx, Hy, Hz)) = ModeSolver._process_fields(
solver_fields, mode_index, normal_axis, plane, diff_coords
)
# Note: back in original coordinates
fields_mode = {"Ex": Ex, "Ey": Ey, "Ez": Ez, "Hx": Hx, "Hy": Hy, "Hz": Hz}
for field_name, field in fields_mode.items():
fields[field_name].append(field)
for field_name, field in fields.items():
fields[field_name] = np.stack(field, axis=-1)
return fields
def _solve_single_freq(
self,
freq: float,
coords: tuple[ArrayFloat1D, ArrayFloat1D],
symmetry: tuple[Symmetry, Symmetry],
) -> tuple[float, dict[str, ArrayComplex4D], EpsSpecType]:
"""Call the mode solver at a single frequency.
The fields are rotated from propagation coordinates back to global coordinates.
"""
if not LOCAL_SOLVER_IMPORTED:
raise ImportError(IMPORT_ERROR_MSG)
solver_fields, n_complex, eps_spec = compute_modes(
eps_cross=self._solver_eps(freq),
coords=coords,
freq=freq,
mode_spec=self.mode_spec,
symmetry=symmetry,
direction=self.direction,
precision=self._precision,
plane_center=self.plane_center_tangential(self.plane),
)
fields = self._postprocess_solver_fields(
solver_fields, self.normal_axis, self.plane, self.mode_spec, coords
)
return n_complex, fields, eps_spec
@classmethod
def _rotate_field_coords_inverse(
cls, field: FIELD, normal_axis: Axis, plane: MODE_PLANE_TYPE
) -> FIELD:
"""Move the propagation axis to the z axis in the array."""
f_x, f_y, f_z = np.moveaxis(field, source=1 + normal_axis, destination=3)
f_n, f_ts = plane.pop_axis((f_x, f_y, f_z), axis=normal_axis)
return np.stack(plane.unpop_axis(f_n, f_ts, axis=2), axis=0)
@classmethod
def _postprocess_solver_fields_inverse(
cls, fields: dict[str, ArrayComplex4D], normal_axis: Axis, plane: MODE_PLANE_TYPE
) -> ArrayComplex4D:
"""Convert ``fields`` to ``solver_fields``. Doesn't change gauge."""
E = [fields[key] for key in ("Ex", "Ey", "Ez")]
H = [fields[key] for key in ("Hx", "Hy", "Hz")]
(Ex, Ey, Ez) = cls._rotate_field_coords_inverse(E, normal_axis=normal_axis, plane=plane)
(Hx, Hy, Hz) = cls._rotate_field_coords_inverse(H, normal_axis=normal_axis, plane=plane)
# apply -1 to H fields if a reflection was involved in the rotation
if normal_axis == 1:
Hx *= -1
Hy *= -1
Hz *= -1
solver_fields = np.stack((Ex, Ey, Ez, Hx, Hy, Hz), axis=0)
return solver_fields
def _solve_single_freq_relative(
self,
freq: float,
coords: tuple[ArrayFloat1D, ArrayFloat1D],
symmetry: tuple[Symmetry, Symmetry],
basis_fields: dict[str, ArrayComplex4D],
) -> tuple[float, dict[str, ArrayComplex4D], EpsSpecType]:
"""Call the mode solver at a single frequency.
Modes are computed as linear combinations of ``basis_fields``.
"""
if not LOCAL_SOLVER_IMPORTED:
raise ImportError(IMPORT_ERROR_MSG)
solver_basis_fields = self._postprocess_solver_fields_inverse(
fields=basis_fields, normal_axis=self.normal_axis, plane=self.plane
)
solver_fields, n_complex, eps_spec = compute_modes(
eps_cross=self._solver_eps(freq),
coords=coords,
freq=freq,
mode_spec=self.mode_spec,
symmetry=symmetry,
direction=self.direction,
solver_basis_fields=solver_basis_fields,
precision=self._precision,
plane_center=self.plane_center_tangential(self.plane),
)
fields = self._postprocess_solver_fields(
solver_fields, self.normal_axis, self.plane, self.mode_spec, coords
)
return n_complex, fields, eps_spec
@staticmethod
def _rotate_field_coords(field: FIELD, normal_axis: Axis, plane: MODE_PLANE_TYPE) -> FIELD:
"""Move the propagation axis=z to the proper order in the array."""
f_x, f_y, f_z = np.moveaxis(field, source=3, destination=1 + normal_axis)
return np.stack(plane.unpop_axis(f_z, (f_x, f_y), axis=normal_axis), axis=0)
@staticmethod
def _weighted_coord_max(
array: ArrayFloat2D, u: ArrayFloat1D, v: ArrayFloat1D
) -> tuple[int, int]:
"""2D argmax for an array weighted in both directions."""
if not np.all(np.isfinite(array)): # make sure the array is valid
return 0, 0
m_i = array * u.reshape(-1, 1)
total_weight_i = m_i.sum()
if total_weight_i == 0:
i = 0
else:
indices_i = np.arange(array.shape[0])
weighted_sum = (m_i * indices_i.reshape(-1, 1)).sum()
i = int(0.5 + weighted_sum / total_weight_i)
m_j = array * v
total_weight_j = m_j.sum()
if total_weight_j == 0:
j = 0
else:
indices_j = np.arange(array.shape[1])
weighted_sum = (m_j * indices_j).sum()
j = int(0.5 + weighted_sum / total_weight_j)
return i, j
@staticmethod
def _inverted_gauge(e_field: FIELD, diff_coords: tuple[ArrayFloat1D, ArrayFloat1D]) -> bool:
"""Check if the lower xy region of the mode has a negative sign."""
dx, dy = diff_coords
e_x, e_y = e_field[:2, :, :, 0]
e_x = e_x.real
e_y = e_y.real
e = e_x if np.abs(e_x).max() > np.abs(e_y).max() else e_y
abs_e = np.abs(e)
e_2 = abs_e**2
i, j = e.shape
while i > 0 and j > 0:
if (e[:i, :j] > 0).all():
return False
if (e[:i, :j] < 0).all():
return True
threshold = abs_e[:i, :j].max() * 0.5
i, j = ModeSolver._weighted_coord_max(e_2[:i, :j], dx[:i], dy[:j])
if abs(e[i, j]) >= threshold:
return e[i, j] < 0
# Do not close the window for 1D mode solvers
if e.shape[0] == 1:
i = 1
elif e.shape[1] == 1:
j = 1
return False
@staticmethod
def _process_fields(
mode_fields: ArrayComplex4D,
mode_index: NonNegativeInt,
normal_axis: Axis,
plane: MODE_PLANE_TYPE,
diff_coords: tuple[ArrayFloat1D, ArrayFloat1D],
) -> tuple[FIELD, FIELD]:
"""Transform solver fields to simulation axes and set gauge."""
# Separate E and H fields (in solver coordinates)
E, H = mode_fields[..., mode_index]
# Set gauge to highest-amplitude in-plane E being real and positive
ind_max = np.argmax(np.abs(E[:2]))
phi = np.angle(E[:2].ravel()[ind_max])
E *= np.exp(-1j * phi)
H *= np.exp(-1j * phi)
# High-order modes with opposite-sign lobes will show inconsistent sign, heavily dependent
# on the exact local grid to choose the previous gauge. This method flips the gauge sign
# when necessary to make it more consistent.
if ModeSolver._inverted_gauge(E, diff_coords):
E *= -1
H *= -1
# Rotate back to original coordinates
(Ex, Ey, Ez) = ModeSolver._rotate_field_coords(E, normal_axis, plane)
(Hx, Hy, Hz) = ModeSolver._rotate_field_coords(H, normal_axis, plane)
# apply -1 to H fields if a reflection was involved in the rotation
if normal_axis == 1:
Hx *= -1
Hy *= -1
Hz *= -1
return ((Ex, Ey, Ez), (Hx, Hy, Hz))
@staticmethod
def _edge_indices_for_component(
field: xr.DataArray, dim: str, bounds: BoundOptional
) -> list[int] | None:
"""Return the two edge indices along *dim* for a single field component.
Each E-field component sits at a different Yee-grid position, so the
coordinates that fall on or just inside the solver bounds differ per
component.
"""
coords = np.array(field.coords[dim])
if len(coords) <= 1:
return None
ax = "xyz".index(dim)
lo, hi = bounds[0][ax], bounds[1][ax]
idx_lo = (
find_snap_location(coords, lo, "upper", rel_tol=fp_eps, abs_tol=fp_eps)
if lo is not None
else 0
)
idx_hi = (
find_snap_location(coords, hi, "lower", rel_tol=fp_eps, abs_tol=fp_eps)
if hi is not None
else len(coords) - 1
)
return [idx_lo, idx_hi]
def _field_decay_warning(self, field_data: ModeSolverData) -> None:
"""Warn if any of the modes do not decay at the edges of the solver domain."""
_, plane_dims = self.plane.pop_axis(["x", "y", "z"], axis=self.normal_axis)
bounds = field_data.solver_field_bounds
e_fields = (field_data.Ex, field_data.Ey, field_data.Ez)
field_sizes = field_data.Ex.sizes
for freq_index in range(field_sizes["f"]):
for mode_index in range(field_sizes["mode_index"]):
e_edge, e_norm = 0, 0
fm_sel = {"f": freq_index, "mode_index": mode_index}
for E in e_fields:
e_norm += np.sum(np.abs(E[fm_sel]) ** 2)
for dim in plane_dims:
for E in e_fields:
edge_idx = self._edge_indices_for_component(E, dim, bounds)
if edge_idx is None:
continue
isel = {dim: edge_idx, **fm_sel}
e_edge += np.sum(np.abs(E[isel]) ** 2)
if e_edge / e_norm > FIELD_DECAY_CUTOFF:
log.warning(
f"Mode field at frequency index {freq_index}, mode index {mode_index} does "
"not decay at the plane boundaries."
)
@staticmethod
def _grid_correction(
simulation: MODE_SIMULATION_TYPE,
plane: Box,
mode_spec: ModeSpec,
n_complex: ModeIndexDataArray,
direction: Direction,
) -> tuple[
tuple[FreqModeDataArray, FreqModeDataArray], tuple[tuple[float, ...], tuple[float, ...]]
]:
"""
Compute grid correction factors for the mode fields.
This method calculates the phase correction factors necessary to account for propagation
on a finite numerical grid along the propagation direction (normal to the mode plane).
The correction is based on the assumed ``E * exp(1j k r)`` field dependence, where the
fields are resampled using linear interpolation to precisely match the mode plane position.
This is needed to correctly compute overlap with fields that come from
a :class:`.FieldMonitor` placed in the same grid.
Parameters
----------
simulation : MODE_SIMULATION_TYPE
Simulation object, which provides the grid structure.
plane : Box
The mode plane (its normal and center define the propagation direction and position).
mode_spec : ModeSpec
Mode specification with relevant propagation angle and properties.
n_complex : ModeIndexDataArray
Complex effective index array for the modes.
direction : Direction
Direction of propagation; "+" for forward or "-" for backward.
Returns
-------
tuple of FreqModeDataArray
A tuple of two FreqModeDataArray objects:
(phase_primal, phase_dual), containing the correction phase factors for the primal
(tangential E field) and dual (tangential H field) grid locations, respectively.
"""
normal_axis = plane.size.index(0.0)
normal_pos = float(plane.center[normal_axis])
normal_dim = "xyz"[normal_axis]
# Primal and dual grid along the normal direction,
# i.e. locations of the tangential E-field and H-field components, respectively
grid = simulation.grid
normal_primal = grid.boundaries.to_list[normal_axis]
normal_primal = xr.DataArray(normal_primal, coords={normal_dim: normal_primal})
normal_dual = grid.centers.to_list[normal_axis]
normal_dual = xr.DataArray(normal_dual, coords={normal_dim: normal_dual})
def find_closest_distances_to_grid_points(
normal_pos: float, grid_coords: ArrayFloat1D
) -> tuple[float, float]:
"""Find the closest points to the normal position in the grid coordinates."""
if grid_coords.size == 1:
return [float(grid_coords.data[0] - normal_pos)]
distances = grid_coords.data - normal_pos
# First, find the signed distance to the closest grid point
closest_distance_ind = np.argmin(np.abs(distances))
closest_distance = distances[closest_distance_ind]
# Then, if the closest distance is positive, take the previous point, otherwise take the next point
if closest_distance > 0:
first_dist = distances[closest_distance_ind - 1]
second_dist = distances[closest_distance_ind]
else:
first_dist = distances[closest_distance_ind]
second_dist = distances[closest_distance_ind + 1]
# Return the two closest points
return [first_dist, second_dist]
primal_closest_distances = find_closest_distances_to_grid_points(normal_pos, normal_primal)
dual_closest_distances = find_closest_distances_to_grid_points(normal_pos, normal_dual)
grid_correction_factors = ModeSolverData._grid_correction_factors(
primal_closest_distances,
dual_closest_distances,
mode_spec,
n_complex,
direction,
normal_dim,
)
return grid_correction_factors, (primal_closest_distances, dual_closest_distances)
@property
def _is_tensorial(self) -> bool:
"""Whether the mode computation should be fully tensorial. This is either due to fully
anisotropic media, or due to an angled waveguide, in which case the transformed eps and mu
become tensorial. When ``angle_rotation`` is enabled, the geometry is rotated so that the
effective angle is zero, avoiding the tensorial path — anisotropic and fully anisotropic
media are currently not supported with ``angle_rotation``, so we return ``False`` in that
case. A separate check is done inside the solver, which looks at the actual eps and mu and
uses a tolerance to determine whether to invoke the tensorial solver, so the actual
behavior may differ from what's predicted by this property."""
if self.mode_spec.angle_rotation and abs(self.mode_spec.angle_theta) > 0:
return False
return abs(self.mode_spec.angle_theta) > 0 or self._has_fully_anisotropic_media
@cached_property
def _intersecting_media(self) -> list:
"""List of media (including simulation background) intersecting the mode plane."""
total_structures = [self.simulation.scene.background_structure]
total_structures += list(self.simulation.volumetric_structures)
return self.simulation.scene.intersecting_media(self.plane, total_structures)
@cached_property
def _has_fully_anisotropic_media(self) -> bool:
"""Check if there are any fully anisotropic media in the plane of the mode."""
if np.any(
[isinstance(mat, FullyAnisotropicMedium) for mat in self.simulation.scene.mediums]
):
for int_mat in self._intersecting_media:
if isinstance(int_mat, FullyAnisotropicMedium):
return True
return False
@cached_property
def _has_complex_eps(self) -> bool:
"""Whether the eigenvalue problem is effectively complex-valued.
Returns ``True`` when at least one medium has a non-negligible imaginary
part of epsilon at any sampled frequency (lossy material), or when PML
layers are present in the mode plane (PML introduces complex coordinate
stretching, making the problem lossy even with real-valued materials).
A separate check is done inside the solver, which looks at the actual
eps and mu and uses a tolerance to determine whether to use real or
complex fields, so the actual behavior may differ from what's predicted
by this property.
"""
if any(n > 0 for n in self.mode_spec.num_pml):
return True
check_freqs = np.unique(
[
np.amin(self._sampling_freqs),
np.amax(self._sampling_freqs),
np.mean(self._sampling_freqs),
]
)
for int_mat in self._intersecting_media:
for freq in check_freqs:
max_imag_eps = np.amax(np.abs(np.imag(int_mat.eps_model(freq))))
if not isclose(max_imag_eps, 0):
return True
return False
@cached_property
def _contain_good_conductor(self) -> bool:
"""Whether modal plane might contain structures made of good conductors with large permittivity
or permeability values.
"""
sim = self.reduced_simulation_copy.simulation
apply_sibc = isinstance(sim._subpixel.lossy_metal, SurfaceImpedance)
for medium in sim.scene.mediums:
if medium.is_pec:
return True
if medium.is_pmc:
return True
if apply_sibc and isinstance(medium, LossyMetalMedium):
return True
return False
@cached_property
def _precision(self) -> Literal["single", "double"]:
"""single or double precision applied in mode solver."""
precision = self.mode_spec.precision
if precision == "auto":
if self._contain_good_conductor:
return "double"
return "single"
return precision
[docs]
def to_source(
self,
source_time: SourceTime,
direction: Direction = None,
mode_index: NonNegativeInt = 0,
num_freqs: PositiveInt = 1,
**kwargs: Any,
) -> ModeSource:
"""Creates :class:`.ModeSource` from a :class:`.ModeSolver` instance plus additional
specifications.
Parameters
----------
source_time: :class:`.SourceTime`
Specification of the source time-dependence.
direction : Direction = None
Whether source will inject in ``"+"`` or ``"-"`` direction relative to plane normal.
If not specified, uses the direction from the mode solver.
mode_index : int = 0
Index into the list of modes returned by mode solver to use in source.
Returns
-------
:class:`.ModeSource`
Mode source with specifications taken from the ModeSolver instance and the method
inputs.
"""
if direction is None:
direction = self.direction
return ModeSource(
center=self.plane.center,
size=self.plane.size,
source_time=source_time,
mode_spec=self.mode_spec,
mode_index=mode_index,
direction=direction,
num_freqs=num_freqs,
**kwargs,
)
[docs]
def to_monitor(self, freqs: list[float] | None = None, name: str | None = None) -> ModeMonitor:
"""Creates :class:`ModeMonitor` from a :class:`.ModeSolver` instance plus additional
specifications.
Parameters
----------
freqs : list[float]
Frequencies to include in Monitor (Hz).
If not specified, passes ``self.freqs``.
name : str
Required name of monitor.
Returns
-------
:class:`.ModeMonitor`
Mode monitor with specifications taken from the ModeSolver instance and the method
inputs.
"""
if freqs is None:
freqs = self.freqs
if name is None:
raise ValueError(
"A 'name' must be passed to 'ModeSolver.to_monitor'. "
"The default value of 'None' is for backwards compatibility and is not accepted."
)
mode_solver_monitor_type = ModeMonitor
if self._has_microwave_mode_spec or self._has_microwave_terminal_mode_spec:
mode_solver_monitor_type = MicrowaveModeMonitor
return mode_solver_monitor_type(
center=self.plane.center,
size=self.plane.size,
freqs=freqs,
mode_spec=self.mode_spec,
colocate=self.colocate,
conjugated_dot_product=self.conjugated_dot_product,
name=name,
)
[docs]
def to_mode_solver_monitor(
self,
name: str,
colocate: bool | None = None,
mode_spec: ModeSpec | None = None,
freqs: list[float] | None = None,
) -> ModeSolverMonitor:
"""Creates :class:`ModeSolverMonitor` from a :class:`.ModeSolver` instance.
Parameters
----------
name : str
Name of the monitor.
colocate : bool
Whether to colocate fields or compute on the Yee grid. If not provided, the value
set in the :class:`.ModeSolver` instance is used.
mode_spec : ModeSpec
Mode specification to use for the monitor.
If not specified, uses the mode specification from the mode solver.
freqs : list[float]
Frequencies to include in Monitor (Hz).
If not specified, uses the frequencies from the mode solver.
Returns
-------
:class:`.ModeSolverMonitor`
Mode monitor with specifications taken from the ModeSolver instance and ``name``.
"""
if mode_spec is None:
mode_spec = self.mode_spec
if freqs is None:
freqs = self.freqs
if colocate is None:
colocate = self.colocate
mode_solver_monitor_type = ModeSolverMonitor
if self._has_microwave_mode_spec or self._has_microwave_terminal_mode_spec:
mode_solver_monitor_type = MicrowaveModeSolverMonitor
return mode_solver_monitor_type(
size=self.plane.size,
center=self.plane.center,
mode_spec=mode_spec,
freqs=freqs,
direction=self.direction,
colocate=colocate,
use_colocated_integration=self.use_colocated_integration,
conjugated_dot_product=self.conjugated_dot_product,
name=name,
)
[docs]
@require_fdtd_simulation
def sim_with_source(
self,
source_time: SourceTime,
direction: Direction = None,
mode_index: NonNegativeInt = 0,
) -> Simulation:
"""Creates :class:`.Simulation` from a :class:`.ModeSolver`. Creates a copy of
the ModeSolver's original simulation with a ModeSource added corresponding to
the ModeSolver parameters.
Parameters
----------
source_time: :class:`.SourceTime`
Specification of the source time-dependence.
direction : Direction = None
Whether source will inject in ``"+"`` or ``"-"`` direction relative to plane normal.
If not specified, uses the direction from the mode solver.
mode_index : int = 0
Index into the list of modes returned by mode solver to use in source.
Returns
-------
:class:`.Simulation`
Copy of the simulation with a :class:`.ModeSource` with specifications taken
from the ModeSolver instance and the method inputs.
"""
mode_source = self.to_source(
mode_index=mode_index, direction=direction, source_time=source_time
)
new_sources = [*list(self.simulation.sources), mode_source]
new_sim = self.simulation.updated_copy(sources=new_sources)
return new_sim
[docs]
@require_fdtd_simulation
def sim_with_monitor(
self,
freqs: list[float] | None = None,
name: str | None = None,
) -> Simulation:
"""Creates :class:`.Simulation` from a :class:`.ModeSolver`. Creates a copy of
the ModeSolver's original simulation with a mode monitor added corresponding to
the ModeSolver parameters.
Parameters
----------
freqs : list[float] = None
Frequencies to include in Monitor (Hz).
If not specified, uses the frequencies from the mode solver.
name : str
Required name of monitor.
Returns
-------
:class:`.Simulation`
Copy of the simulation with a :class:`.ModeMonitor` with specifications taken
from the ModeSolver instance and the method inputs.
"""
mode_monitor = self.to_monitor(freqs=freqs, name=name)
new_monitors = [*list(self.simulation.monitors), mode_monitor]
new_sim = self.simulation.updated_copy(monitors=new_monitors)
return new_sim
[docs]
def sim_with_mode_solver_monitor(
self,
name: str,
) -> Simulation:
"""Creates :class:`.Simulation` from a :class:`.ModeSolver`. Creates a
copy of the ModeSolver's original simulation with a mode solver monitor
added corresponding to the ModeSolver parameters.
Parameters
----------
name : str
Name of the monitor.
Returns
-------
:class:`.Simulation`
Copy of the simulation with a :class:`.ModeSolverMonitor` with specifications taken
from the ModeSolver instance and ``name``.
"""
mode_solver_monitor = self.to_mode_solver_monitor(name=name)
new_monitors = [*list(self.simulation.monitors), mode_solver_monitor]
new_sim = self.simulation.updated_copy(monitors=new_monitors)
return new_sim
[docs]
def plot_field(
self,
field_name: str,
val: Literal["real", "imag", "abs"] = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
robust: bool = True,
vmin: float | None = None,
vmax: float | None = None,
ax: Ax = None,
cmap: str | Colormap | None = None,
**sel_kwargs: Any,
) -> Ax:
"""Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid.
Parameters
----------
field_name : str
Name of ``field`` component to plot (eg. ``'Ex'``).
Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
magnetic fields, and ``'S'`` for the Poynting vector.
val : Literal['real', 'imag', 'abs', 'abs^2', 'dB'] = 'real'
Which part of the field to plot.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
in the presence of a source.
vmin : float = None
The lower bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
vmax : float = None
The upper bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
cmap : Optional[Union[str, Colormap]] = None
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable.
For the plotting to work appropriately, the resulting data after selection must contain
only two coordinates with len > 1.
Furthermore, these should be spatial coordinates (``x``, ``y``, or ``z``).
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
sim_data = self.sim_data
return sim_data.plot_field(
field_monitor_name=self.data.monitor.name,
field_name=field_name,
val=val,
scale=scale,
eps_alpha=eps_alpha,
robust=robust,
vmin=vmin,
vmax=vmax,
ax=ax,
cmap=cmap,
**sel_kwargs,
)
[docs]
def plot_field_components(
self,
field_names: str | tuple[str, ...],
mode_indices: int | tuple[int, ...] | None = None,
val: Literal["real", "imag", "abs"] = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
robust: bool = True,
vmin: float | None = None,
vmax: float | None = None,
ax: Any = None,
cmap: str | Colormap | None = None,
figsize: tuple[float, float] | None = None,
titles: bool = True,
show_n_eff: bool = False,
**sel_kwargs: Any,
) -> tuple[Any, np.ndarray]:
"""Plot multiple field components for one or more modes in a single call."""
from tidy3d.components.mode.data.sim_data import ModeSimulationData
from tidy3d.components.mode.simulation import ModeSimulation
mode_sim = ModeSimulation.from_mode_solver(self)
mode_sim_data = ModeSimulationData(simulation=mode_sim, modes_raw=self.data_raw)
return mode_sim_data.plot_field_components(
field_names=field_names,
mode_indices=mode_indices,
val=val,
scale=scale,
eps_alpha=eps_alpha,
robust=robust,
vmin=vmin,
vmax=vmax,
ax=ax,
cmap=cmap,
figsize=figsize,
titles=titles,
show_n_eff=show_n_eff,
**sel_kwargs,
)
[docs]
def plot(
self,
ax: Ax = None,
hlim: tuple[float, float] | None = None,
vlim: tuple[float, float] | None = None,
fill_structures: bool = True,
**patch_kwargs: Any,
) -> Ax:
"""Plot the mode plane simulation's components.
Parameters
----------
ax : matplotlib.axes._subplots.Axes = None
Matplotlib axes to plot on, if not specified, one is created.
hlim : Tuple[float, float] = None
The x range if plotting on xy or xz planes, y range if plotting on yz plane.
vlim : Tuple[float, float] = None
The z range if plotting on xz or yz planes, y plane if plotting on xy plane.
fill_structures : bool = True
Whether to fill structures with color or just draw outlines.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
See Also
---------
**Notebooks**
* `Visualizing geometries in Tidy3D: Plotting Materials <../../notebooks/VizSimulation.html#Plotting-Materials>`_
"""
# Get the mode plane normal axis, center, and limits.
a_center, hlim_plane, vlim_plane, _ = self._center_and_lims(
simulation=self.simulation, plane=self.plane
)
if hlim is None:
hlim = hlim_plane
if vlim is None:
vlim = vlim_plane
ax = self.simulation.plot(
x=a_center[0],
y=a_center[1],
z=a_center[2],
hlim=hlim,
vlim=vlim,
source_alpha=0,
monitor_alpha=0,
lumped_element_alpha=0,
ax=ax,
fill_structures=fill_structures,
**patch_kwargs,
)
return self.plot_pml(ax=ax)
[docs]
def plot_eps(
self,
freq: float | None = None,
alpha: float | None = None,
ax: Ax = None,
) -> Ax:
"""Plot the mode plane simulation's components.
The permittivity is plotted in grayscale based on its value at the specified frequency.
Parameters
----------
freq : float = None
Frequency to evaluate the relative permittivity of all mediums.
If not specified, evaluates at infinite frequency.
alpha : float = None
Opacity of the structures being plotted.
Defaults to the structure default alpha.
ax : matplotlib.axes._subplots.Axes = None
Matplotlib axes to plot on, if not specified, one is created.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
See Also
---------
**Notebooks**
* `Visualizing geometries in Tidy3D: Plotting Permittivity <../../notebooks/VizSimulation.html#Plotting-Permittivity>`_
"""
# Get the mode plane normal axis, center, and limits.
a_center, h_lim, v_lim, _ = self._center_and_lims(
simulation=self.simulation, plane=self.plane
)
# Plot at central mode frequency if freq is not provided.
f = freq if freq is not None else self.freqs[len(self.freqs) // 2]
return self.simulation.plot_eps(
x=a_center[0],
y=a_center[1],
z=a_center[2],
freq=f,
alpha=alpha,
hlim=h_lim,
vlim=v_lim,
source_alpha=0,
monitor_alpha=0,
lumped_element_alpha=0,
ax=ax,
)
[docs]
def plot_structures_eps(
self,
freq: float | None = None,
alpha: float | None = None,
cbar: bool = True,
reverse: bool = False,
ax: Ax = None,
) -> Ax:
"""Plot the mode plane simulation's components.
The permittivity is plotted in grayscale based on its value at the specified frequency.
Parameters
----------
freq : float = None
Frequency to evaluate the relative permittivity of all mediums.
If not specified, evaluates at infinite frequency.
alpha : float = None
Opacity of the structures being plotted.
Defaults to the structure default alpha.
cbar : bool = True
Whether to plot a colorbar for the relative permittivity.
reverse : bool = False
If ``False``, the highest permittivity is plotted in black.
If ``True``, it is plotteed in white (suitable for black backgrounds).
ax : matplotlib.axes._subplots.Axes = None
Matplotlib axes to plot on, if not specified, one is created.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
See Also
---------
**Notebooks**
* `Visualizing geometries in Tidy3D: Plotting Permittivity <../../notebooks/VizSimulation.html#Plotting-Permittivity>`_
"""
# Get the mode plane normal axis, center, and limits.
a_center, h_lim, v_lim, _ = self._center_and_lims(
simulation=self.simulation, plane=self.plane
)
# Plot at central mode frequency if freq is not provided.
f = freq if freq is not None else self.freqs[len(self.freqs) // 2]
return self.simulation.plot_structures_eps(
x=a_center[0],
y=a_center[1],
z=a_center[2],
freq=f,
alpha=alpha,
cbar=cbar,
reverse=reverse,
hlim=h_lim,
vlim=v_lim,
ax=ax,
)
[docs]
def plot_grid(
self,
ax: Ax = None,
**kwargs: Any,
) -> Ax:
"""Plot the mode plane cell boundaries as lines.
Parameters
----------
ax : matplotlib.axes._subplots.Axes = None
Matplotlib axes to plot on, if not specified, one is created.
**kwargs
Optional keyword arguments passed to the matplotlib ``LineCollection``.
For details on accepted values, refer to
`Matplotlib's documentation <https://tinyurl.com/2p97z4cn>`_.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
# Get the mode plane normal axis, center, and limits.
a_center, h_lim, v_lim, _ = self._center_and_lims(
simulation=self.simulation, plane=self.plane
)
return self.simulation.plot_grid(
x=a_center[0], y=a_center[1], z=a_center[2], hlim=h_lim, vlim=v_lim, ax=ax, **kwargs
)
@classmethod
def _plane_grid(cls, simulation: Simulation, plane: Box) -> tuple[Coords, Coords]:
"""Plane grid for mode solver."""
# Get the mode plane normal axis, center, and limits.
_, _, _, t_axes = cls._center_and_lims(simulation=simulation, plane=plane)
grid_snapped = cls._grid_snapped(simulation=simulation, plane=plane)
# Mode plane grid.
plane_grid = grid_snapped.boundaries.to_list
coord_0 = plane_grid[t_axes[0]]
coord_1 = plane_grid[t_axes[1]]
return coord_0, coord_1
@classmethod
def _effective_num_pml(
cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec
) -> tuple[NonNegativeFloat, NonNegativeFloat]:
"""Number of cells of the mode solver pml."""
coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane)
# Number of PML layers in ModeSpec.
num_pml_0 = mode_spec.num_pml[0]
num_pml_1 = mode_spec.num_pml[1]
num_pml_0 = min(num_pml_0, len(coord_0) - 1)
num_pml_1 = min(num_pml_1, len(coord_1) - 1)
return (num_pml_0, num_pml_1)
@classmethod
def _pml_thickness(
cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec
) -> tuple[
tuple[NonNegativeFloat, NonNegativeFloat],
tuple[NonNegativeFloat, NonNegativeFloat],
]:
"""Thickness of the mode solver pml in the form
((plus0, minus0), (plus1, minus1))
where 0 and 1 are the lexicographically-ordered tangential axes
to the mode plane.
"""
# Get the mode plane normal axis, center, and limits.
solver_symmetry = cls._solver_symmetry(simulation=simulation, plane=plane)
coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane)
# Number of PML layers in ModeSpec.
num_pml_0, num_pml_1 = cls._effective_num_pml(
simulation=simulation, plane=plane, mode_spec=mode_spec
)
# Calculate PML thickness.
pml_thick_0_plus = 0
pml_thick_0_minus = 0
if num_pml_0 > 0:
pml_thick_0_plus = coord_0[-1] - coord_0[-num_pml_0 - 1]
pml_thick_0_minus = coord_0[num_pml_0] - coord_0[0]
if solver_symmetry[0] != 0:
pml_thick_0_minus = pml_thick_0_plus
pml_thick_1_plus = 0
pml_thick_1_minus = 0
if num_pml_1 > 0:
pml_thick_1_plus = coord_1[-1] - coord_1[-num_pml_1 - 1]
pml_thick_1_minus = coord_1[num_pml_1] - coord_1[0]
if solver_symmetry[1] != 0:
pml_thick_1_minus = pml_thick_1_plus
return ((pml_thick_0_plus, pml_thick_0_minus), (pml_thick_1_plus, pml_thick_1_minus))
@classmethod
def _mode_plane_size(
cls, simulation: Simulation, plane: Box
) -> tuple[NonNegativeFloat, NonNegativeFloat]:
"""The size of the mode plane intersected with the simulation."""
_, h_lim, v_lim, _ = cls._center_and_lims(simulation=simulation, plane=plane)
return h_lim[1] - h_lim[0], v_lim[1] - v_lim[0]
@classmethod
def _mode_plane_size_no_pml(
cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec
) -> tuple[NonNegativeFloat, NonNegativeFloat]:
"""The size of the remaining portion of the mode plane, after the pml
has been removed."""
size = cls._mode_plane_size(simulation=simulation, plane=plane)
pml_thickness = cls._pml_thickness(simulation=simulation, plane=plane, mode_spec=mode_spec)
size0 = size[0] - pml_thickness[0][0] - pml_thickness[0][1]
size1 = size[1] - pml_thickness[1][0] - pml_thickness[1][1]
return size0, size1
[docs]
def plot_pml(
self,
ax: Ax = None,
) -> Ax:
"""Plot the mode plane absorbing boundaries.
Parameters
----------
ax : matplotlib.axes._subplots.Axes = None
Matplotlib axes to plot on, if not specified, one is created.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
return self._plot_pml(
simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec, ax=ax
)
@classmethod
def _plot_pml(
cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec, ax: Ax = None
) -> Ax:
"""Plot the mode plane absorbing boundaries."""
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
# Get the mode plane normal axis, center, and limits.
_, h_lim, v_lim, _ = cls._center_and_lims(simulation=simulation, plane=plane)
# Create ax if ax=None.
if not ax:
ax = make_ax()
ax.set_xlim(h_lim)
ax.set_ylim(v_lim)
# Number of PML layers in ModeSpec.
num_pml_0 = mode_spec.num_pml[0]
num_pml_1 = mode_spec.num_pml[1]
((pml_thick_0_plus, pml_thick_0_minus), (pml_thick_1_plus, pml_thick_1_minus)) = (
cls._pml_thickness(simulation=simulation, plane=plane, mode_spec=mode_spec)
)
# Mode Plane width and height
mp_w, mp_h = cls._mode_plane_size(simulation=simulation, plane=plane)
# Plot the absorbing layers.
if num_pml_0 > 0 or num_pml_1 > 0:
pml_rect = []
if pml_thick_0_minus > 0:
pml_rect.append(Rectangle((h_lim[0], v_lim[0]), pml_thick_0_minus, mp_h))
if pml_thick_0_plus > 0:
pml_rect.append(
Rectangle((h_lim[1] - pml_thick_0_plus, v_lim[0]), pml_thick_0_plus, mp_h)
)
if pml_thick_1_minus > 0:
pml_rect.append(Rectangle((h_lim[0], v_lim[0]), mp_w, pml_thick_1_minus))
if pml_thick_1_plus > 0:
pml_rect.append(
Rectangle((h_lim[0], v_lim[1] - pml_thick_1_plus), mp_w, pml_thick_1_plus)
)
pc = PatchCollection(
pml_rect,
alpha=plot_params_pml.alpha,
facecolor=plot_params_pml.facecolor,
edgecolor=plot_params_pml.edgecolor,
hatch=plot_params_pml.hatch,
zorder=plot_params_pml.zorder,
)
ax.add_collection(pc)
return ax
@staticmethod
def _center_and_lims(simulation: Simulation, plane: Box) -> tuple[list, list, list, list]:
"""Get the mode plane center and limits."""
normal_axis = plane.size.index(0.0)
n_axis, t_axes = plane.pop_axis([0, 1, 2], normal_axis)
a_center = [None, None, None]
a_center[n_axis] = plane.center[n_axis]
_, (h_min_s, v_min_s) = Box.pop_axis(simulation.bounds[0], axis=n_axis)
_, (h_max_s, v_max_s) = Box.pop_axis(simulation.bounds[1], axis=n_axis)
h_min = plane.center[t_axes[0]] - plane.size[t_axes[0]] / 2
h_max = plane.center[t_axes[0]] + plane.size[t_axes[0]] / 2
v_min = plane.center[t_axes[1]] - plane.size[t_axes[1]] / 2
v_max = plane.center[t_axes[1]] + plane.size[t_axes[1]] / 2
h_lim = [
h_min if abs(h_min) < abs(h_min_s) else h_min_s,
h_max if abs(h_max) < abs(h_max_s) else h_max_s,
]
v_lim = [
v_min if abs(v_min) < abs(v_min_s) else v_min_s,
v_max if abs(v_max) < abs(v_max_s) else v_max_s,
]
return a_center, h_lim, v_lim, t_axes
def _validate_modes_size(self) -> None:
"""Make sure that the total size of the modes fields is not too large."""
monitor = self.to_mode_solver_monitor(name=MODE_MONITOR_NAME)
num_cells = self.simulation._monitor_num_cells(monitor)
# size in GB
total_size = monitor._storage_size_solver(num_cells=num_cells, tmesh=[]) / 1e9
if total_size > MAX_MODES_DATA_SIZE_GB:
raise SetupError(
f"Mode solver has {total_size:.2f}GB of estimated storage, "
f"a maximum of {MAX_MODES_DATA_SIZE_GB:.2f}GB is allowed. Consider making the "
"mode plane smaller, or decreasing the resolution or number of requested "
"frequencies or modes."
)
def _validate_auto_impedance_spec(self) -> None:
"""Raise error if the microwave mode specification with ``AutoImpedanceSpec`` will
fail to instantiate."""
if not self._has_microwave_mode_spec:
return
self.mode_spec._validate_auto_impedance_setup(
center=self.plane.center,
size=self.plane.size,
colocate=self.colocate,
volumetric_structures=self.simulation.volumetric_structures,
grid=self.simulation.grid,
symmetry=self.simulation.symmetry,
simulation_geometry=self.simulation.simulation_geometry,
label=" for mode solver",
interior_disjoint_geometries=ModePlaneAnalyzer.apply_interior_disjoint_geometries(
self.simulation.structure_priority_mode
),
)
[docs]
def validate_pre_upload(self) -> None:
"""Validate the fully initialized mode solver is ok for upload to our servers."""
self._validate_modes_size()
self._validate_auto_impedance_spec()
@cached_property
def reduced_simulation_copy(self) -> Self:
"""Strip objects not used by the mode solver from simulation object.
This might significantly reduce upload time in the presence of custom mediums.
"""
# for now, we handle EME simulation by converting to FDTD simulation
# because we can't take planar subsection of an EME simulation.
# eventually, we will convert to ModeSimulation
if isinstance(self.simulation, EMESimulation):
return self.as_fdtd_mode_solver.reduced_simulation_copy
# we preserve extra cells along the normal direction to ensure there is enough data for
# subpixel
extended_grid = self._get_output_grid(
simulation=self.simulation,
plane=self.plane,
keep_additional_layers=True,
truncate_symmetry=False,
)
grids_1d = extended_grid.boundaries
new_sim_box = Box.from_bounds(
rmin=(grids_1d.x[0], grids_1d.y[0], grids_1d.z[0]),
rmax=(grids_1d.x[-1], grids_1d.y[-1], grids_1d.z[-1]),
)
# remove PML, Absorers, etc, to avoid unnecessary cells
bspec = self.simulation.boundary_spec
new_bspec_dict = {}
for axis in "xyz":
bcomp = bspec[axis]
for bside, sign in zip([bcomp.plus, bcomp.minus], "+-"):
if isinstance(bside, (PML, StablePML, Absorber)):
new_bspec_dict[axis + sign] = PECBoundary()
else:
new_bspec_dict[axis + sign] = bside
new_bspec = BoundarySpec(
x=Boundary(plus=new_bspec_dict["x+"], minus=new_bspec_dict["x-"]),
y=Boundary(plus=new_bspec_dict["y+"], minus=new_bspec_dict["y-"]),
z=Boundary(plus=new_bspec_dict["z+"], minus=new_bspec_dict["z-"]),
)
# extract sub-simulation removing everything irrelevant
new_sim = self.simulation.subsection(
region=new_sim_box,
monitors=(),
sources=(),
internal_absorbers=(),
warn_symmetry_expansion=False, # we already warn upon mode solver creation
grid_spec="identical",
boundary_spec=new_bspec,
remove_outside_custom_mediums=True,
remove_outside_structures=True,
include_pml_cells=True,
validate_geometries=False,
deep_copy=False,
low_freq_smoothing=None,
)
# Let's only validate mode solver where geometry validation is skipped: geometry replaced by its bounding
# box
structures = tuple(
strc.updated_copy(geometry=strc.geometry.bounding_box, deep=False)
for strc in new_sim.structures
)
# skip validation as it's validated already in subsection
aux_new_sim = new_sim.updated_copy(structures=structures, deep=False, validate=False)
# validate mode solver here where geometry is replaced by its bounding box
new_mode = self.updated_copy(simulation=aux_new_sim, deep=False)
# return full mode solver and skip validation
return new_mode.updated_copy(simulation=new_sim, deep=False, validate=False)
[docs]
def to_fdtd_mode_solver(self) -> ModeSolver:
"""Construct a new :class:`.ModeSolver` by converting ``simulation``
from a :class:`.EMESimulation` to an FDTD :class:`.Simulation`.
Only used as a workaround until :class:`.EMESimulation` is natively supported in the
:class:`.ModeSolver` webapi."""
if not isinstance(self.simulation, EMESimulation):
raise ValidationError(
"The method 'to_fdtd_mode_solver' is only needed "
"when the 'simulation' is an 'EMESimulation'."
)
fdtd_sim = self.simulation._as_fdtd_sim
return self.updated_copy(simulation=fdtd_sim)
@cached_property
def as_fdtd_mode_solver(self) -> ModeSolver:
"""Construct a new :class:`.ModeSolver` by converting ``simulation``
from a :class:`.EMESimulation` to an FDTD :class:`.Simulation`.
Only used as a workaround until :class:`.EMESimulation` is natively supported in the
:class:`.ModeSolver` webapi."""
return self.to_fdtd_mode_solver()
def _patch_data(self, data: ModeSolverData) -> None:
"""
Patch the :class:`.ModeSolver` with the provided data so that
it will be used everywhere instead of locally-computed data.
This function is available as a workaround while we transition
to the new :class:`.ModeSimulation` interface. It is used
in the webapi to make functions like ``plot_field`` operate
on the data from the remote mode solver rather than data from
the local mode solver. It can also be used manually if needed.
Parameters
----------
data : :class:`.ModeSolverData`
The mode solver data to be used, typically the result
of a remote mode solver run.
"""
self._cached_properties["data_raw"] = data
self._cached_properties.pop("data", None)
self._cached_properties.pop("sim_data", None)
[docs]
def plot_3d(self, width: int = 800, height: int = 800) -> None:
"""Render 3D plot of ``ModeSolver`` (in jupyter notebook only).
Parameters
----------
width : float = 800
width of the 3d view dom's size
height : float = 800
height of the 3d view dom's size
"""
return self.simulation.plot_3d(width=width, height=height)