"""Simulation Level Data"""
from __future__ import annotations
import json
import pathlib
import re
from abc import ABC
from collections import defaultdict
from typing import Callable, Optional, Union
import h5py
import numpy as np
import pydantic.v1 as pd
import xarray as xr
from tidy3d.components.autograd.utils import split_list
from tidy3d.components.base import JSON_TAG, Tidy3dBaseModel, cached_property
from tidy3d.components.base_sim.data.sim_data import AbstractSimulationData
from tidy3d.components.file_util import replace_values
from tidy3d.components.monitor import Monitor
from tidy3d.components.simulation import Simulation
from tidy3d.components.source.current import CustomCurrentSource
from tidy3d.components.source.time import GaussianPulse
from tidy3d.components.source.utils import SourceType
from tidy3d.components.structure import Structure
from tidy3d.components.types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type
from tidy3d.components.viz import add_ax_if_none, equal_aspect
from tidy3d.constants import C_0, inf
from tidy3d.exceptions import DataError, FileError, Tidy3dKeyError
from tidy3d.log import log
from .data_array import FreqDataArray, TimeDataArray
from .monitor_data import AbstractFieldData, FieldTimeData, MonitorDataType, MonitorDataTypes
DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes}
# maps monitor type (string) to the class of the corresponding data
DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes}
# residuals below this are considered good fits for broadband adjoint source creation
RESIDUAL_CUTOFF_ADJOINT = 1e-6
# for adjoint source, the minimum number of FWIDTH between the center frequency and zero
NUM_ADJOINT_FWIDTH_TO_ZERO = 3
# for broadband adjoint source, the minimum number of FWIDTH to reach the lowest frequency
# that is covered by the broadband pulse
NUM_ADJOINT_FWIDTH_TO_FMIN = 0.5
class AdjointSourceInfo(Tidy3dBaseModel):
"""Stores information about the adjoint sources to pass to autograd pipeline."""
sources: tuple[annotate_type(SourceType), ...] = pd.Field(
...,
title="Adjoint Sources",
description="Set of processed sources to include in the adjoint simulation.",
)
post_norm: Union[float, FreqDataArray] = pd.Field(
...,
title="Post Normalization Values",
description="Factor to multiply the adjoint fields by after running "
"given the adjoint source pipeline used.",
)
normalize_sim: bool = pd.Field(
...,
title="Normalize Adjoint Simulation",
description="Whether the adjoint simulation needs to be normalized "
"given the adjoint source pipeline used.",
)
[docs]
class AbstractYeeGridSimulationData(AbstractSimulationData, ABC):
"""Data from an :class:`.AbstractYeeGridSimulation` involving
electromagnetic fields on a Yee grid.
Notes
-----
The ``SimulationData`` objects store a copy of the original :class:`.Simulation`:, so it can be recovered if the
``SimulationData`` is loaded in a new session and the :class:`.Simulation` is no longer in memory.
More importantly, the ``SimulationData`` contains a reference to the data for each of the monitors within the
original :class:`.Simulation`. This data can be accessed directly using the name given to the monitors initially.
"""
[docs]
def load_field_monitor(self, monitor_name: str) -> AbstractFieldData:
"""Load monitor and raise exception if not a field monitor."""
mon_data = self[monitor_name]
if not isinstance(mon_data, AbstractFieldData):
raise DataError(
f"data for monitor '{monitor_name}' does not contain field data "
f"as it is a '{type(mon_data)}'."
)
return mon_data
[docs]
def at_centers(self, field_monitor_name: str) -> xr.Dataset:
"""Return xarray.Dataset representation of field monitor data colocated at Yee cell centers.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
Returns
-------
xarray.Dataset
Dataset containing all of the fields in the data interpolated to center locations on
the Yee grid.
"""
monitor_data = self.load_field_monitor(field_monitor_name)
return monitor_data.at_coords(monitor_data.colocation_centers)
def _at_boundaries(self, monitor_data: xr.Dataset) -> xr.Dataset:
"""Return xarray.Dataset representation of field monitor data colocated at Yee cell
boundaries.
Parameters
----------
monitor_data : xr.Dataset
Monitor data to be co-located.
Returns
-------
xarray.Dataset
Dataset containing all of the fields in the data interpolated to boundary locations on
the Yee grid.
"""
if monitor_data.monitor.colocate:
# TODO: this still errors if monitor_data.colocate is allowed to be ``True`` in the
# adjoint plugin, and the monitor data is tracked in a gradient computation. It seems
# interpolating does something to the arrays that makes the JAX chain work.
return monitor_data.package_colocate_results(monitor_data.field_components)
# colocate to monitor grid boundaries
return monitor_data.at_coords(monitor_data.colocation_boundaries)
[docs]
def at_boundaries(self, field_monitor_name: str) -> xr.Dataset:
"""Return xarray.Dataset representation of field monitor data colocated at Yee cell
boundaries.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
Returns
-------
xarray.Dataset
Dataset containing all of the fields in the data interpolated to boundary locations on
the Yee grid.
"""
# colocate to monitor grid boundaries
return self._at_boundaries(self.load_field_monitor(field_monitor_name))
def _get_poynting_vector(self, field_monitor_data: AbstractFieldData) -> xr.Dataset:
"""return ``xarray.Dataset`` of the Poynting vector at Yee cell centers.
Calculated values represent the instantaneous Poynting vector for time-domain fields and the
complex vector for frequency-domain: ``S = 1/2 E Γ conj(H)``.
Only the available components are returned, e.g., if the indicated monitor doesn't include
field component `"Ex"`, then `"Sy"` and `"Sz"` will not be calculated.
Parameters
----------
field_monitor_data: AbstractFieldData
Field monitor data from which to extract Poynting vector.
Returns
-------
xarray.DataArray
DataArray containing the Poynting vector calculated based on the field components
colocated at the center locations of the Yee grid.
"""
field_dataset = self._at_boundaries(field_monitor_data)
time_domain = isinstance(field_monitor_data, FieldTimeData)
poynting_components = {}
dims = "xyz"
for axis, dim in enumerate(dims):
dim_1 = dims[axis - 2]
dim_2 = dims[axis - 1]
required_components = [f + c for f in "EH" for c in (dim_1, dim_2)]
if not all(field_cmp in field_dataset for field_cmp in required_components):
continue
e_1 = field_dataset.data_vars["E" + dim_1]
e_2 = field_dataset.data_vars["E" + dim_2]
h_1 = field_dataset.data_vars["H" + dim_1]
h_2 = field_dataset.data_vars["H" + dim_2]
poynting_components["S" + dim] = (
e_1 * h_2 - e_2 * h_1
if time_domain
else 0.5 * (e_1 * h_2.conj() - e_2 * h_1.conj())
)
# 2D monitors have grid correction factors that can be different from 1. For Poynting,
# it is always the product of a primal-located field and dual-located field, so the
# total grid correction factor is the product of the two
grid_correction = (
field_monitor_data.grid_dual_correction * field_monitor_data.grid_primal_correction
)
poynting_components["S" + dim] *= grid_correction
return xr.Dataset(poynting_components)
[docs]
def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:
"""return ``xarray.Dataset`` of the Poynting vector at Yee cell centers.
Calculated values represent the instantaneous Poynting vector for time-domain fields and the
complex vector for frequency-domain: ``S = 1/2 E Γ conj(H)``.
Only the available components are returned, e.g., if the indicated monitor doesn't include
field component `"Ex"`, then `"Sy"` and `"Sz"` will not be calculated.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
Returns
-------
xarray.DataArray
DataArray containing the Poynting vector calculated based on the field components
colocated at the center locations of the Yee grid.
"""
field_monitor_data = self.load_field_monitor(field_monitor_name)
return self._get_poynting_vector(field_monitor_data=field_monitor_data)
def _get_scalar_field(
self,
field_monitor_name: str,
field_name: str,
val: FieldVal,
phase: float = 0.0,
):
"""return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
field_name : str
Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
phase : float = 0.0
Optional phase to apply to result
Returns
-------
xarray.DataArray
DataArray containing the electric intensity of the field-like monitor.
Data is interpolated to the center locations on Yee grid.
"""
field_monitor_data = self.load_field_monitor(field_monitor_name)
return self._get_scalar_field_from_data(
field_monitor_data, field_name=field_name, val=val, phase=phase
)
def _get_scalar_field_from_data(
self,
field_monitor_data: AbstractFieldData,
field_name: str,
val: FieldVal,
phase: float = 0.0,
):
"""return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.
Parameters
----------
field_monitor_data : AbstractFieldData
Field monitor data from which to extract scalar field.
field_name : str
Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
phase : float = 0.0
Optional phase to apply to result
Returns
-------
xarray.DataArray
DataArray containing the electric intensity of the field-like monitor.
Data is interpolated to the center locations on Yee grid.
"""
if field_name[0] == "S":
dataset = self._get_poynting_vector(field_monitor_data)
if len(field_name) > 1:
if field_name in dataset:
derived_data = dataset[field_name]
derived_data.name = field_name
return self._field_component_value(derived_data, val)
raise Tidy3dKeyError(f"Poynting component {field_name} not available")
else:
dataset = self._at_boundaries(field_monitor_data)
dataset = self.apply_phase(data=dataset, phase=phase)
if field_name in ("E", "H", "S"):
# Gather vector components
required_components = [field_name + c for c in "xyz"]
if not all(field_cmp in dataset for field_cmp in required_components):
raise DataError(
f"Field monitor must contain '{field_name}x', '{field_name}y', and "
f"'{field_name}z' fields to compute '{field_name}'."
)
field_components = (dataset[c] for c in required_components)
# Apply the requested transformation
val = val.lower()
if val in ("real", "re"):
derived_data = sum(f.real**2 for f in field_components) ** 0.5
derived_data.name = f"|Re{{{field_name}}}|"
elif val in ("imag", "im"):
derived_data = sum(f.imag**2 for f in field_components) ** 0.5
derived_data.name = f"|Im{{{field_name}}}|"
elif val == "abs":
derived_data = sum(abs(f) ** 2 for f in field_components) ** 0.5
derived_data.name = f"|{field_name}|"
elif val == "abs^2":
derived_data = sum(abs(f) ** 2 for f in field_components)
if hasattr(derived_data, "name"):
derived_data.name = f"|{field_name}|Β²"
elif val == "phase":
raise Tidy3dKeyError(f"Phase is not defined for complex vector {field_name}")
else:
raise Tidy3dKeyError(
f"'val' of {val} not supported. "
"Must be one of 'real', 'imag', 'abs', 'abs^2', or 'phase'."
)
return derived_data
raise Tidy3dKeyError(
f"Derived field name must be one of 'E', 'H', 'S', 'Sx', 'Sy', or 'Sz', received "
f"'{field_name}'."
)
[docs]
def get_intensity(self, field_monitor_name: str) -> xr.DataArray:
"""return `xarray.DataArray` of the intensity of a field monitor at Yee cell centers.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
Returns
-------
xarray.DataArray
DataArray containing the electric intensity of the field-like monitor.
Data is interpolated to the center locations on Yee grid.
"""
return self._get_scalar_field(
field_monitor_name=field_monitor_name, field_name="E", val="abs^2"
)
[docs]
@classmethod
def mnt_data_from_file(cls, fname: str, mnt_name: str, **parse_obj_kwargs) -> MonitorDataType:
"""Loads data for a specific monitor from a .hdf5 file with data for a ``SimulationData``.
Parameters
----------
fname : str
Full path to an hdf5 file containing :class:`.SimulationData` data.
mnt_name : str, optional
``.name`` of the monitor to load the data from.
**parse_obj_kwargs
Keyword arguments passed to either pydantic's ``parse_obj`` function when loading model.
Returns
-------
:class:`MonitorData`
Monitor data corresponding to the ``mnt_name`` type.
Example
-------
>>> field_data = your_simulation_data.from_file(fname='folder/data.hdf5', mnt_name="field") # doctest: +SKIP
"""
if pathlib.Path(fname).suffix != ".hdf5":
raise ValueError("'mnt_data_from_file' only works with '.hdf5' files.")
# open file and ensure it has data
with h5py.File(fname) as f_handle:
if "data" not in f_handle:
raise ValueError(f"could not find data in the supplied file {fname}")
# get the monitor list from the json string
json_string = f_handle[JSON_TAG][()]
json_dict = json.loads(json_string)
monitor_list = json_dict["simulation"]["monitors"]
# loop through data
for monitor_index_str, _mnt_data in f_handle["data"].items():
# grab the monitor data for this data element
monitor_dict = monitor_list[int(monitor_index_str)]
# if a match on the monitor name
if monitor_dict["name"] == mnt_name:
# try to grab the monitor data type
monitor_type_str = monitor_dict["type"]
if monitor_type_str not in DATA_TYPE_NAME_MAP:
raise ValueError(f"Could not find data type '{monitor_type_str}'.")
monitor_data_type = DATA_TYPE_NAME_MAP[monitor_type_str]
# load the monitor data from the file using the group_path
group_path = f"data/{monitor_index_str}"
return monitor_data_type.from_file(
fname, group_path=group_path, **parse_obj_kwargs
)
raise ValueError(f"No monitor with name '{mnt_name}' found in data file.")
[docs]
@staticmethod
def apply_phase(data: Union[xr.DataArray, xr.Dataset], phase: float = 0.0) -> xr.DataArray:
"""Apply a phase to xarray data."""
if phase != 0.0:
if np.any(np.iscomplex(data.values)):
data *= np.exp(1j * phase)
else:
log.warning(
f"Non-zero phase of {phase} specified but the data being plotted is "
"real-valued. The phase will be ignored in the plot."
)
return data
[docs]
def plot_field_monitor_data(
self,
field_monitor_data: AbstractFieldData,
field_name: str,
val: FieldVal = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
phase: float = 0.0,
robust: bool = True,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
ax: Ax = None,
shading: str = "flat",
**sel_kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlaid.
Parameters
----------
field_monitor_data : AbstractFieldData
Field monitor data to plot.
field_name : str
Name of ``field`` component to plot (eg. `'Ex'`).
Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
magnetic fields, and ``'S'`` for the Poynting vector.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
scale : Literal['lin', 'dB']
Plot in linear or logarithmic (dB) scale.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
phase : float = 0.0
Optional phase (radians) to apply to the fields.
Only has an effect on frequency-domain fields.
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
in the presence of a source.
vmin : float = None
The lower bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
vmax : float = None
The upper bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
shading: str = 'flat'
Shading argument for Xarray plot method ('flat','nearest','goraud')
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
For the plotting to work appropriately, the resulting data after selection must contain
only two coordinates with len > 1.
Furthermore, these should be spatial coordinates (``x``, ``y``, or ``z``).
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
# get the DataArray corresponding to the monitor_name and field_name
# deprecated intensity
if field_name == "int":
log.warning(
"'int' field name is deprecated and will be removed in the future. Please use "
"field_name='E' and val='abs^2' for the same effect."
)
field_name = "E"
val = "abs^2"
if field_name in ("E", "H") or field_name[0] == "S":
# Derived fields
field_data = self._get_scalar_field_from_data(
field_monitor_data, field_name, val, phase=phase
)
else:
# Direct field component (e.g. Ex)
if field_name not in field_monitor_data.field_components:
raise DataError(f"field_name '{field_name}' not found in data.")
field_component = field_monitor_data.field_components[field_name]
field_component.name = field_name
field_component = self.apply_phase(data=field_component, phase=phase)
field_data = self._field_component_value(field_component, val)
if scale == "dB":
if val == "phase":
log.warning("Plotting phase component in log scale masks the phase sign.")
db_factor = {
("S", "real"): 10,
("S", "imag"): 10,
("S", "abs"): 10,
("S", "abs^2"): 5,
("S", "phase"): 1,
("E", "abs^2"): 10,
("H", "abs^2"): 10,
}.get((field_name[0], val), 20)
field_data = db_factor * np.log10(np.abs(field_data))
field_data.name += " (dB)"
cmap_type = "sequential"
else:
cmap_type = (
"cyclic"
if val == "phase"
else (
"divergent"
if len(field_name) == 2 and val in ("real", "imag", "re", "im")
else "sequential"
)
)
# interp out any monitor.size==0 dimensions
monitor = field_monitor_data.monitor
thin_dims = {
"xyz"[dim]: monitor.center[dim]
for dim in range(3)
if monitor.size[dim] == 0 and "xyz"[dim] not in sel_kwargs
}
for axis, pos in thin_dims.items():
if axis not in field_data.coords:
continue
if field_data.coords[axis].size <= 1:
field_data = field_data.sel(**{axis: pos}, method="nearest")
else:
field_data = field_data.interp(**{axis: pos}, kwargs={"bounds_error": True})
# warn about new API changes and replace the values
if "freq" in sel_kwargs:
log.warning(
"'freq' supplied to 'plot_field', frequency selection key renamed to 'f' and "
"'freq' will error in future release, please update your local script to use "
"'f=value'."
)
sel_kwargs["f"] = sel_kwargs.pop("freq")
if "time" in sel_kwargs:
log.warning(
"'time' supplied to 'plot_field', frequency selection key renamed to 't' and "
"'time' will error in future release, please update your local script to use "
"'t=value'."
)
sel_kwargs["t"] = sel_kwargs.pop("time")
# select the extra coordinates out of the data from user-specified kwargs
for coord_name, coord_val in sel_kwargs.items():
interp_val = np.array(coord_val)
if interp_val.size == 1:
interp_val = interp_val.item()
if (
field_data.coords[coord_name].size <= 1
or coord_name == "eme_port_index"
or coord_name == "eme_cell_index"
or coord_name == "sweep_index"
or coord_name == "mode_index"
):
field_data = field_data.sel(**{coord_name: interp_val}, method=None)
else:
field_data = field_data.interp(
**{coord_name: interp_val}, kwargs={"bounds_error": True}
)
# before dropping coordinates, check if a frequency can be derived from the data that can
# be used to plot material permittivity
if "f" in sel_kwargs:
freq_eps_eval = sel_kwargs["f"]
elif "f" in field_data.coords:
freq_eps_eval = field_data.coords["f"].values[0]
else:
freq_eps_eval = None
field_data = field_data.squeeze(drop=True)
non_scalar_coords = {name: c for name, c in field_data.coords.items() if c.size > 1}
# assert the data is valid for plotting
if len(non_scalar_coords) != 2:
raise DataError(
f"Data after selection has {len(non_scalar_coords)} coordinates "
f"({list(non_scalar_coords.keys())}), "
"must be 2 spatial coordinates for plotting on plane. "
"Please add keyword arguments to `plot_field()` to select out the other coords."
)
spatial_coords_in_data = {
coord_name: (coord_name in non_scalar_coords) for coord_name in "xyz"
}
if sum(spatial_coords_in_data.values()) != 2:
raise DataError(
"All coordinates in the data after selection must be spatial (x, y, z), "
f" given {non_scalar_coords.keys()}."
)
# get the spatial coordinate corresponding to the plane
planar_coord = [name for name, c in spatial_coords_in_data.items() if c is False][0]
axis = "xyz".index(planar_coord)
if planar_coord in field_data.coords:
position = float(field_data.coords[planar_coord])
else:
position = monitor.center[axis]
return self.plot_scalar_array(
field_data=field_data,
axis=axis,
position=position,
freq=freq_eps_eval,
eps_alpha=eps_alpha,
robust=robust,
vmin=vmin,
vmax=vmax,
cmap_type=cmap_type,
ax=ax,
shading=shading,
infer_intervals=True if shading == "flat" else False,
)
[docs]
def plot_field(
self,
field_monitor_name: str,
field_name: str,
val: FieldVal = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
phase: float = 0.0,
robust: bool = True,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
ax: Ax = None,
shading: str = "flat",
**sel_kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlaid.
Parameters
----------
field_monitor_name : str
Name of :class:`.FieldMonitor`, :class:`.FieldTimeData`, or :class:`.ModeSolverData`
to plot.
field_name : str
Name of ``field`` component to plot (eg. `'Ex'`).
Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
magnetic fields, and ``'S'`` for the Poynting vector.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
scale : Literal['lin', 'dB']
Plot in linear or logarithmic (dB) scale.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
phase : float = 0.0
Optional phase (radians) to apply to the fields.
Only has an effect on frequency-domain fields.
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
in the presence of a source.
vmin : float = None
The lower bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
vmax : float = None
The upper bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
shading: str = 'flat'
Shading argument for Xarray plot method ('flat','nearest','goraud')
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
For the plotting to work appropriately, the resulting data after selection must contain
only two coordinates with len > 1.
Furthermore, these should be spatial coordinates (``x``, ``y``, or ``z``).
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
field_monitor_data = self.load_field_monitor(field_monitor_name)
return self.plot_field_monitor_data(
field_monitor_data=field_monitor_data,
field_name=field_name,
val=val,
scale=scale,
eps_alpha=eps_alpha,
phase=phase,
robust=robust,
vmin=vmin,
vmax=vmax,
ax=ax,
shading=shading,
**sel_kwargs,
)
[docs]
@equal_aspect
@add_ax_if_none
def plot_scalar_array(
self,
field_data: xr.DataArray,
axis: Axis,
position: float,
freq: Optional[float] = None,
eps_alpha: float = 0.2,
robust: bool = True,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
cmap_type: ColormapType = "divergent",
ax: Ax = None,
**kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlaid.
Parameters
----------
field_data: xr.DataArray
DataArray with the field data to plot.
Must be a scalar field.
axis: Axis
Axis normal to the plotting plane.
position: float
Position along the axis.
freq: float = None
Frequency at which the permittivity is evaluated at (if dispersive).
By default, chooses permittivity as frequency goes to infinity.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
in the presence of a source.
vmin : float = None
The lower bound of data range that the colormap covers. If `None`, they are
inferred from the data and other keyword arguments.
vmax : float = None
The upper bound of data range that the colormap covers. If `None`, they are
inferred from the data and other keyword arguments.
cmap_type : Literal["divergent", "sequential", "cyclic"] = "divergent"
Type of color map to use for plotting.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
**kwargs : Extra arguments to ``DataArray.plot``.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
# select the cross section data
interp_kwarg = {"xyz"[axis]: position}
if cmap_type == "divergent":
cmap = "RdBu"
center = 0.0
eps_reverse = False
elif cmap_type == "sequential":
cmap = "magma"
center = False
eps_reverse = True
elif cmap_type == "cyclic":
cmap = "twilight"
vmin = -np.pi
vmax = np.pi
center = False
eps_reverse = False
# plot the field
xy_coord_labels = list("xyz")
xy_coord_labels.pop(axis)
x_coord_label, y_coord_label = xy_coord_labels[0], xy_coord_labels[1]
field_data.plot(
ax=ax,
x=x_coord_label,
y=y_coord_label,
cmap=cmap,
vmin=vmin,
vmax=vmax,
robust=robust,
center=center,
cbar_kwargs={"label": field_data.name},
**kwargs,
)
# plot the simulation epsilon
ax = self.simulation.plot_structures_eps(
freq=freq,
cbar=False,
alpha=eps_alpha,
reverse=eps_reverse,
ax=ax,
**interp_kwarg,
)
# set the limits based on the xarray coordinates min and max
x_coord_values = field_data.coords[x_coord_label]
y_coord_values = field_data.coords[y_coord_label]
ax.set_xlim(min(x_coord_values), max(x_coord_values))
ax.set_ylim(min(y_coord_values), max(y_coord_values))
return ax
[docs]
class SimulationData(AbstractYeeGridSimulationData):
"""Stores data from a collection of :class:`.Monitor` objects in a :class:`.Simulation`.
Notes
-----
The ``SimulationData`` objects store a copy of the original :class:`.Simulation`:, so it can be recovered if the
``SimulationData`` is loaded in a new session and the :class:`.Simulation` is no longer in memory.
More importantly, the ``SimulationData`` contains a reference to the data for each of the monitors within the
original :class:`.Simulation`. This data can be accessed directly using the name given to the monitors initially.
Examples
--------
Standalone example:
>>> import tidy3d as td
>>> num_modes = 5
>>> x = [-1,1,3]
>>> y = [-2,0,2,4]
>>> z = [-3,-1,1,3,5]
>>> f = [2e14, 3e14]
>>> coords = dict(x=x[:-1], y=y[:-1], z=z[:-1], f=f)
>>> grid = td.Grid(boundaries=td.Coords(x=x, y=y, z=z))
>>> scalar_field = td.ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords)
>>> field_monitor = td.FieldMonitor(
... size=(2,4,6),
... freqs=[2e14, 3e14],
... name='field',
... fields=['Ex'],
... colocate=True,
... )
>>> sim = td.Simulation(
... size=(2, 4, 6),
... grid_spec=td.GridSpec(wavelength=1.0),
... monitors=[field_monitor],
... run_time=2e-12,
... sources=[
... td.UniformCurrentSource(
... size=(0, 0, 0),
... center=(0, 0.5, 0),
... polarization="Hx",
... source_time=td.GaussianPulse(
... freq0=2e14,
... fwidth=4e13,
... ),
... )
... ],
... )
>>> field_data = td.FieldData(monitor=field_monitor, Ex=scalar_field, grid_expanded=grid)
>>> sim_data = td.SimulationData(simulation=sim, data=(field_data,))
To save and load the :class:`SimulationData` object.
.. code-block:: python
sim_data.to_file(fname='path/to/file.hdf5') # Save a SimulationData object to a HDF5 file
sim_data = SimulationData.from_file(fname='path/to/file.hdf5') # Load a SimulationData object from a HDF5 file.
See Also
--------
**Notebooks:**
* `Quickstart <../../notebooks/StartHere.html>`_: Usage in a basic simulation flow.
* `Performing visualization of simulation data <../../notebooks/VizData.html>`_
* `Advanced monitor data manipulation and visualization <../../notebooks/XarrayTutorial.html>`_
"""
simulation: Simulation = pd.Field(
...,
title="Simulation",
description="Original :class:`.Simulation` associated with the data.",
)
data: tuple[annotate_type(MonitorDataType), ...] = pd.Field(
...,
title="Monitor Data",
description="List of :class:`.MonitorData` instances "
"associated with the monitors of the original :class:`.Simulation`.",
)
diverged: bool = pd.Field(
False,
title="Diverged",
description="A boolean flag denoting whether the simulation run diverged.",
)
@cached_property
def field_decay(self) -> TimeDataArray:
"""Returns a TimeDataArray of field decay values over time steps."""
log_str = self.log
if log_str is None:
raise DataError(
"No log string in the SimulationData object, can't extract field decay."
)
matches = re.findall(r"- Time step\s+(\d+)\s+/.*?field decay:\s*([0-9.eE+-]+)", log_str)
steps = [int(m[0]) for m in matches]
decays = [float(m[1]) for m in matches]
return TimeDataArray(decays, coords={"t": steps})
@property
def final_decay_value(self) -> float:
"""Returns value of the field decay at the final time step."""
field_decay = self.field_decay
if len(field_decay) == 0:
log.warning("No field decay values found, using 1.0 as final decay value.")
return 1.0
return float(field_decay.values[-1])
[docs]
def source_spectrum(self, source_index: int) -> Callable:
"""Get a spectrum normalization function for a given source index."""
if source_index is None or len(self.simulation.sources) == 0:
return np.ones_like
source = self.simulation.sources[source_index]
source_time = source.source_time
times = self.simulation.tmesh
dt = self.simulation.dt
# plug in mornitor_data frequency domain information
def source_spectrum_fn(freqs):
"""Source amplitude as function of frequency."""
spectrum = source_time.spectrum(times, freqs, dt)
# Remove user defined amplitude and phase from the normalization
# such that they would still have an effect on the output fields.
# In other words, we are only normalizing out the arbitrary part of the spectrum
# that depends on things like freq0, fwidth and offset.
return spectrum / source_time.amplitude / np.exp(1j * source_time.phase)
return source_spectrum_fn
[docs]
def renormalize(self, normalize_index: int) -> SimulationData:
"""Return a copy of the :class:`.SimulationData` with a different source used for the
normalization."""
num_sources = len(self.simulation.sources)
if normalize_index == self.simulation.normalize_index or num_sources == 0:
# already normalized to that index
return self.copy()
if normalize_index and (normalize_index < 0 or normalize_index >= num_sources):
# normalize index out of bounds for source list
raise DataError(
f"normalize_index {normalize_index} out of bounds for list of sources "
f"of length {num_sources}"
)
def source_spectrum_fn(freqs):
"""Normalization function that also removes previous normalization if needed."""
new_spectrum_fn = self.source_spectrum(normalize_index)
old_spectrum_fn = self.source_spectrum(self.simulation.normalize_index)
return new_spectrum_fn(freqs) / old_spectrum_fn(freqs)
# Make a new monitor_data dictionary with renormalized data
data_normalized = [mnt_data.normalize(source_spectrum_fn) for mnt_data in self.data]
simulation = self.simulation.copy(update={"normalize_index": normalize_index})
return self.copy(update={"simulation": simulation, "data": data_normalized})
def _split_adjoint_data(self: SimulationData, num_mnts_original: int) -> tuple[list, list]:
"""Split data list into original, adjoint field, and adjoint permittivity."""
data_all = list(self.data)
num_mnts_adjoint = (len(data_all) - num_mnts_original) // 2
log.info(
f" -> {num_mnts_original} monitors, {num_mnts_adjoint} adjoint field monitors, {num_mnts_adjoint} adjoint eps monitors."
)
data_original, data_adjoint = split_list(data_all, index=num_mnts_original)
return data_original, data_adjoint
def _split_original_fwd(self, num_mnts_original: int) -> tuple[SimulationData, SimulationData]:
"""Split this simulation data into original and fwd data from number of original mnts."""
# split the data and monitors into the original ones & adjoint gradient ones (for 'fwd')
data_original, data_fwd = self._split_adjoint_data(num_mnts_original=num_mnts_original)
monitors_orig, monitors_fwd = split_list(self.simulation.monitors, index=num_mnts_original)
# reconstruct the simulation data for the user, using original sim, and data for original mnts
sim_original = self.simulation.updated_copy(monitors=monitors_orig)
sim_data_original = self.updated_copy(
simulation=sim_original,
data=data_original,
deep=False,
)
# construct the 'forward' simulation and its data, which is only used for for gradient calc.
sim_fwd = self.simulation.updated_copy(monitors=monitors_fwd)
sim_data_fwd = self.updated_copy(
simulation=sim_fwd,
data=data_fwd,
deep=False,
)
return sim_data_original, sim_data_fwd
def _make_adjoint_sims(
self,
data_vjp_paths: set[tuple],
adjoint_monitors: list[Monitor],
) -> list[Simulation]:
"""Make the adjoint simulations from the original simulation and the VJP-containing data."""
if not data_vjp_paths:
return []
sim_original = self.simulation
# generate the adjoint sources {mnt_name : list[Source]}
sources_adj_dict = self._make_adjoint_sources(data_vjp_paths=data_vjp_paths)
if not sources_adj_dict:
return []
adj_srcs = []
for src_list in sources_adj_dict.values():
adj_srcs += list(src_list)
adjoint_source_infos = self._process_adjoint_sources(adj_srcs=adj_srcs)
if not adjoint_source_infos:
return []
# grab boundary conditions with flipped Bloch vectors (for adjoint)
bc_adj = sim_original.boundary_spec.flipped_bloch_vecs
# set the ADJ grid spec wavelength to the original wavelength (for same meshing)
grid_spec_original = sim_original.grid_spec
if sim_original.sources and grid_spec_original.wavelength is None:
wavelength_original = grid_spec_original.wavelength_from_sources(sim_original.sources)
grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original)
adj_sims = []
for adjoint_source_info in adjoint_source_infos:
# only include monitors with the same freqs as the adjoint sources
monitors = [
m.updated_copy(freqs=adjoint_source_info.post_norm.f) for m in adjoint_monitors
]
# fields to update the 'fwd' simulation with to make it 'adj'
sim_adj_update_dict = {
"sources": adjoint_source_info.sources,
"boundary_spec": bc_adj,
"monitors": monitors,
"post_norm": adjoint_source_info.post_norm,
}
if not adjoint_source_info.normalize_sim:
sim_adj_update_dict["normalize_index"] = None
if sim_original.sources and grid_spec_original.wavelength is None:
sim_adj_update_dict["grid_spec"] = grid_spec_adj
adj_sims.append(sim_original.updated_copy(**sim_adj_update_dict))
log.info(f"Created {len(adj_sims)} adjoint simulations.")
return adj_sims
def _make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceType]:
"""Generate all of the non-zero sources for the adjoint simulation given the VJP data."""
# map of index into 'self.data' to the list of datasets we need adjoint sources for
adj_src_map = defaultdict(list)
for _, index, dataset_name in data_vjp_paths:
adj_src_map[index].append(dataset_name)
# gather a dict of adjoint sources for every monitor data in the VJP that needs one
sources_adj_all = defaultdict(list)
for data_index, dataset_names in adj_src_map.items():
mnt_data = self.data[data_index]
sources_adj = mnt_data._make_adjoint_sources(
dataset_names=dataset_names, fwidth=self._fwidth_adj
)
sources_adj_all[mnt_data.monitor.name] = sources_adj
log.info(
f"Created {len(sources_adj)} adjoint sources for monitor '{mnt_data.monitor.name}'."
)
return sources_adj_all
@property
def _fwidth_adj(self) -> float:
# fwidth of forward pass, try as default for adjoint
normalize_index_fwd = self.simulation.normalize_index or 0
return self.simulation.sources[normalize_index_fwd].source_time.fwidth
@staticmethod
def _adjoint_src_width_single(adj_srcs: list[SourceType]) -> list[SourceType]:
"""Ensure the adjoint source sufficiently decays before zero frequency."""
adj_srcs_process_fwidth = []
for adj_src in adj_srcs:
source_time = adj_src.source_time
freq0 = source_time.freq0
fwidth = np.minimum(freq0 / NUM_ADJOINT_FWIDTH_TO_ZERO, source_time.fwidth)
adj_srcs_process_fwidth.append(
adj_src.updated_copy(source_time=source_time.updated_copy(fwidth=fwidth))
)
return adj_srcs_process_fwidth
def _process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSourceInfo]:
"""Compute list of final sources along with a post run normalization for adj fields."""
# dictionary mapping hash of sources with same freq dependence to list of time-dependencies
hashes_to_sources = defaultdict(None)
hashes_to_src_times = defaultdict(list)
adj_srcs_process_fwidth = self._adjoint_src_width_single(adj_srcs)
tmp_src_time = GaussianPulse(freq0=C_0, fwidth=inf)
for src in adj_srcs_process_fwidth:
tmp_src = src.updated_copy(source_time=tmp_src_time)
tmp_src_hash = tmp_src._hash_self()
hashes_to_sources[tmp_src_hash] = src
hashes_to_src_times[tmp_src_hash].append(src.source_time)
# Group sources by frequency or port, whichever gives fewer groups
num_ports = len(hashes_to_src_times)
num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs_process_fwidth})
log.info(f"Found {num_ports} spatial ports and {num_unique_freqs} unique frequencies.")
adjoint_infos = []
if num_unique_freqs <= num_ports:
log.info("Grouping adjoint sources by frequency.")
unique_freqs = {src.source_time.freq0 for src in adj_srcs_process_fwidth}
for freq0 in unique_freqs:
group = [src for src in adj_srcs_process_fwidth if src.source_time.freq0 == freq0]
post_norm = xr.DataArray(data=np.array([1 + 0j]), coords={"f": [freq0]})
adjoint_infos.append(
AdjointSourceInfo(sources=group, post_norm=post_norm, normalize_sim=True)
)
else:
log.info("Grouping adjoint sources by port.")
#
# warn if the forward simulation had symmetry and we are grouping by port, which
# which means the individual adjoint simulations may not respect the original symmetry
#
if np.any(np.abs(self.simulation.symmetry) > 0):
log.warning(
"The adjoint simulations for this problem are being broken into "
"multiple simulations that may not individually respect the symmetry of the "
"initial simulation. Gradients may be unreliable and it is recommended to "
"optimize this problem without utilizing symmetry."
)
for src_hash, src_times in hashes_to_src_times.items():
base_src = hashes_to_sources[src_hash]
group = [base_src.updated_copy(source_time=src_time) for src_time in src_times]
processed_srcs, post_norm = self._process_adjoint_sources_broadband(group)
adjoint_infos.append(
AdjointSourceInfo(
sources=processed_srcs, post_norm=post_norm, normalize_sim=True
)
)
log.info(f"Created {len(adjoint_infos)} adjoint source groups.")
return adjoint_infos
def _process_adjoint_sources_broadband(
self, adj_srcs: list[SourceType]
) -> tuple[list[SourceType], xr.DataArray]:
"""Process adjoint sources for the case of several sources at the same freq."""
src_broadband = self._make_broadband_source(adj_srcs=adj_srcs)
post_norm_amps = self._make_post_norm_amps(adj_srcs=adj_srcs)
log.info(
"Several adjoint sources, from one monitor. "
"Only difference between them is the source time. "
"Constructing broadband adjoint source and performing post-run normalization "
f"of fields with {len(post_norm_amps)} frequencies."
)
return [src_broadband], post_norm_amps
@staticmethod
def _adjoint_src_width_broadband(adj_srcs: list[SourceType]) -> float:
"""Find the adjoint source fwidth that sufficiently covers all adjoint frequencies."""
adj_srcs_f0 = [adj_src.source_time.freq0 for adj_src in adj_srcs]
middle_f0 = 0.5 * (np.max(adj_srcs_f0) + np.min(adj_srcs_f0))
min_f0 = np.min(adj_srcs_f0)
# width of source to sufficiently decay by zero frequency
decay_by_f0_fwidth = middle_f0 / NUM_ADJOINT_FWIDTH_TO_ZERO
# width of source to sufficiently cover all adjoint frequencies
fwidth_to_min_f0 = (middle_f0 - min_f0) / NUM_ADJOINT_FWIDTH_TO_FMIN
# log warning if the adjoint pulse width is not sufficiently decayed by zero frequency
# which may cause some issues in the adjoint accuracy when using field sources
if (fwidth_to_min_f0 > decay_by_f0_fwidth) and isinstance(adj_srcs[0], CustomCurrentSource):
log.warning(
"Adjoint source generated with a frequency spectrum that extends to or overlaps with 0 Hz. "
"This can introduce errors into the gradient computation."
)
# Choose a wider pulse width in frequency especially when the min/max frequencies
# for the broadband pulse might be very close together
adj_src_fwidth = np.maximum(decay_by_f0_fwidth, fwidth_to_min_f0)
return middle_f0, adj_src_fwidth
def _make_broadband_source(self, adj_srcs: list[SourceType]) -> SourceType:
"""Make a broadband source for a set of adjoint sources."""
adj_src_f0, adj_src_fwidth = self._adjoint_src_width_broadband(adj_srcs)
source_index = self.simulation.normalize_index or 0
src_time_base = self.simulation.sources[source_index].source_time.updated_copy(
amplitude=1.0, phase=0.0
)
src_broadband = adj_srcs[0].updated_copy(
source_time=src_time_base.updated_copy(freq0=adj_src_f0, fwidth=adj_src_fwidth)
)
return src_broadband
@staticmethod
def _make_post_norm_amps(adj_srcs: list[SourceType]) -> xr.DataArray:
"""Make a ``DataArray`` containing the complex amplitudes to multiply with adjoint field."""
freqs = []
amps_complex = []
for src in adj_srcs:
src_time = src.source_time
freqs.append(src_time.freq0)
amp_complex = src_time.amplitude * np.exp(1j * src_time.phase)
amps_complex.append(amp_complex)
coords = {"f": freqs}
amps_complex = np.array(amps_complex)
return xr.DataArray(amps_complex, coords=coords)
def _get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorDataType:
"""Grab the field or permittivity data for a given structure index."""
monitor_name = Structure._get_monitor_name(index=structure_index, data_type=data_type)
return self[monitor_name]
[docs]
def to_mat_file(self, fname: str, **kwargs):
"""Output the ``SimulationData`` object as ``.mat`` MATLAB file.
Parameters
----------
fname : str
Full path to the output file. Should include ``.mat`` file extension.
**kwargs : dict, optional
Extra arguments to ``scipy.io.savemat``: see ``scipy`` documentation for more detail.
Example
-------
>>> simData.to_mat_file('/path/to/file/data.mat') # doctest: +SKIP
"""
# Check .mat file extension is given
extension = pathlib.Path(fname).suffixes[0].lower()
if len(extension) == 0:
raise FileError(f"File '{fname}' missing extension.")
if extension != ".mat":
raise FileError(f"File '{fname}' should have a .mat extension.")
# Handle m_dict in kwargs
if "m_dict" in kwargs:
raise ValueError(
"'m_dict' is automatically determined by 'to_mat_file', can't pass to 'savemat'."
)
# Get SimData object as dictionary
sim_dict = self.dict()
# set long field names true by default, otherwise it wont save fields with > 31 characters
if "long_field_names" not in kwargs:
kwargs["long_field_names"] = True
# Remove NoneType values from dict
# Built from theory discussed in https://github.com/scipy/scipy/issues/3488
modified_sim_dict = replace_values(sim_dict, None, [])
try:
from scipy.io import savemat
savemat(fname, modified_sim_dict, **kwargs)
except Exception as e:
raise ValueError(
"Could not save supplied 'SimulationData' to file. As this is an experimental feature, we may not be able to support the contents of your dataset. If you receive this error, please feel free to raise an issue on our front end repository so we can investigate."
) from e