Source code for tidy3d.plugins.invdes.design

# container for everything defining the inverse design

from __future__ import annotations

import abc
import typing

import autograd.numpy as anp
import numpy as np
import pydantic.v1 as pd

import tidy3d as td
from tidy3d.components.autograd import get_static
from tidy3d.exceptions import ValidationError
from tidy3d.plugins.expressions.metrics import Metric, generate_validation_data
from tidy3d.plugins.expressions.types import ExpressionType

from .base import InvdesBaseModel
from .region import DesignRegionType
from .validators import check_pixel_size

PostProcessFnType = typing.Callable[[td.SimulationData], float]


class AbstractInverseDesign(InvdesBaseModel, abc.ABC):
    """Container for an inverse design problem."""

    design_region: DesignRegionType = pd.Field(
        ...,
        title="Design Region",
        description="Region within which we will optimize the simulation.",
    )

    task_name: str = pd.Field(
        ...,
        title="Task Name",
        description="Task name to use in the objective function when running the ``JaxSimulation``.",
    )

    verbose: bool = pd.Field(
        False,
        title="Task Verbosity",
        description="If ``True``, will print the regular output from ``web`` functions.",
    )

    metric: typing.Optional[ExpressionType] = pd.Field(
        None,
        title="Objective Metric",
        description="Serializable expression defining the objective function.",
    )

    def make_objective_fn(
        self, post_process_fn: typing.Optional[typing.Callable] = None, maximize: bool = True
    ) -> typing.Callable[[anp.ndarray], tuple[float, dict]]:
        """Construct the objective function for this InverseDesign object."""

        if (post_process_fn is None) and (self.metric is None):
            raise ValueError("Either 'post_process_fn' or 'metric' must be provided.")

        if (post_process_fn is not None) and (self.metric is not None):
            raise ValueError("Provide only one of 'post_process_fn' or 'metric', not both.")

        direction_multiplier = 1 if maximize else -1

        def objective_fn(params: anp.ndarray, aux_data: dict = None) -> float:
            """Full objective function."""
            data = self.to_simulation_data(params=params)

            if self.metric is None:
                post_process_val = post_process_fn(data)
            elif isinstance(data, td.SimulationData):
                post_process_val = self.metric.evaluate(data)
            elif getattr(data, "type", None) == "BatchData":
                raise NotImplementedError("Metrics currently do not support 'BatchData'")
            else:
                raise ValueError(f"Invalid data type: {type(data)}")

            penalty_value = self.design_region.penalty_value(params)
            objective_fn_val = direction_multiplier * post_process_val - penalty_value

            # Store auxiliary data if provided
            if aux_data is not None:
                aux_data["penalty"] = get_static(penalty_value)
                aux_data["post_process_val"] = get_static(post_process_val)
                aux_data["objective_fn_val"] = get_static(objective_fn_val) * direction_multiplier
                if isinstance(data, td.SimulationData):
                    aux_data["sim_data"] = data.to_static()
                else:
                    aux_data["sim_data"] = {k: v.to_static() for k, v in data.items()}
                aux_data["params"] = params

            return objective_fn_val

        return objective_fn

    @property
    def initial_simulation(self) -> td.Simulation:
        """Return a simulation with the initial design region parameters."""
        initial_params = self.design_region.initial_parameters
        return self.to_simulation(initial_params)

    def run(self, simulation, **kwargs) -> td.SimulationData:
        """Run a single tidy3d simulation."""
        from tidy3d.web import run

        kwargs.setdefault("verbose", self.verbose)
        kwargs.setdefault("task_name", self.task_name)
        return run(simulation, **kwargs)

    def run_async(self, simulations, **kwargs) -> web.BatchData:  # noqa: F821
        """Run a batch of tidy3d simulations."""
        from tidy3d.web import run_async

        kwargs.setdefault("verbose", self.verbose)
        return run_async(simulations, **kwargs)


[docs] class InverseDesign(AbstractInverseDesign): """Container for an inverse design problem.""" simulation: td.Simulation = pd.Field( ..., title="Base Simulation", description="Simulation without the design regions or monitors used in the objective fn.", ) output_monitor_names: typing.Tuple[str, ...] = pd.Field( None, title="Output Monitor Names", description="Optional names of monitors whose data the differentiable output depends on." "If this field is left ``None``, the plugin will try to add all compatible monitors to " "``JaxSimulation.output_monitors``. While this will work, there may be warnings if the " "monitors are not compatible with the ``adjoint`` plugin, for example if there are " "``FieldMonitor`` instances with ``.colocate != False``.", ) _check_sim_pixel_size = check_pixel_size("simulation") @pd.root_validator(pre=False) def _validate_model(cls, values: dict) -> dict: cls._validate_metric(values) return values @staticmethod def _validate_metric(values: dict) -> dict: metric_expr = values.get("metric") if not metric_expr: return values simulation = values.get("simulation") for metric in metric_expr.filter(Metric): InverseDesign._validate_metric_monitor_name(metric, simulation) InverseDesign._validate_metric_mode_index(metric, simulation) InverseDesign._validate_metric_f(metric, simulation) InverseDesign._validate_metric_data(metric_expr, simulation) return values @staticmethod def _validate_metric_monitor_name(metric: Metric, simulation: td.Simulation) -> None: """Validate that the monitor name of the metric exists in the simulation.""" monitor = next((m for m in simulation.monitors if m.name == metric.monitor_name), None) if monitor is None: raise ValidationError( f"Monitor named '{metric.monitor_name}' associated with the metric not found in the simulation monitors." ) @staticmethod def _validate_metric_mode_index(metric: Metric, simulation: td.Simulation) -> None: """Validate that the mode index of the metric is within the bounds of the monitor's ``ModeSpec.num_modes``.""" monitor = next((m for m in simulation.monitors if m.name == metric.monitor_name), None) if metric.mode_index >= monitor.mode_spec.num_modes: raise ValidationError( f"Mode index '{metric.mode_index}' for metric associated with monitor " f"'{metric.monitor_name}' is out of bounds. " f"Maximum allowed mode index is '{monitor.mode_spec.num_modes - 1}'." ) @staticmethod def _validate_metric_f(metric: Metric, simulation: td.Simulation) -> None: """Validate that the frequencies of the metric are present in the monitor.""" monitor = next((m for m in simulation.monitors if m.name == metric.monitor_name), None) if metric.f is not None: metric_f_list = [metric.f] if isinstance(metric.f, float) else metric.f if len(metric_f_list) != 1: raise ValidationError("Only a single frequency is supported for the metric.") for freq in metric_f_list: if not any(np.isclose(freq, monitor.freqs, atol=1.0)): raise ValidationError( f"Frequency '{freq}' for metric associated with monitor " f"'{metric.monitor_name}' not found in monitor frequencies." ) else: if len(monitor.freqs) != 1: raise ValidationError( f"Monitor '{metric.monitor_name}' must contain only a single frequency when metric.f is None." ) @staticmethod def _validate_metric_data(expr: ExpressionType, simulation: td.Simulation) -> None: """Validate that expression can be evaluated and returns a real scalar.""" data = generate_validation_data(expr) try: result = expr(data) except Exception as e: raise ValidationError(f"Failed to evaluate the metric expression: {str(e)}") from e if len(np.ravel(result)) > 1: raise ValidationError( f"The expression must return a scalar value or an array of length 1 (got {result})." ) if not np.all(np.isreal(result)): raise ValidationError( f"The expression must return a real (not complex) value (got {result})." )
[docs] def is_output_monitor(self, monitor: td.Monitor) -> bool: """Whether a monitor is added to the ``JaxSimulation`` as an ``output_monitor``.""" output_mnt_types = td.components.simulation.OutputMonitorTypes if self.output_monitor_names is None: return any(isinstance(monitor, mnt_type) for mnt_type in output_mnt_types) return monitor.name in self.output_monitor_names
[docs] def separate_output_monitors(self, monitors: typing.Tuple[td.Monitor]) -> dict: """Separate monitors into output_monitors and regular monitors.""" monitor_fields = dict(monitors=[], output_monitors=[]) for monitor in monitors: key = "output_monitors" if self.is_output_monitor(monitor) else "monitors" monitor_fields[key].append(monitor) return monitor_fields
[docs] def to_simulation(self, params: anp.ndarray) -> td.Simulation: """Convert the ``InverseDesign`` to a corresponding ``td.Simulation`` with traced fields.""" # construct the design region to a regular structure design_region_structure = self.design_region.to_structure(params) # construct mesh override structures and a new grid spec, if applicable grid_spec = self.simulation.grid_spec mesh_override_structure = self.design_region.mesh_override_structure if mesh_override_structure: override_structures = list(self.simulation.grid_spec.override_structures) override_structures += [mesh_override_structure] grid_spec = grid_spec.updated_copy(override_structures=override_structures) return self.simulation.updated_copy( structures=list(self.simulation.structures) + [design_region_structure], grid_spec=grid_spec, )
[docs] def to_simulation_data(self, params: anp.ndarray, **kwargs) -> td.SimulationData: """Convert the ``InverseDesign`` to a ``td.Simulation`` and run it.""" simulation = self.to_simulation(params=params) return self.run(simulation, **kwargs)
[docs] class InverseDesignMulti(AbstractInverseDesign): """``InverseDesign`` with multiple simulations and corresponding postprocess functions.""" simulations: typing.Tuple[td.Simulation, ...] = pd.Field( ..., title="Base Simulations", description="Set of simulation without the design regions or monitors used in the objective fn.", ) output_monitor_names: typing.Tuple[typing.Union[typing.Tuple[str, ...], None], ...] = pd.Field( None, title="Output Monitor Names", description="Optional names of monitors whose data the differentiable output depends on." "If this field is left ``None``, the plugin will try to add all compatible monitors to " "``JaxSimulation.output_monitors``. While this will work, there may be warnings if the " "monitors are not compatible with the ``adjoint`` plugin, for example if there are " "``FieldMonitor`` instances with ``.colocate != False``.", ) _check_sim_pixel_size = check_pixel_size("simulations") @pd.root_validator() def _check_lengths(cls, values): """Check the lengths of all of the multi fields.""" keys = ("simulations", "post_process_fns", "output_monitor_names", "override_structure_dl") multi_dict = {key: values.get(key) for key in keys} sizes = {key: len(val) for key, val in multi_dict.items() if val is not None} if len(set(sizes.values())) != 1: raise ValueError( f"'MultiInverseDesign' requires that the fields {keys} must either " "have the same length or be left ``None``, if optional. Given fields with " "corresponding sizes of '{sizes}'." ) return values @property def task_names(self) -> list[str]: """Task names associated with each of the simulations.""" return [f"{self.task_name}_{i}" for i in range(len(self.simulations))] @property def designs(self) -> typing.List[InverseDesign]: """List of individual ``InverseDesign`` objects corresponding to this instance.""" designs_list = [] for i, (task_name, sim) in enumerate(zip(self.task_names, self.simulations)): des_i = InverseDesign( design_region=self.design_region, simulation=sim, verbose=self.verbose, task_name=task_name, ) if self.output_monitor_names is not None: des_i = des_i.updated_copy(output_monitor_names=self.output_monitor_names[i]) designs_list.append(des_i) return designs_list
[docs] def to_simulation(self, params: anp.ndarray) -> dict[str, td.Simulation]: """Convert the ``InverseDesign`` to a corresponding dict of ``td.Simulation``s.""" simulation_list = [design.to_simulation(params) for design in self.designs] return dict(zip(self.task_names, simulation_list))
[docs] def to_simulation_data(self, params: anp.ndarray, **kwargs) -> web.BatchData: # noqa: F821 """Convert the ``InverseDesignMulti`` to a set of ``td.Simulation``s and run async.""" simulations = self.to_simulation(params) return self.run_async(simulations, **kwargs)
InverseDesignType = typing.Union[InverseDesign, InverseDesignMulti]