Source code for tidy3d.components.data.sim_data

"""Simulation Level Data"""

from __future__ import annotations

import json
import pathlib
from abc import ABC
from collections import defaultdict
from typing import Callable, Tuple, Union

import h5py
import numpy as np
import pydantic.v1 as pd
import xarray as xr

from ...constants import C_0, inf
from ...exceptions import DataError, FileError, Tidy3dKeyError
from ...log import log
from ..autograd.utils import split_list
from ..base import JSON_TAG, Tidy3dBaseModel
from ..base_sim.data.sim_data import AbstractSimulationData
from ..file_util import replace_values
from ..monitor import Monitor
from ..simulation import Simulation
from ..source import GaussianPulse, SourceType
from ..structure import Structure
from ..types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type
from ..viz import add_ax_if_none, equal_aspect
from .data_array import FreqDataArray
from .monitor_data import (
    AbstractFieldData,
    FieldTimeData,
    MonitorDataType,
    MonitorDataTypes,
)

DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes}

# maps monitor type (string) to the class of the corresponding data
DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes}

# residuals below this are considered good fits for broadband adjoint source creation
RESIDUAL_CUTOFF_ADJOINT = 1e-6


class AdjointSourceInfo(Tidy3dBaseModel):
    """Stores information about the adjoint sources to pass to autograd pipeline."""

    sources: Tuple[annotate_type(SourceType), ...] = pd.Field(
        ...,
        title="Adjoint Sources",
        description="Set of processed sources to include in the adjoint simulation.",
    )

    post_norm: Union[float, FreqDataArray] = pd.Field(
        ...,
        title="Post Normalization Values",
        description="Factor to multiply the adjoint fields by after running "
        "given the adjoint source pipeline used.",
    )

    normalize_sim: bool = pd.Field(
        ...,
        title="Normalize Adjoint Simulation",
        description="Whether the adjoint simulation needs to be normalized "
        "given the adjoint source pipeline used.",
    )


class AbstractYeeGridSimulationData(AbstractSimulationData, ABC):
    """Data from an :class:`.AbstractYeeGridSimulation` involving
    electromagnetic fields on a Yee grid.

    Notes
    -----

        The ``SimulationData`` objects store a copy of the original :class:`.Simulation`:, so it can be recovered if the
        ``SimulationData`` is loaded in a new session and the :class:`.Simulation` is no longer in memory.

        More importantly, the ``SimulationData`` contains a reference to the data for each of the monitors within the
        original :class:`.Simulation`. This data can be accessed directly using the name given to the monitors initially.
    """

    def load_field_monitor(self, monitor_name: str) -> AbstractFieldData:
        """Load monitor and raise exception if not a field monitor."""
        mon_data = self[monitor_name]
        if not isinstance(mon_data, AbstractFieldData):
            raise DataError(
                f"data for monitor '{monitor_name}' does not contain field data "
                f"as it is a '{type(mon_data)}'."
            )
        return mon_data

    def at_centers(self, field_monitor_name: str) -> xr.Dataset:
        """Return xarray.Dataset representation of field monitor data colocated at Yee cell centers.

        Parameters
        ----------
        field_monitor_name : str
            Name of field monitor used in the original :class:`Simulation`.

        Returns
        -------
        xarray.Dataset
            Dataset containing all of the fields in the data interpolated to center locations on
            the Yee grid.
        """

        monitor_data = self.load_field_monitor(field_monitor_name)
        return monitor_data.at_coords(monitor_data.colocation_centers)

    def _at_boundaries(self, monitor_data: xr.Dataset) -> xr.Dataset:
        """Return xarray.Dataset representation of field monitor data colocated at Yee cell
        boundaries.

        Parameters
        ----------
        monitor_data : xr.Dataset
            Monitor data to be co-located.

        Returns
        -------
        xarray.Dataset
            Dataset containing all of the fields in the data interpolated to boundary locations on
            the Yee grid.
        """

        if monitor_data.monitor.colocate:
            # TODO: this still errors if monitor_data.colocate is allowed to be ``True`` in the
            # adjoint plugin, and the monitor data is tracked in a gradient computation. It seems
            # interpolating does something to the arrays that makes the JAX chain work.
            return monitor_data.package_colocate_results(monitor_data.field_components)

        # colocate to monitor grid boundaries
        return monitor_data.at_coords(monitor_data.colocation_boundaries)

    def at_boundaries(self, field_monitor_name: str) -> xr.Dataset:
        """Return xarray.Dataset representation of field monitor data colocated at Yee cell
        boundaries.

        Parameters
        ----------
        field_monitor_name : str
            Name of field monitor used in the original :class:`Simulation`.

        Returns
        -------
        xarray.Dataset
            Dataset containing all of the fields in the data interpolated to boundary locations on
            the Yee grid.
        """

        # colocate to monitor grid boundaries
        return self._at_boundaries(self.load_field_monitor(field_monitor_name))

    def _get_poynting_vector(self, field_monitor_data: AbstractFieldData) -> xr.Dataset:
        """return ``xarray.Dataset`` of the Poynting vector at Yee cell centers.

        Calculated values represent the instantaneous Poynting vector for time-domain fields and the
        complex vector for frequency-domain: ``S = 1/2 E Γ— conj(H)``.

        Only the available components are returned, e.g., if the indicated monitor doesn't include
        field component `"Ex"`, then `"Sy"` and `"Sz"` will not be calculated.

        Parameters
        ----------
        field_monitor_data: AbstractFieldData
            Field monitor data from which to extract Poynting vector.

        Returns
        -------
        xarray.DataArray
            DataArray containing the Poynting vector calculated based on the field components
            colocated at the center locations of the Yee grid.
        """
        field_dataset = self._at_boundaries(field_monitor_data)

        time_domain = isinstance(field_monitor_data, FieldTimeData)

        poynting_components = {}

        dims = "xyz"
        for axis, dim in enumerate(dims):
            dim_1 = dims[axis - 2]
            dim_2 = dims[axis - 1]

            required_components = [f + c for f in "EH" for c in (dim_1, dim_2)]
            if not all(field_cmp in field_dataset for field_cmp in required_components):
                continue

            e_1 = field_dataset.data_vars["E" + dim_1]
            e_2 = field_dataset.data_vars["E" + dim_2]
            h_1 = field_dataset.data_vars["H" + dim_1]
            h_2 = field_dataset.data_vars["H" + dim_2]
            poynting_components["S" + dim] = (
                e_1 * h_2 - e_2 * h_1
                if time_domain
                else 0.5 * (e_1 * h_2.conj() - e_2 * h_1.conj())
            )

            # 2D monitors have grid correction factors that can be different from 1. For Poynting,
            # it is always the product of a primal-located field and dual-located field, so the
            # total grid correction factor is the product of the two
            grid_correction = (
                field_monitor_data.grid_dual_correction * field_monitor_data.grid_primal_correction
            )
            poynting_components["S" + dim] *= grid_correction

        return xr.Dataset(poynting_components)

    def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:
        """return ``xarray.Dataset`` of the Poynting vector at Yee cell centers.

        Calculated values represent the instantaneous Poynting vector for time-domain fields and the
        complex vector for frequency-domain: ``S = 1/2 E Γ— conj(H)``.

        Only the available components are returned, e.g., if the indicated monitor doesn't include
        field component `"Ex"`, then `"Sy"` and `"Sz"` will not be calculated.

        Parameters
        ----------
        field_monitor_name : str
            Name of field monitor used in the original :class:`Simulation`.

        Returns
        -------
        xarray.DataArray
            DataArray containing the Poynting vector calculated based on the field components
            colocated at the center locations of the Yee grid.
        """
        field_monitor_data = self.load_field_monitor(field_monitor_name)
        return self._get_poynting_vector(field_monitor_data=field_monitor_data)

    def _get_scalar_field(
        self,
        field_monitor_name: str,
        field_name: str,
        val: FieldVal,
        phase: float = 0.0,
    ):
        """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.

        Parameters
        ----------
        field_monitor_name : str
            Name of field monitor used in the original :class:`Simulation`.
        field_name : str
            Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
        val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
            Which part of the field to plot.
        phase : float = 0.0
            Optional phase to apply to result

        Returns
        -------
        xarray.DataArray
            DataArray containing the electric intensity of the field-like monitor.
            Data is interpolated to the center locations on Yee grid.
        """
        field_monitor_data = self.load_field_monitor(field_monitor_name)
        return self._get_scalar_field_from_data(
            field_monitor_data, field_name=field_name, val=val, phase=phase
        )

    def _get_scalar_field_from_data(
        self,
        field_monitor_data: AbstractFieldData,
        field_name: str,
        val: FieldVal,
        phase: float = 0.0,
    ):
        """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.

        Parameters
        ----------
        field_monitor_data : AbstractFieldData
            Field monitor data from which to extract scalar field.
        field_name : str
            Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
        val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
            Which part of the field to plot.
        phase : float = 0.0
            Optional phase to apply to result

        Returns
        -------
        xarray.DataArray
            DataArray containing the electric intensity of the field-like monitor.
            Data is interpolated to the center locations on Yee grid.
        """

        if field_name[0] == "S":
            dataset = self._get_poynting_vector(field_monitor_data)
            if len(field_name) > 1:
                if field_name in dataset:
                    derived_data = dataset[field_name]
                    derived_data.name = field_name
                    return self._field_component_value(derived_data, val)
                raise Tidy3dKeyError(f"Poynting component {field_name} not available")
        else:
            dataset = self._at_boundaries(field_monitor_data)

        dataset = self.apply_phase(data=dataset, phase=phase)

        if field_name in ("E", "H", "S"):
            # Gather vector components
            required_components = [field_name + c for c in "xyz"]
            if not all(field_cmp in dataset for field_cmp in required_components):
                raise DataError(
                    f"Field monitor must contain '{field_name}x', '{field_name}y', and "
                    f"'{field_name}z' fields to compute '{field_name}'."
                )
            field_components = (dataset[c] for c in required_components)

            # Apply the requested transformation
            val = val.lower()
            if val in ("real", "re"):
                derived_data = sum(f.real**2 for f in field_components) ** 0.5
                derived_data.name = f"|Re{{{field_name}}}|"

            elif val in ("imag", "im"):
                derived_data = sum(f.imag**2 for f in field_components) ** 0.5
                derived_data.name = f"|Im{{{field_name}}}|"

            elif val == "abs":
                derived_data = sum(abs(f) ** 2 for f in field_components) ** 0.5
                derived_data.name = f"|{field_name}|"

            elif val == "abs^2":
                derived_data = sum(abs(f) ** 2 for f in field_components)
                if hasattr(derived_data, "name"):
                    derived_data.name = f"|{field_name}|Β²"

            elif val == "phase":
                raise Tidy3dKeyError(f"Phase is not defined for complex vector {field_name}")

            else:
                raise Tidy3dKeyError(
                    f"'val' of {val} not supported. "
                    "Must be one of 'real', 'imag', 'abs', 'abs^2', or 'phase'."
                )

            return derived_data

        raise Tidy3dKeyError(
            f"Derived field name must be one of 'E', 'H', 'S', 'Sx', 'Sy', or 'Sz', received "
            f"'{field_name}'."
        )

    def get_intensity(self, field_monitor_name: str) -> xr.DataArray:
        """return `xarray.DataArray` of the intensity of a field monitor at Yee cell centers.

        Parameters
        ----------
        field_monitor_name : str
            Name of field monitor used in the original :class:`Simulation`.

        Returns
        -------
        xarray.DataArray
            DataArray containing the electric intensity of the field-like monitor.
            Data is interpolated to the center locations on Yee grid.
        """
        return self._get_scalar_field(
            field_monitor_name=field_monitor_name, field_name="E", val="abs^2"
        )

    @classmethod
    def mnt_data_from_file(cls, fname: str, mnt_name: str, **parse_obj_kwargs) -> MonitorDataType:
        """Loads data for a specific monitor from a .hdf5 file with data for a ``SimulationData``.

        Parameters
        ----------
        fname : str
            Full path to an hdf5 file containing :class:`.SimulationData` data.
        mnt_name : str, optional
            ``.name`` of the monitor to load the data from.
        **parse_obj_kwargs
            Keyword arguments passed to either pydantic's ``parse_obj`` function when loading model.

        Returns
        -------
        :class:`MonitorData`
            Monitor data corresponding to the ``mnt_name`` type.

        Example
        -------
        >>> field_data = your_simulation_data.from_file(fname='folder/data.hdf5', mnt_name="field") # doctest: +SKIP
        """

        if pathlib.Path(fname).suffix != ".hdf5":
            raise ValueError("'mnt_data_from_file' only works with '.hdf5' files.")

        # open file and ensure it has data
        with h5py.File(fname) as f_handle:
            if "data" not in f_handle:
                raise ValueError(f"could not find data in the supplied file {fname}")

            # get the monitor list from the json string
            json_string = f_handle[JSON_TAG][()]
            json_dict = json.loads(json_string)
            monitor_list = json_dict["simulation"]["monitors"]

            # loop through data
            for monitor_index_str, _mnt_data in f_handle["data"].items():
                # grab the monitor data for this data element
                monitor_dict = monitor_list[int(monitor_index_str)]

                # if a match on the monitor name
                if monitor_dict["name"] == mnt_name:
                    # try to grab the monitor data type
                    monitor_type_str = monitor_dict["type"]
                    if monitor_type_str not in DATA_TYPE_NAME_MAP:
                        raise ValueError(f"Could not find data type '{monitor_type_str}'.")
                    monitor_data_type = DATA_TYPE_NAME_MAP[monitor_type_str]

                    # load the monitor data from the file using the group_path
                    group_path = f"data/{monitor_index_str}"
                    return monitor_data_type.from_file(
                        fname, group_path=group_path, **parse_obj_kwargs
                    )

        raise ValueError(f"No monitor with name '{mnt_name}' found in data file.")

    @staticmethod
    def apply_phase(data: Union[xr.DataArray, xr.Dataset], phase: float = 0.0) -> xr.DataArray:
        """Apply a phase to xarray data."""
        if phase != 0.0:
            if np.any(np.iscomplex(data.values)):
                data *= np.exp(1j * phase)
            else:
                log.warning(
                    f"Non-zero phase of {phase} specified but the data being plotted is "
                    "real-valued. The phase will be ignored in the plot."
                )
        return data

    def plot_field_monitor_data(
        self,
        field_monitor_data: AbstractFieldData,
        field_name: str,
        val: FieldVal = "real",
        scale: PlotScale = "lin",
        eps_alpha: float = 0.2,
        phase: float = 0.0,
        robust: bool = True,
        vmin: float = None,
        vmax: float = None,
        ax: Ax = None,
        shading: str = "flat",
        **sel_kwargs,
    ) -> Ax:
        """Plot the field data for a monitor with simulation plot overlaid.

        Parameters
        ----------
        field_monitor_data : AbstractFieldData
            Field monitor data to plot.
        field_name : str
            Name of ``field`` component to plot (eg. `'Ex'`).
            Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
            magnetic fields, and ``'S'`` for the Poynting vector.
        val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
            Which part of the field to plot.
        scale : Literal['lin', 'dB']
            Plot in linear or logarithmic (dB) scale.
        eps_alpha : float = 0.2
            Opacity of the structure permittivity.
            Must be between 0 and 1 (inclusive).
        phase : float = 0.0
            Optional phase (radians) to apply to the fields.
            Only has an effect on frequency-domain fields.
        robust : bool = True
            If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
            to compute the color limits. This helps in visualizing the field patterns especially
            in the presence of a source.
        vmin : float = None
            The lower bound of data range that the colormap covers. If ``None``, they are
            inferred from the data and other keyword arguments.
        vmax : float = None
            The upper bound of data range that the colormap covers. If ``None``, they are
            inferred from the data and other keyword arguments.
        ax : matplotlib.axes._subplots.Axes = None
            matplotlib axes to plot on, if not specified, one is created.
        shading: str = 'flat'
            Shading argument for Xarray plot method ('flat','nearest','goraud')
        sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
            These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
            frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
            For the plotting to work appropriately, the resulting data after selection must contain
            only two coordinates with len > 1.
            Furthermore, these should be spatial coordinates (``x``, ``y``, or ``z``).

        Returns
        -------
        matplotlib.axes._subplots.Axes
            The supplied or created matplotlib axes.
        """
        # get the DataArray corresponding to the monitor_name and field_name
        # deprecated intensity
        if field_name == "int":
            log.warning(
                "'int' field name is deprecated and will be removed in the future. Please use "
                "field_name='E' and val='abs^2' for the same effect."
            )
            field_name = "E"
            val = "abs^2"

        if field_name in ("E", "H") or field_name[0] == "S":
            # Derived fields
            field_data = self._get_scalar_field_from_data(
                field_monitor_data, field_name, val, phase=phase
            )
        else:
            # Direct field component (e.g. Ex)
            if field_name not in field_monitor_data.field_components:
                raise DataError(f"field_name '{field_name}' not found in data.")
            field_component = field_monitor_data.field_components[field_name]
            field_component.name = field_name
            field_component = self.apply_phase(data=field_component, phase=phase)
            field_data = self._field_component_value(field_component, val)

        if scale == "dB":
            if val == "phase":
                log.warning("Plotting phase component in log scale masks the phase sign.")
            db_factor = {
                ("S", "real"): 10,
                ("S", "imag"): 10,
                ("S", "abs"): 10,
                ("S", "abs^2"): 5,
                ("S", "phase"): 1,
                ("E", "abs^2"): 10,
                ("H", "abs^2"): 10,
            }.get((field_name[0], val), 20)
            field_data = db_factor * np.log10(np.abs(field_data))
            field_data.name += " (dB)"
            cmap_type = "sequential"
        else:
            cmap_type = (
                "cyclic"
                if val == "phase"
                else (
                    "divergent"
                    if len(field_name) == 2 and val in ("real", "imag", "re", "im")
                    else "sequential"
                )
            )

        # interp out any monitor.size==0 dimensions
        monitor = field_monitor_data.monitor
        thin_dims = {
            "xyz"[dim]: monitor.center[dim]
            for dim in range(3)
            if monitor.size[dim] == 0 and "xyz"[dim] not in sel_kwargs
        }
        for axis, pos in thin_dims.items():
            if axis not in field_data.coords:
                continue
            if field_data.coords[axis].size <= 1:
                field_data = field_data.sel(**{axis: pos}, method="nearest")
            else:
                field_data = field_data.interp(**{axis: pos}, kwargs=dict(bounds_error=True))

        # warn about new API changes and replace the values
        if "freq" in sel_kwargs:
            log.warning(
                "'freq' supplied to 'plot_field', frequency selection key renamed to 'f' and "
                "'freq' will error in future release, please update your local script to use "
                "'f=value'."
            )
            sel_kwargs["f"] = sel_kwargs.pop("freq")
        if "time" in sel_kwargs:
            log.warning(
                "'time' supplied to 'plot_field', frequency selection key renamed to 't' and "
                "'time' will error in future release, please update your local script to use "
                "'t=value'."
            )
            sel_kwargs["t"] = sel_kwargs.pop("time")

        # select the extra coordinates out of the data from user-specified kwargs
        for coord_name, coord_val in sel_kwargs.items():
            if (
                field_data.coords[coord_name].size <= 1
                or coord_name == "eme_port_index"
                or coord_name == "eme_cell_index"
                or coord_name == "sweep_index"
                or coord_name == "mode_index"
            ):
                field_data = field_data.sel(**{coord_name: coord_val}, method=None)
            else:
                field_data = field_data.interp(
                    **{coord_name: coord_val}, kwargs=dict(bounds_error=True)
                )

        # before dropping coordinates, check if a frequency can be derived from the data that can
        # be used to plot material permittivity
        if "f" in sel_kwargs:
            freq_eps_eval = sel_kwargs["f"]
        elif "f" in field_data.coords:
            freq_eps_eval = field_data.coords["f"].values[0]
        else:
            freq_eps_eval = None

        field_data = field_data.squeeze(drop=True)
        non_scalar_coords = {name: c for name, c in field_data.coords.items() if c.size > 1}

        # assert the data is valid for plotting
        if len(non_scalar_coords) != 2:
            raise DataError(
                f"Data after selection has {len(non_scalar_coords)} coordinates "
                f"({list(non_scalar_coords.keys())}), "
                "must be 2 spatial coordinates for plotting on plane. "
                "Please add keyword arguments to `plot_field()` to select out the other coords."
            )

        spatial_coords_in_data = {
            coord_name: (coord_name in non_scalar_coords) for coord_name in "xyz"
        }

        if sum(spatial_coords_in_data.values()) != 2:
            raise DataError(
                "All coordinates in the data after selection must be spatial (x, y, z), "
                f" given {non_scalar_coords.keys()}."
            )

        # get the spatial coordinate corresponding to the plane
        planar_coord = [name for name, c in spatial_coords_in_data.items() if c is False][0]
        axis = "xyz".index(planar_coord)
        if planar_coord in field_data.coords:
            position = float(field_data.coords[planar_coord])
        else:
            position = monitor.center[axis]

        return self.plot_scalar_array(
            field_data=field_data,
            axis=axis,
            position=position,
            freq=freq_eps_eval,
            eps_alpha=eps_alpha,
            robust=robust,
            vmin=vmin,
            vmax=vmax,
            cmap_type=cmap_type,
            ax=ax,
            shading=shading,
            infer_intervals=True if shading == "flat" else False,
        )

    def plot_field(
        self,
        field_monitor_name: str,
        field_name: str,
        val: FieldVal = "real",
        scale: PlotScale = "lin",
        eps_alpha: float = 0.2,
        phase: float = 0.0,
        robust: bool = True,
        vmin: float = None,
        vmax: float = None,
        ax: Ax = None,
        shading: str = "flat",
        **sel_kwargs,
    ) -> Ax:
        """Plot the field data for a monitor with simulation plot overlaid.

        Parameters
        ----------
        field_monitor_name : str
            Name of :class:`.FieldMonitor`, :class:`.FieldTimeData`, or :class:`.ModeSolverData`
            to plot.
        field_name : str
            Name of ``field`` component to plot (eg. `'Ex'`).
            Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
            magnetic fields, and ``'S'`` for the Poynting vector.
        val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
            Which part of the field to plot.
        scale : Literal['lin', 'dB']
            Plot in linear or logarithmic (dB) scale.
        eps_alpha : float = 0.2
            Opacity of the structure permittivity.
            Must be between 0 and 1 (inclusive).
        phase : float = 0.0
            Optional phase (radians) to apply to the fields.
            Only has an effect on frequency-domain fields.
        robust : bool = True
            If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
            to compute the color limits. This helps in visualizing the field patterns especially
            in the presence of a source.
        vmin : float = None
            The lower bound of data range that the colormap covers. If ``None``, they are
            inferred from the data and other keyword arguments.
        vmax : float = None
            The upper bound of data range that the colormap covers. If ``None``, they are
            inferred from the data and other keyword arguments.
        ax : matplotlib.axes._subplots.Axes = None
            matplotlib axes to plot on, if not specified, one is created.
        shading: str = 'flat'
            Shading argument for Xarray plot method ('flat','nearest','goraud')
        sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
            These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
            frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
            For the plotting to work appropriately, the resulting data after selection must contain
            only two coordinates with len > 1.
            Furthermore, these should be spatial coordinates (``x``, ``y``, or ``z``).

        Returns
        -------
        matplotlib.axes._subplots.Axes
            The supplied or created matplotlib axes.
        """

        field_monitor_data = self.load_field_monitor(field_monitor_name)

        return self.plot_field_monitor_data(
            field_monitor_data=field_monitor_data,
            field_name=field_name,
            val=val,
            scale=scale,
            eps_alpha=eps_alpha,
            phase=phase,
            robust=robust,
            vmin=vmin,
            vmax=vmax,
            ax=ax,
            shading=shading,
            **sel_kwargs,
        )

    @equal_aspect
    @add_ax_if_none
    def plot_scalar_array(
        self,
        field_data: xr.DataArray,
        axis: Axis,
        position: float,
        freq: float = None,
        eps_alpha: float = 0.2,
        robust: bool = True,
        vmin: float = None,
        vmax: float = None,
        cmap_type: ColormapType = "divergent",
        ax: Ax = None,
        **kwargs,
    ) -> Ax:
        """Plot the field data for a monitor with simulation plot overlaid.

        Parameters
        ----------
        field_data: xr.DataArray
            DataArray with the field data to plot.
            Must be a scalar field.
        axis: Axis
            Axis normal to the plotting plane.
        position: float
            Position along the axis.
        freq: float = None
            Frequency at which the permittivity is evaluated at (if dispersive).
            By default, chooses permittivity as frequency goes to infinity.
        eps_alpha : float = 0.2
            Opacity of the structure permittivity.
            Must be between 0 and 1 (inclusive).
        robust : bool = True
            If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
            to compute the color limits. This helps in visualizing the field patterns especially
            in the presence of a source.
        vmin : float = None
            The lower bound of data range that the colormap covers. If `None`, they are
            inferred from the data and other keyword arguments.
        vmax : float = None
            The upper bound of data range that the colormap covers. If `None`, they are
            inferred from the data and other keyword arguments.
        cmap_type : Literal["divergent", "sequential", "cyclic"] = "divergent"
            Type of color map to use for plotting.
        ax : matplotlib.axes._subplots.Axes = None
            matplotlib axes to plot on, if not specified, one is created.
        **kwargs : Extra arguments to ``DataArray.plot``.

        Returns
        -------
        matplotlib.axes._subplots.Axes
            The supplied or created matplotlib axes.
        """

        # select the cross section data
        interp_kwarg = {"xyz"[axis]: position}

        if cmap_type == "divergent":
            cmap = "RdBu"
            center = 0.0
            eps_reverse = False
        elif cmap_type == "sequential":
            cmap = "magma"
            center = False
            eps_reverse = True
        elif cmap_type == "cyclic":
            cmap = "twilight"
            vmin = -np.pi
            vmax = np.pi
            center = False
            eps_reverse = False

        # plot the field
        xy_coord_labels = list("xyz")
        xy_coord_labels.pop(axis)
        x_coord_label, y_coord_label = xy_coord_labels[0], xy_coord_labels[1]
        field_data.plot(
            ax=ax,
            x=x_coord_label,
            y=y_coord_label,
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            robust=robust,
            center=center,
            cbar_kwargs={"label": field_data.name},
            **kwargs,
        )

        # plot the simulation epsilon
        ax = self.simulation.plot_structures_eps(
            freq=freq,
            cbar=False,
            alpha=eps_alpha,
            reverse=eps_reverse,
            ax=ax,
            **interp_kwarg,
        )

        # set the limits based on the xarray coordinates min and max
        x_coord_values = field_data.coords[x_coord_label]
        y_coord_values = field_data.coords[y_coord_label]
        ax.set_xlim(min(x_coord_values), max(x_coord_values))
        ax.set_ylim(min(y_coord_values), max(y_coord_values))

        return ax


[docs] class SimulationData(AbstractYeeGridSimulationData): """Stores data from a collection of :class:`.Monitor` objects in a :class:`.Simulation`. Notes ----- The ``SimulationData`` objects store a copy of the original :class:`.Simulation`:, so it can be recovered if the ``SimulationData`` is loaded in a new session and the :class:`.Simulation` is no longer in memory. More importantly, the ``SimulationData`` contains a reference to the data for each of the monitors within the original :class:`.Simulation`. This data can be accessed directly using the name given to the monitors initially. Examples -------- Standalone example: >>> import tidy3d as td >>> num_modes = 5 >>> x = [-1,1,3] >>> y = [-2,0,2,4] >>> z = [-3,-1,1,3,5] >>> f = [2e14, 3e14] >>> coords = dict(x=x[:-1], y=y[:-1], z=z[:-1], f=f) >>> grid = td.Grid(boundaries=td.Coords(x=x, y=y, z=z)) >>> scalar_field = td.ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) >>> field_monitor = td.FieldMonitor( ... size=(2,4,6), ... freqs=[2e14, 3e14], ... name='field', ... fields=['Ex'], ... colocate=True, ... ) >>> sim = td.Simulation( ... size=(2, 4, 6), ... grid_spec=td.GridSpec(wavelength=1.0), ... monitors=[field_monitor], ... run_time=2e-12, ... sources=[ ... td.UniformCurrentSource( ... size=(0, 0, 0), ... center=(0, 0.5, 0), ... polarization="Hx", ... source_time=td.GaussianPulse( ... freq0=2e14, ... fwidth=4e13, ... ), ... ) ... ], ... ) >>> field_data = td.FieldData(monitor=field_monitor, Ex=scalar_field, grid_expanded=grid) >>> sim_data = td.SimulationData(simulation=sim, data=(field_data,)) To save and load the :class:`SimulationData` object. .. code-block:: python sim_data.to_file(fname='path/to/file.hdf5') # Save a SimulationData object to a HDF5 file sim_data = SimulationData.from_file(fname='path/to/file.hdf5') # Load a SimulationData object from a HDF5 file. See Also -------- **Notebooks:** * `Quickstart <../../notebooks/StartHere.html>`_: Usage in a basic simulation flow. * `Performing visualization of simulation data <../../notebooks/VizData.html>`_ * `Advanced monitor data manipulation and visualization <../../notebooks/XarrayTutorial.html>`_ """ simulation: Simulation = pd.Field( ..., title="Simulation", description="Original :class:`.Simulation` associated with the data.", ) data: Tuple[annotate_type(MonitorDataType), ...] = pd.Field( ..., title="Monitor Data", description="List of :class:`.MonitorData` instances " "associated with the monitors of the original :class:`.Simulation`.", ) diverged: bool = pd.Field( False, title="Diverged", description="A boolean flag denoting whether the simulation run diverged.", ) @property def final_decay_value(self) -> float: """Returns value of the field decay at the final time step.""" log_str = self.log if log_str is None: raise DataError( "No log string in the SimulationData object, can't find final decay value." ) lines = log_str.split("\n") decay_lines = [line for line in lines if "field decay" in line] final_decay = 1.0 if len(decay_lines) > 0: final_decay_line = decay_lines[-1] final_decay = float(final_decay_line.split("field decay: ")[-1]) return final_decay
[docs] def source_spectrum(self, source_index: int) -> Callable: """Get a spectrum normalization function for a given source index.""" if source_index is None or len(self.simulation.sources) == 0: return np.ones_like source = self.simulation.sources[source_index] source_time = source.source_time times = self.simulation.tmesh dt = self.simulation.dt # plug in mornitor_data frequency domain information def source_spectrum_fn(freqs): """Source amplitude as function of frequency.""" spectrum = source_time.spectrum(times, freqs, dt) # Remove user defined amplitude and phase from the normalization # such that they would still have an effect on the output fields. # In other words, we are only normalizing out the arbitrary part of the spectrum # that depends on things like freq0, fwidth and offset. return spectrum / source_time.amplitude / np.exp(1j * source_time.phase) return source_spectrum_fn
[docs] def renormalize(self, normalize_index: int) -> SimulationData: """Return a copy of the :class:`.SimulationData` with a different source used for the normalization.""" num_sources = len(self.simulation.sources) if normalize_index == self.simulation.normalize_index or num_sources == 0: # already normalized to that index return self.copy() if normalize_index and (normalize_index < 0 or normalize_index >= num_sources): # normalize index out of bounds for source list raise DataError( f"normalize_index {normalize_index} out of bounds for list of sources " f"of length {num_sources}" ) def source_spectrum_fn(freqs): """Normalization function that also removes previous normalization if needed.""" new_spectrum_fn = self.source_spectrum(normalize_index) old_spectrum_fn = self.source_spectrum(self.simulation.normalize_index) return new_spectrum_fn(freqs) / old_spectrum_fn(freqs) # Make a new monitor_data dictionary with renormalized data data_normalized = [mnt_data.normalize(source_spectrum_fn) for mnt_data in self.data] simulation = self.simulation.copy(update=dict(normalize_index=normalize_index)) return self.copy(update=dict(simulation=simulation, data=data_normalized))
[docs] def split_adjoint_data(self: SimulationData, num_mnts_original: int) -> tuple[list, list]: """Split data list into original, adjoint field, and adjoint permittivity.""" data_all = list(self.data) num_mnts_adjoint = (len(data_all) - num_mnts_original) // 2 log.info( f" -> {num_mnts_original} monitors, {num_mnts_adjoint} adjoint field monitors, {num_mnts_adjoint} adjoint eps monitors." ) data_original, data_adjoint = split_list(data_all, index=num_mnts_original) return data_original, data_adjoint
[docs] def split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, SimulationData]: """Split this simulation data into original and fwd data from number of original mnts.""" # split the data and monitors into the original ones & adjoint gradient ones (for 'fwd') data_original, data_fwd = self.split_adjoint_data(num_mnts_original=num_mnts_original) monitors_orig, monitors_fwd = split_list(self.simulation.monitors, index=num_mnts_original) # reconstruct the simulation data for the user, using original sim, and data for original mnts sim_original = self.simulation.updated_copy(monitors=monitors_orig) sim_data_original = self.updated_copy( simulation=sim_original, data=data_original, deep=False, ) # construct the 'forward' simulation and its data, which is only used for for gradient calc. sim_fwd = self.simulation.updated_copy(monitors=monitors_fwd) sim_data_fwd = self.updated_copy( simulation=sim_fwd, data=data_fwd, deep=False, ) return sim_data_original, sim_data_fwd
[docs] def make_adjoint_sim( self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor], ) -> Simulation | None: """Make the adjoint simulation from the original simulation and the VJP-containing data.""" sim_original = self.simulation # generate the adjoint sources {mnt_name : list[Source]} sources_adj_dict = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) adj_srcs = [] for src_list in sources_adj_dict.values(): adj_srcs += list(src_list) if not any(adj_srcs): return None adjoint_source_info = self.process_adjoint_sources(adj_srcs=adj_srcs) # grab boundary conditions with flipped Bloch vectors (for adjoint) bc_adj = sim_original.boundary_spec.flipped_bloch_vecs # fields to update the 'fwd' simulation with to make it 'adj' sim_adj_update_dict = dict( sources=adjoint_source_info.sources, boundary_spec=bc_adj, monitors=adjoint_monitors, post_norm=adjoint_source_info.post_norm, ) if not adjoint_source_info.normalize_sim: sim_adj_update_dict["normalize_index"] = None # set the ADJ grid spec wavelength to the original wavelength (for same meshing) grid_spec_original = sim_original.grid_spec if sim_original.sources and grid_spec_original.wavelength is None: wavelength_original = grid_spec_original.wavelength_from_sources(sim_original.sources) grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original) sim_adj_update_dict["grid_spec"] = grid_spec_adj return sim_original.updated_copy(**sim_adj_update_dict)
[docs] def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceType]: """Generate all of the non-zero sources for the adjoint simulation given the VJP data.""" # map of index into 'self.data' to the list of datasets we need adjoint sources for adj_src_map = defaultdict(list) for _, index, dataset_name in data_vjp_paths: adj_src_map[index].append(dataset_name) # gather a dict of adjoint sources for every monitor data in the VJP that needs one sources_adj_all = defaultdict(list) for data_index, dataset_names in adj_src_map.items(): mnt_data = self.data[data_index] sources_adj = mnt_data.make_adjoint_sources( dataset_names=dataset_names, fwidth=self.fwidth_adj ) sources_adj_all[mnt_data.monitor.name] = sources_adj return sources_adj_all
@property def fwidth_adj(self) -> float: # fwidth of forward pass, try as default for adjoint normalize_index_fwd = self.simulation.normalize_index or 0 return self.simulation.sources[normalize_index_fwd].source_time.fwidth
[docs] def process_adjoint_sources(self, adj_srcs: list[SourceType]) -> AdjointSourceInfo: """Compute list of final sources along with a post run normalization for adj fields.""" # dictionary mapping hash of sources with same freq dependence to list of time-dependencies hashes_to_sources = defaultdict(None) hashes_to_src_times = defaultdict(list) tmp_src_time = GaussianPulse(freq0=C_0, fwidth=inf) for src in adj_srcs: tmp_src = src.updated_copy(source_time=tmp_src_time) tmp_src_hash = tmp_src._hash_self() hashes_to_sources[tmp_src_hash] = src hashes_to_src_times[tmp_src_hash].append(src.source_time) num_ports = len(hashes_to_src_times) unique_freqs = {src.source_time.freq0 for src in adj_srcs} num_unique_freqs = len(unique_freqs) # next, figure out which treatment / normalization to apply if num_unique_freqs == 1: log.info("Adjoint source creation: one unique frequency, no normalization.") freqs_adj = self.simulation.freqs_adjoint # if many adjoint freqs, but only 1 unique, need to mask out the non-contributors if len(freqs_adj) > 1: coords = dict(f=freqs_adj) data = [1 if f == tuple(unique_freqs)[0] else 0 for f in freqs_adj] post_norm = xr.DataArray(data, coords=coords) return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=True) return AdjointSourceInfo(sources=adj_srcs, post_norm=1.0, normalize_sim=True) if num_ports == 1 and len(adj_srcs) == num_unique_freqs: log.info("Adjoint source creation: one spatial port detected.") adj_srcs, post_norm = self.process_adjoint_sources_broadband(adj_srcs) return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=True) # if several spatial ports and several frequencies, try to fit log.info("Adjoint source creation: trying multifrequency fit.") adj_srcs, post_norm = self.process_adjoint_sources_fit( adj_srcs=adj_srcs, hashes_to_src_times=hashes_to_src_times, hashes_to_sources=hashes_to_sources, ) return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=False)
""" SIMPLE APPROACH """
[docs] def process_adjoint_sources_broadband( self, adj_srcs: list[SourceType] ) -> tuple[list[SourceType], xr.DataArray]: """Process adjoint sources for the case of several sources at the same freq.""" src_broadband = self._make_broadband_source(adj_srcs=adj_srcs) post_norm_amps = self._make_post_norm_amps(adj_srcs=adj_srcs) log.info( "Several adjoint sources, from one monitor. " "Only difference between them is the source time. " "Constructing broadband adjoint source and performing post-run normalization " f"of fields with {len(post_norm_amps)} frequencies." ) return [src_broadband], post_norm_amps
def _make_broadband_source(self, adj_srcs: list[SourceType]) -> SourceType: """Make a broadband source for a set of adjoint sources.""" source_index = self.simulation.normalize_index or 0 src_time_base = self.simulation.sources[source_index].source_time.copy() src_broadband = adj_srcs[0].updated_copy(source_time=src_time_base) return src_broadband @staticmethod def _make_post_norm_amps(adj_srcs: list[SourceType]) -> xr.DataArray: """Make a ``DataArray`` containing the complex amplitudes to multiply with adjoint field.""" freqs = [] amps_complex = [] for src in adj_srcs: src_time = src.source_time freqs.append(src_time.freq0) amp_complex = src_time.amplitude * np.exp(1j * src_time.phase) amps_complex.append(amp_complex) coords = dict(f=freqs) amps_complex = np.array(amps_complex) return xr.DataArray(amps_complex, coords=coords) """ FITTING APPROACH """
[docs] def process_adjoint_sources_fit( self, adj_srcs: list[SourceType], hashes_to_src_times: dict[str, GaussianPulse], hashes_to_sources: dict[str, list[SourceType]], ) -> tuple[list[SourceType], float]: """Process the adjoint sources using a least squared fit to the derivative data.""" raise NotImplementedError( "Can't perform multi-frequency autograd with several adjoint sources yet. " "In the meantime, please construct a single 'Simulation' per output data " "(can be multi-frequency) and run in parallel using 'web.run_async'. For example, " "if your problem has 'P' outuput ports, e.g. waveguides, please make a 'Simulation' " "corresponding to the objective function contribution at each port." ) # new adjoint sources new_adj_srcs = [] for src_hash, source_times in hashes_to_src_times.items(): src = hashes_to_sources[src_hash] new_sources = self.correct_adjoint_sources( src=src, fwidth=self.fwidth_adj, source_times=source_times ) new_adj_srcs += new_sources # compute amplitudes of each adjoint source, and the norm adj_src_amps = [] for src in new_adj_srcs: amp = src.source_time.amp_complex adj_src_amps.append(amp) norm_amps = np.linalg.norm(adj_src_amps) # normalize all of the adjoint sources by this and return the normalization term used adj_srcs_norm = [] for src in new_adj_srcs: src_time = src.source_time amp = src_time.amp_complex src_time_norm = src_time.from_amp_complex(amp=amp / norm_amps) src_nrm = src.updated_copy(source_time=src_time_norm) adj_srcs_norm.append(src_nrm) return adj_srcs_norm, norm_amps
[docs] def correct_adjoint_sources( self, src: SourceType, fwidth: float, source_times: list[GaussianPulse] ) -> [SourceType]: """Corret a set of spectrally overlapping adjoint sources to give correct E_adj.""" freqs = [st.freq0 for st in source_times] times = self.simulation.tmesh dt = self.simulation.dt def get_spectrum(source_time: GaussianPulse, freqs: list[float]) -> complex: """Get the spectrum of a source time at a given frequency.""" return source_time.spectrum(times=times, freqs=freqs, dt=dt) # compute matrix coupling the spectra of Gaussian pulses centered at each adjoint freq def get_coupling_matrix(fwidth: float) -> np.ndarray: """Matrix coupling the spectra of Gaussian pulses centered at each adjoint freq.""" return np.array( [ get_spectrum( source_time=GaussianPulse(freq0=source_time.freq0, fwidth=fwidth), freqs=freqs, ) for source_time in source_times ] ).T amps_adj = np.array([src_time.amp_complex for src_time in source_times]) # compute the corrected set of amps to inject at each freq to take coupling into account def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]: """New set of new adjoint source amps that generate the desired response at each f.""" J_coupling = get_coupling_matrix(fwidth=fwidth) amps_adj_new, *info = np.linalg.lstsq(J_coupling, amps_adj, rcond=None) # amps_adj_new = np.linalg.solve(J_coupling, amps_adj) residual = J_coupling @ amps_adj_new - amps_adj residual_norm = np.linalg.norm(residual) / np.linalg.norm(amps_adj) return amps_adj_new, residual_norm # get the corrected amplitudes amps_corrected, res_norm = get_amps_corrected(self.fwidth_adj) if res_norm > RESIDUAL_CUTOFF_ADJOINT: raise ValueError( f"Residual of {res_norm:.5e} found when trying to fit adjoint source spectrum. " f"This is above our accuracy cutoff of {RESIDUAL_CUTOFF_ADJOINT:.5e} and therefore " "we are not able to process this adjoint simulation in a broadband way. " "To fix, split your simulation into a set of simulations, one for each port, and " "run parallel, broadband simulations using 'web.run_async'. " ) # construct the new adjoint sources with the corrected amplitudes src_times_corrected = [ src_time.from_amp_complex(amp=amp, fwidth=self.fwidth_adj) for src_time, amp in zip(source_times, amps_corrected) ] srcs_corrected = [] for src_time in src_times_corrected: src_new = src.updated_copy(source_time=src_time) srcs_corrected.append(src_new) return srcs_corrected
[docs] def get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorDataType: """Grab the field or permittivity data for a given structure index.""" monitor_name = Structure.get_monitor_name(index=structure_index, data_type=data_type) return self[monitor_name]
[docs] def to_mat_file(self, fname: str, **kwargs): """Output the ``SimulationData`` object as ``.mat`` MATLAB file. Parameters ---------- fname : str Full path to the output file. Should include ``.mat`` file extension. **kwargs : dict, optional Extra arguments to ``scipy.io.savemat``: see ``scipy`` documentation for more detail. Example ------- >>> simData.to_mat_file('/path/to/file/data.mat') # doctest: +SKIP """ # Check .mat file extension is given extension = pathlib.Path(fname).suffixes[0].lower() if len(extension) == 0: raise FileError(f"File '{fname}' missing extension.") if extension != ".mat": raise FileError(f"File '{fname}' should have a .mat extension.") # Handle m_dict in kwargs if "m_dict" in kwargs: raise ValueError( "'m_dict' is automatically determined by 'to_mat_file', can't pass to 'savemat'." ) # Get SimData object as dictionary sim_dict = self.dict() # Remove NoneType values from dict # Built from theory discussed in https://github.com/scipy/scipy/issues/3488 modified_sim_dict = replace_values(sim_dict, None, []) try: from scipy.io import savemat savemat(fname, modified_sim_dict, **kwargs) except Exception as e: raise ValueError( "Could not save supplied 'SimulationData' to file. As this is an experimental feature, we may not be able to support the contents of your dataset. If you receive this error, please feel free to raise an issue on our front end repository so we can investigate." ) from e