Source code for tidy3d.plugins.smatrix.component_modelers.modal

"""Tool for generating an S matrix automatically from a Tidy3d simulation and modal port definitions."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import autograd.numpy as np
from pydantic import Field

from tidy3d.components.base import cached_property
from tidy3d.components.index import SimulationMap
from tidy3d.components.source.time import GaussianPulse
from tidy3d.components.types import Complex
from tidy3d.components.viz import add_ax_if_none, equal_aspect
from tidy3d.constants import GLANCING_CUTOFF
from tidy3d.exceptions import SetupError
from tidy3d.plugins.smatrix.ports.modal import (
    AbstractGaussianPort,
    GaussianPort,
    ModalPortType,
)
from tidy3d.plugins.smatrix.types import Element, MatrixIndex

from .base import FWIDTH_FRAC, AbstractComponentModeler

if TYPE_CHECKING:
    from tidy3d.components.data.sim_data import SimulationData
    from tidy3d.components.monitor import AbstractOverlapMonitor
    from tidy3d.components.simulation import Simulation
    from tidy3d.components.source.field import DirectionalSource
    from tidy3d.components.types import Ax
    from tidy3d.plugins.smatrix.ports.modal import Port


[docs] class ModalComponentModeler(AbstractComponentModeler): """A tool for modeling devices and computing scattering matrix elements. Notes ----- This class orchestrates the process of running multiple simulations to derive the scattering matrix (S-matrix) of a component. It uses modal or Gaussian sources and monitors defined by a set of ports. See Also -------- **Notebooks** * `Computing the scattering matrix of a device <../../notebooks/SMatrix.html>`_ """ ports: tuple[ModalPortType, ...] = Field( (), title="Ports", description="Collection of ports describing the scattering matrix elements. " "For each input mode, one simulation will be run with a modal source.", ) run_only: tuple[MatrixIndex, ...] | None = Field( None, title="Run Only", description="Set of matrix indices that define the simulations to run. " "If ``None``, simulations will be run for all indices in the scattering matrix. " "If a tuple is given, simulations will be run only for the given matrix indices.", ) element_mappings: tuple[tuple[Element, Element, Complex], ...] = Field( (), title="Element Mappings", description="Tuple of S matrix element mappings, each described by a tuple of " "(input_element, output_element, coefficient), where the coefficient is the " "element_mapping coefficient describing the relationship between the input and output " "matrix element. If all elements of a given column of the scattering matrix are defined " "by ``element_mappings``, the simulation corresponding to this column is skipped automatically.", ) @property def base_sim(self) -> Simulation: """The base simulation.""" return self.simulation @cached_property def sim_dict(self) -> SimulationMap: """Generates all :class:`.Simulation` objects for the S-matrix calculation. Returns ------- Dict[str, Simulation] A dictionary where keys are task names and values are the corresponding :class:`.Simulation` objects. Each simulation is configured to excite a specific mode at a specific port and includes all necessary monitors. """ sim_dict = {} mode_monitors = [self.to_monitor(port=port) for port in self.ports] for port_name, mode_index in self.matrix_indices_run_sim: port = self.get_port_by_name(port_name=port_name) port_source = self.shift_port(port=port) mode_source = self.to_source(port=port_source, mode_index=mode_index) new_mnts = list(self.simulation.monitors) + mode_monitors sim_copy = self.simulation.copy(update={"sources": [mode_source], "monitors": new_mnts}) task_name = self.get_task_name(port=port, mode_index=mode_index) sim_dict[task_name] = sim_copy return SimulationMap(keys=tuple(sim_dict.keys()), values=tuple(sim_dict.values())) def _construct_matrix_indices_monitor( self, ports: tuple[ModalPortType, ...] ) -> tuple[MatrixIndex, ...]: """Construct matrix indices for monitoring from modal ports. Parameters ---------- ports : tuple[Port, ...] Tuple of Port objects. Returns ------- tuple[MatrixIndex, ...] Tuple of (port_name, mode_index) pairs. """ matrix_indices = [] for port in ports: for mode_index in range(port.num_modes): matrix_indices.append((port.name, mode_index)) return tuple(matrix_indices) @cached_property def matrix_indices_monitor(self) -> tuple[MatrixIndex, ...]: """Returns a tuple of all possible matrix indices for monitoring. Each matrix index is a tuple of (port_name, mode_index). Returns ------- Tuple[MatrixIndex, ...] A tuple of all possible matrix indices for the monitoring ports. """ return self._construct_matrix_indices_monitor(self.ports) @cached_property def matrix_indices_source(self) -> tuple[MatrixIndex, ...]: """Tuple of all the source matrix indices, which may be less than the total number of ports.""" return super().matrix_indices_source @cached_property def matrix_indices_run_sim(self) -> tuple[MatrixIndex, ...]: """Tuple of all the matrix indices that will be used to run simulations.""" return super().matrix_indices_run_sim @cached_property def port_names(self) -> tuple[list[str], list[str]]: """Returns lists of port names for inputs and outputs. Returns ------- Tuple[List[str], List[str]] A tuple containing two lists: the first with the names of the output ports, and the second with the names of the input ports. """ def get_port_names(matrix_elements: tuple[str, int]) -> list[str]: """Get the port names from a list of (port name, mode index).""" port_names = [] for port_name, _ in matrix_elements: if port_name not in port_names: port_names.append(port_name) return port_names port_names_in = get_port_names(self.matrix_indices_source) port_names_out = get_port_names(self.matrix_indices_monitor) return port_names_out, port_names_in
[docs] def to_monitor(self, port: ModalPortType) -> AbstractOverlapMonitor: """Creates an overlap monitor from a given port (modal or gaussian).""" return port.to_monitor(freqs=self.freqs)
[docs] def to_source( self, port: ModalPortType, mode_index: int, num_freqs: int = 1, **kwargs: Any ) -> DirectionalSource: """Creates a source from a given port (modal or gaussian).""" freqs = self.freqs freq0 = 0.5 * (max(freqs) + min(freqs)) fdiff = max(freqs) - min(freqs) fwidth = max(fdiff, freq0 * FWIDTH_FRAC) source_time = self.custom_source_time if source_time is None: source_time = GaussianPulse( freq0=freq0, fwidth=fwidth, remove_dc_component=self.remove_dc_component, ) return port.to_source( freq0=freq0, fwidth=fwidth, mode_index=mode_index, num_freqs=num_freqs, source_time=source_time, **kwargs, )
[docs] def shift_port(self, port: ModalPortType) -> ModalPortType: """Generates a new port shifted slightly in the normal direction. This is to ensure that the source is placed just inside the simulation domain, away from the PML. Parameters ---------- port : Port The port to shift. Returns ------- Port A new :class:`.Port` object with its center shifted. """ shift_value = self._shift_value_signed(port=port, simulation=self.simulation) center_shifted = list(port.center) normal_dim, plane_dims = port.pop_axis([0, 1, 2], port.size.index(0.0)) center_shifted[normal_dim] += shift_value update = {} if isinstance(port, AbstractGaussianPort): theta = port.angle_theta phi = port.angle_phi cos_theta = np.cos(theta) if np.abs(cos_theta) < GLANCING_CUTOFF: raise SetupError( "Cannot shift Gaussian port at glancing incidence. " "Adjust angle_theta or use a different injection axis." ) tan_theta = np.sin(theta) / cos_theta center_shifted[plane_dims[0]] += shift_value * tan_theta * np.cos(phi) center_shifted[plane_dims[1]] += shift_value * tan_theta * np.sin(phi) if isinstance(port, GaussianPort): waist_distance = port.waist_distance + shift_value / cos_theta update["waist_distance"] = waist_distance else: waist_distances = list(port.waist_distances) waist_distances[0] = waist_distances[0] + shift_value / cos_theta waist_distances[1] = waist_distances[1] + shift_value / cos_theta update["waist_distances"] = waist_distances port_shifted = port.updated_copy(center=center_shifted, **update) return port_shifted
[docs] @equal_aspect @add_ax_if_none def plot_sim( self, x: float | None = None, y: float | None = None, z: float | None = None, ax: Ax = None, ) -> Ax: """Plots the simulation with all sources added for troubleshooting. This method creates a temporary simulation with all mode sources activated to help visualize the setup. Parameters ---------- x : float, optional The x-coordinate of the cross-section, by default None. y : float, optional The y-coordinate of the cross-section, by default None. z : float, optional The z-coordinate of the cross-section, by default None. ax : Ax, optional The matplotlib axes to plot on, by default None. Returns ------- Ax The matplotlib axes with the plot. """ plot_sources = [] for port_source in self.ports: # for plotting, use mode_index=0 (gaussian ignores it) src0 = self.to_source(port=port_source, mode_index=0) plot_sources.append(src0) sim_plot = self.simulation.copy(update={"sources": tuple(plot_sources)}) return sim_plot.plot(x=x, y=y, z=z, ax=ax)
[docs] @equal_aspect @add_ax_if_none def plot_sim_eps( self, x: float | None = None, y: float | None = None, z: float | None = None, ax: Ax = None, **kwargs: Any, ) -> Ax: """Plots the permittivity of the simulation with all sources. This method is useful for visualizing the device geometry along with the placement of the sources. Parameters ---------- x : float, optional The x-coordinate of the cross-section, by default None. y : float, optional The y-coordinate of the cross-section, by default None. z : float, optional The z-coordinate of the cross-section, by default None. ax : Ax, optional The matplotlib axes to plot on, by default None. **kwargs Additional keyword arguments passed to the plotter. Returns ------- Ax The matplotlib axes with the plot. """ plot_sources = [] for port_source in self.ports: src0 = self.to_source(port=port_source, mode_index=0) plot_sources.append(src0) sim_plot = self.simulation.copy(update={"sources": tuple(plot_sources)}) return sim_plot.plot_eps(x=x, y=y, z=z, ax=ax, **kwargs)
def _normalization_factor(self, port_source: Port, sim_data: SimulationData) -> complex: """Computes the normalization amplitude for the input mode. This is used to normalize the S-matrix elements. Parameters ---------- port_source : Port The port that was excited. sim_data : SimulationData The data from the simulation run. Returns ------- complex The complex amplitude of the input mode. """ port_monitor_data = sim_data[port_source.name] # some sources (GaussianBeam) don't have 'mode_index'; default to 0 src = sim_data.simulation.sources[0] mode_index = getattr(src, "mode_index", 0) normalize_amps = port_monitor_data.amps.sel( f=np.array(self.freqs), direction=port_source.direction, mode_index=mode_index, ) return normalize_amps.values @cached_property def max_mode_index(self) -> tuple[int, int]: """Returns the maximum mode indices for the in and out ports. Returns ------- Tuple[int, int] A tuple containing the maximum mode index for the output ports and the maximum mode index for the input ports. """ def get_max_mode_indices(matrix_elements: tuple[str, int]) -> int: """Get the maximum mode index for a list of (port name, mode index).""" return max(mode_index for _, mode_index in matrix_elements) max_mode_index_out = get_max_mode_indices(self.matrix_indices_monitor) max_mode_index_in = get_max_mode_indices(self.matrix_indices_source) return max_mode_index_out, max_mode_index_in
[docs] def task_name_from_index(self, matrix_index: MatrixIndex) -> str: """Compute task name for a given (port_name, mode_index) without constructing simulations.""" port_name, mode_index = matrix_index port = self.get_port_by_name(port_name=port_name) return self.get_task_name(port=port, mode_index=mode_index)