Source code for tidy3d.plugins.smatrix.component_modelers.base

"""Base class for generating an S matrix automatically from tidy3d simulations and port definitions."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Optional, Union

from pydantic import Field, field_validator, model_validator

from tidy3d.components.base import Tidy3dBaseModel, cached_property
from tidy3d.components.geometry.utils import _shift_value_signed
from tidy3d.components.simulation import Simulation
from tidy3d.components.types import Complex, FreqArray
from tidy3d.components.types.time import SourceTimeType
from tidy3d.components.validators import (
    assert_unique_names,
    validate_freqs_min,
    validate_freqs_not_empty,
    validate_freqs_unique,
)
from tidy3d.config import config
from tidy3d.constants import HERTZ
from tidy3d.exceptions import SetupError, Tidy3dKeyError
from tidy3d.log import log
from tidy3d.plugins.smatrix.ports.modal import Port
from tidy3d.plugins.smatrix.ports.types import LumpedPortType, TerminalPortType
from tidy3d.plugins.smatrix.ports.wave import WavePort
from tidy3d.plugins.smatrix.types import Element, MatrixIndex, NetworkElement, NetworkIndex

if TYPE_CHECKING:
    from pydantic import ValidationInfo

    from tidy3d.compat import Self
    from tidy3d.plugins.smatrix import MicrowaveSMatrixData
    from tidy3d.plugins.smatrix.ports.modal import ModalPortDataArray
    from tidy3d.plugins.smatrix.ports.types import PortType
    from tidy3d.web.core.types import PayType
# fwidth of gaussian pulse in units of central frequency
FWIDTH_FRAC = 1.0 / 10
DEFAULT_DATA_DIR = "."

IndexType = Union[MatrixIndex, NetworkIndex]
ElementType = Union[Element, NetworkElement]
TaskNameFormat = Literal["RF", "PF"]


[docs] class AbstractComponentModeler(ABC, Tidy3dBaseModel): """Tool for modeling devices and computing port parameters.""" name: str = Field( "", title="Name", ) simulation: Simulation = Field( title="Simulation", description="Simulation describing the device without any sources present.", ) ports: tuple[Union[Port, TerminalPortType], ...] = Field( (), title="Ports", description="Collection of ports describing the scattering matrix elements. " "For each input mode, one simulation will be run with a modal source.", ) freqs: FreqArray = Field( title="Frequencies", description="Array or list of frequencies at which to compute port parameters.", json_schema_extra={"units": HERTZ}, ) remove_dc_component: bool = Field( True, title="Remove DC Component", description="Whether to remove the DC component in the Gaussian pulse spectrum. " "If ``True``, the Gaussian pulse is modified at low frequencies to zero out the " "DC component, which is usually desirable so that the fields will decay. However, " "for broadband simulations, it may be better to have non-vanishing source power " "near zero frequency. Setting this to ``False`` results in an unmodified Gaussian " "pulse spectrum which can have a nonzero DC component.", ) run_only: Optional[tuple[IndexType, ...]] = Field( None, title="Run Only", description="Set of matrix indices that define the simulations to run. " "If ``None``, simulations will be run for all indices in the scattering matrix. " "If a tuple is given, simulations will be run only for the given matrix indices.", ) element_mappings: tuple[tuple[ElementType, ElementType, Complex], ...] = Field( (), title="Element Mappings", description="Tuple of S matrix element mappings, each described by a tuple of " "(input_element, output_element, coefficient), where the coefficient is the " "element_mapping coefficient describing the relationship between the input and output " "matrix element. If all elements of a given column of the scattering matrix are defined " "by ``element_mappings``, the simulation corresponding to this column is skipped automatically.", ) custom_source_time: Optional[SourceTimeType] = Field( None, title="Custom Source Time", description="If provided, this will be used as specification of the source time-dependence in simulations. " "Otherwise, a default source time will be constructed.", ) @model_validator(mode="before") @classmethod def _warn_refactor_2_10(cls, data: dict) -> dict: log.warning( f"'{cls.__name__}' was refactored (tidy3d 'v2.10.0'). Existing functionality is available differently. Please consult the migration documentation: https://docs.flexcompute.com/projects/tidy3d/en/latest/api/microwave/microwave_migration.html", log_once=True, ) return data @field_validator("simulation") @classmethod def _sim_has_no_sources(cls, val: Simulation) -> Simulation: """Make sure simulation has no sources as they interfere with tool.""" if len(val.sources) > 0: raise SetupError(f"'{cls.__name__}.simulation' must not have any sources.") return val @field_validator("element_mappings") @classmethod def _validate_element_mappings( cls, element_mappings: tuple[tuple[ElementType, ElementType, Complex], ...], info: ValidationInfo, ) -> tuple[tuple[ElementType, ElementType, Complex], ...]: """ Validate that each source index referenced in element_mappings is included in run_only. """ run_only = info.data.get("run_only") if run_only is None: return element_mappings valid_set = set(run_only) invalid_indices = set() for mapping in element_mappings: input_element = mapping[0] output_element = mapping[1] for source_index in [input_element[1], output_element[1]]: if source_index not in valid_set: invalid_indices.add(source_index) if invalid_indices: raise SetupError( f"'element_mappings' references source index(es) {invalid_indices} " f"that are not present in run_only: {run_only}." ) return element_mappings @model_validator(mode="after") def _validate_run_only(self) -> Self: """Validate that run_only entries are unique and exist in matrix_indices_monitor.""" val = self.run_only if val is None: return self # Check uniqueness if len(val) != len(set(val)): duplicates = [idx for idx in set(val) if val.count(idx) > 1] raise SetupError( f"'run_only' contains duplicate entries: {duplicates}. " "Each index must appear only once." ) # Check membership - use the helper method to get valid indices ports = self.ports valid_indices = set(self._construct_matrix_indices_monitor(ports)) invalid_indices = [idx for idx in val if idx not in valid_indices] if invalid_indices: raise SetupError( f"'run_only' contains indices {invalid_indices} that are not present in " f"'matrix_indices_monitor'. Valid indices are: {sorted(valid_indices)}" ) return self _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() _freqs_unique = validate_freqs_unique() @model_validator(mode="after") def _freqs_in_custom_source_time(self) -> Self: """Make sure freqs is in the range of the custom source time.""" val = self.custom_source_time if val is None: return self freq_range = val._frequency_range_sigma_cached freqs = self.freqs if freq_range[0] > min(freqs) or max(freqs) > freq_range[1]: log.warning( "Custom source time does not cover all 'freqs'.", ) return self
[docs] @staticmethod def get_task_name(port: PortType, mode_index: Optional[int] = None) -> str: """Generates a standardized task name from a port object. This method creates a unique string identifier for a simulation task based on a port and, if applicable, a specified mode index. Parameters ---------- port : PortType The port object from which to derive the base name. mode_index : Optional[int], optional If provided, this index is appended to the port name (e.g., 'port_1@1'). Defaults to `None`, in which case the first mode is chosen by default. Returns ------- str The formatted task name string. Raises ------ ValueError If `mode_index` is specified for a lumped port. """ if isinstance(port, LumpedPortType): if mode_index is not None: raise ValueError( "'mode_index' should not be specified for a lumped port, " f"but was passed with value '{mode_index}'." ) return f"{port.name}" elif isinstance(port, WavePort): # WavePorts default to first mode index if mode_index is not None: return f"{port.name}@{mode_index}" return f"{port.name}@{port._mode_indices[0]}" else: # Modal ports default to 0 if mode_index is not None: return f"{port.name}@{mode_index}" return f"{port.name}@0"
[docs] def get_port_by_name(self, port_name: str) -> Port: """Get the port from the name.""" ports = [port for port in self.ports if port.name == port_name] if len(ports) == 0: raise Tidy3dKeyError(f'Port "{port_name}" not found.') return ports[0]
@staticmethod @abstractmethod def _construct_matrix_indices_monitor(ports: tuple) -> tuple[IndexType, ...]: """Construct matrix indices for monitoring from ports. This helper method is used by both the matrix_indices_monitor property and the run_only validator to ensure consistency. Parameters ---------- ports : tuple Tuple of port objects. Returns ------- tuple[IndexType, ...] Tuple of matrix indices for monitoring. """ @property @abstractmethod def matrix_indices_monitor(self) -> tuple[IndexType, ...]: """Abstract property for all matrix indices that will be used to collect data.""" @cached_property def matrix_indices_source(self) -> tuple[IndexType, ...]: """Tuple of all the source matrix indices, which may be less than the total number of ports.""" if self.run_only is not None: return self.run_only return self.matrix_indices_monitor @cached_property def matrix_indices_run_sim(self) -> tuple[IndexType, ...]: """Tuple of all the matrix indices that will be used to run simulations.""" if not self.element_mappings: return self.matrix_indices_source # all the (i, j) pairs in `S_ij` that are tagged as covered by `element_mappings` elements_determined_by_map = [element_out for (_, element_out, _) in self.element_mappings] # loop through rows of the full s matrix and record rows that still need running. source_indices_needed = [] for col_index in self.matrix_indices_source: # loop through columns and keep track of whether each element is covered by mapping. matrix_elements_covered = [] for row_index in self.matrix_indices_monitor: element = (row_index, col_index) element_covered_by_map = element in elements_determined_by_map matrix_elements_covered.append(element_covered_by_map) # if any matrix elements in row still not covered by map, a source is needed for row. if not all(matrix_elements_covered): source_indices_needed.append(col_index) return source_indices_needed def _shift_value_signed(self, port: Union[Port, WavePort]) -> float: """How far (signed) to shift the source from the monitor.""" return _shift_value_signed( obj=port, grid=self.simulation.grid, bounds=self.simulation.bounds, direction=port.direction, shift=-2, name=f"Port {port.name}", ) unique_port_names = assert_unique_names("ports")
[docs] def run( self, path_dir: str = DEFAULT_DATA_DIR, *, folder_name: str = "default", callback_url: Optional[str] = None, verbose: bool = True, solver_version: Optional[str] = None, pay_type: Union[PayType, str] = "AUTO", priority: Optional[int] = None, local_gradient: bool = False, max_num_adjoint_per_fwd: Optional[int] = None, ) -> Union[ModalPortDataArray, MicrowaveSMatrixData]: log.warning( "'ComponentModeler.run()' is deprecated and will be removed in a future release. " "Use web.run(modeler) instead. 'web.run' returns a 'ComponentModelerData' object; " "get the scattering matrix via 'data.smatrix()'.", log_once=True, ) from tidy3d.plugins.smatrix.run import _run_local if max_num_adjoint_per_fwd is None: max_num_adjoint_per_fwd = config.adjoint.max_adjoint_per_fwd data = _run_local( self, path_dir=path_dir, folder_name=folder_name, callback_url=callback_url, verbose=verbose, solver_version=solver_version, pay_type=pay_type, priority=priority, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, ) return data.smatrix()
[docs] def validate_pre_upload(self: Self) -> None: """Validate the modeler before upload.""" self.base_sim.validate_pre_upload(source_required=False)
AbstractComponentModeler.model_rebuild()