"""EME simulation data"""
from __future__ import annotations
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import pydantic.v1 as pd
from ....exceptions import SetupError
from ...base import cached_property
from ...data.data_array import EMEScalarFieldDataArray, EMESMatrixDataArray
from ...data.monitor_data import FieldData, ModeData, ModeSolverData
from ...data.sim_data import AbstractYeeGridSimulationData
from ...types import annotate_type
from ..simulation import EMESimulation
from .dataset import EMESMatrixDataset
from .monitor_data import EMEFieldData, EMEModeSolverData, EMEMonitorDataType
[docs]
class EMESimulationData(AbstractYeeGridSimulationData):
"""Data associated with an EME simulation."""
simulation: EMESimulation = pd.Field(
..., title="EME simulation", description="EME simulation associated with this data."
)
data: Tuple[annotate_type(EMEMonitorDataType), ...] = pd.Field(
...,
title="Monitor Data",
description="List of EME monitor data "
"associated with the monitors of the original :class:`.EMESimulation`.",
)
smatrix: Optional[EMESMatrixDataset] = pd.Field(
None, title="S Matrix", description="Scattering matrix of the EME simulation."
)
port_modes: Optional[EMEModeSolverData] = pd.Field(
None,
title="Port Modes",
description="Modes associated with the two ports of the EME device. "
"The scattering matrix is expressed in this basis.",
)
def _extract_mode_solver_data(
self, data: EMEModeSolverData, eme_cell_index: int, sweep_index: int = None
) -> ModeSolverData:
"""Extract :class:`.ModeSolverData` at a given ``eme_cell_index``.
Assumes the :class:`.EMEModeSolverMonitor` spans the entire simulation and has
no downsampling.
"""
update_dict = dict(data._grid_correction_dict, **data.field_components)
update_dict.update({"n_complex": data.n_complex})
update_dict = {
key: field.sel(eme_cell_index=eme_cell_index, drop=True)
for key, field in update_dict.items()
}
sweep_in_data = "sweep_index" in data.n_complex.coords
if sweep_index is not None and sweep_in_data:
update_dict = {
key: field.isel(sweep_index=sweep_index, drop=True)
for key, field in update_dict.items()
}
if (
"sweep_index" in update_dict["n_complex"].dims
and len(update_dict["n_complex"].sweep_index) == 1
):
update_dict = {
key: field.squeeze(dim="sweep_index") for key, field in update_dict.items()
}
monitor = self.simulation.mode_solver_monitors[eme_cell_index]
monitor = monitor.updated_copy(
colocate=data.monitor.colocate,
)
grid_expanded = self.simulation.discretize_monitor(monitor=monitor)
return ModeSolverData(**update_dict, monitor=monitor, grid_expanded=grid_expanded)
@cached_property
def port_modes_tuple(self) -> Tuple[ModeSolverData, ModeSolverData]:
"""Port modes as a tuple ``(port_modes_1, port_modes_2)``."""
if self.port_modes is None:
raise SetupError(
"The field 'port_modes' is 'None'. Please set 'store_port_modes' "
"to 'True' in 'EMESimulation' and re-run the simulation."
)
if self.simulation._sweep_modes:
raise SetupError(
"The port modes vary with 'sweep_index'. "
"Use 'EMESimulationData.port_modes_list_sweep' instead."
)
num_cells = self.simulation.eme_grid.num_cells
port_modes_1 = self._extract_mode_solver_data(data=self.port_modes, eme_cell_index=0)
port_modes_2 = self._extract_mode_solver_data(
data=self.port_modes, eme_cell_index=num_cells - 1
)
return port_modes_1, port_modes_2
@cached_property
def port_modes_list_sweep(self) -> List[Tuple[ModeSolverData, ModeSolverData]]:
"""Port modes as a list of tuples ``(port_modes_1, port_modes_2)``.
There is one entry for every sweep index if the port modes vary with sweep index."""
if self.port_modes is None:
raise SetupError(
"The field 'port_modes' is 'None'. Please set 'store_port_modes' "
"to 'True' in 'EMESimulation' and re-run the simulation."
)
if self.simulation._sweep_modes:
sweep_indices = np.arange(self.simulation.sweep_spec.num_sweep)
else:
sweep_indices = [0]
port_modes_list = []
for sweep_index in sweep_indices:
num_cells = self.simulation.eme_grid.num_cells
port_modes_1 = self._extract_mode_solver_data(
data=self.port_modes, eme_cell_index=0, sweep_index=sweep_index
)
port_modes_2 = self._extract_mode_solver_data(
data=self.port_modes, eme_cell_index=num_cells - 1, sweep_index=sweep_index
)
port_modes_list.append((port_modes_1, port_modes_2))
return port_modes_list
[docs]
def smatrix_in_basis(
self, modes1: Union[FieldData, ModeData] = None, modes2: Union[FieldData, ModeData] = None
) -> EMESMatrixDataset:
"""Express the scattering matrix in the provided basis.
Change of basis is done by computing overlaps between provided modes and port modes.
Parameters
----------
modes1: Union[FieldData, ModeData]
New modal basis for port 1. If None, use port_modes.
modes2: Union[FieldData, ModeData]
New modal basis for port 2. If None, use port_modes.
Returns
-------
:class:`.EMESMatrixDataset`
The scattering matrix of the EME simulation, but expressed in the basis
of the provided modes, rather than in the basis of ``port_modes`` used
in computation.
"""
if self.port_modes is None:
raise SetupError(
"Cannot convert the EME scattering matrix to the provided "
"basis, because 'port_modes' is 'None'. Please set 'store_port_modes' "
"to 'True' and re-run the simulation."
)
port_modes1, port_modes2 = self.port_modes_list_sweep[0]
modes1_provided = modes1 is not None
modes2_provided = modes2 is not None
if not modes1_provided:
modes1 = port_modes1
if not modes2_provided:
modes2 = port_modes2
f1 = list(modes1.field_components.values())[0].f.values
f2 = list(modes2.field_components.values())[0].f.values
f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs)))
modes_in_1 = "mode_index" in list(modes1.field_components.values())[0].coords
modes_in_2 = "mode_index" in list(modes2.field_components.values())[0].coords
if modes_in_1:
mode_index_1 = list(modes1.field_components.values())[0].mode_index.to_numpy()
else:
mode_index_1 = [0]
if modes_in_2:
mode_index_2 = list(modes2.field_components.values())[0].mode_index.to_numpy()
else:
mode_index_2 = [0]
sweep = "sweep_index" in self.smatrix.S11.coords
if sweep:
sweep_indices = self.smatrix.S11.sweep_index.to_numpy()
else:
sweep_indices = [0]
data11 = np.zeros(
(len(f), len(sweep_indices), len(mode_index_1), len(mode_index_1)), dtype=complex
)
data12 = np.zeros(
(len(f), len(sweep_indices), len(mode_index_1), len(mode_index_2)), dtype=complex
)
data21 = np.zeros(
(len(f), len(sweep_indices), len(mode_index_2), len(mode_index_1)), dtype=complex
)
data22 = np.zeros(
(len(f), len(sweep_indices), len(mode_index_2), len(mode_index_2)), dtype=complex
)
for sweep_index in sweep_indices:
S11 = self.smatrix.S11.sel(f=f, sweep_index=sweep_index)
S12 = self.smatrix.S12.sel(f=f, sweep_index=sweep_index)
S21 = self.smatrix.S21.sel(f=f, sweep_index=sweep_index)
S22 = self.smatrix.S22.sel(f=f, sweep_index=sweep_index)
# nans in S-matrix indicate invalid EME modes
# we skip these in change of basis
nan_inds1 = np.argwhere(np.any(np.isnan(S11.to_numpy()), axis=0))
nan_inds2 = np.argwhere(np.any(np.isnan(S22.to_numpy()), axis=0))
keep_inds1 = np.setdiff1d(np.arange(len(S11.mode_index_in)), nan_inds1)
keep_inds2 = np.setdiff1d(np.arange(len(S22.mode_index_in)), nan_inds2)
keep_mode_inds1 = [S11.mode_index_in[i] for i in keep_inds1]
keep_mode_inds2 = [S22.mode_index_in[i] for i in keep_inds2]
S11 = S11.sel(mode_index_in=keep_mode_inds1, mode_index_out=keep_mode_inds1)
S12 = S12.sel(mode_index_in=keep_mode_inds2, mode_index_out=keep_mode_inds1)
S21 = S21.sel(mode_index_in=keep_mode_inds1, mode_index_out=keep_mode_inds2)
S22 = S22.sel(mode_index_in=keep_mode_inds2, mode_index_out=keep_mode_inds2)
if self.simulation._sweep_modes:
port_modes1, port_modes2 = self.port_modes_list_sweep[sweep_index]
if modes1_provided:
overlaps1 = modes1.outer_dot(port_modes1, conjugate=False)
if not modes_in_1:
overlaps1 = overlaps1.expand_dims(dim={"mode_index_0": mode_index_1}, axis=1)
O1 = overlaps1.sel(f=f, mode_index_1=keep_mode_inds1)
O1out = O1.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old")
O1in = O1.rename(mode_index_0="mode_index_in", mode_index_1="mode_index_in_old")
S11 = S11.rename(
mode_index_in="mode_index_in_old", mode_index_out="mode_index_out_old"
)
S12 = S12.rename(mode_index_out="mode_index_out_old")
S21 = S21.rename(mode_index_in="mode_index_in_old")
# this exception handling is needed because xarray renamed dims kwarg to dim
# but we want to keep supporting old xarray
try:
S11 = O1out.dot(S11, dim="mode_index_out_old").dot(
O1in, dim="mode_index_in_old"
)
S12 = O1out.dot(S12, dim="mode_index_out_old")
S21 = S21.dot(O1in, dim="mode_index_in_old")
except TypeError:
S11 = O1out.dot(S11, dims="mode_index_out_old").dot(
O1in, dims="mode_index_in_old"
)
S12 = O1out.dot(S12, dims="mode_index_out_old")
S21 = S21.dot(O1in, dims="mode_index_in_old")
if modes2_provided:
overlaps2 = modes2.outer_dot(port_modes2, conjugate=False)
if not modes_in_2:
overlaps2 = overlaps2.expand_dims(dim={"mode_index_0": mode_index_2}, axis=1)
O2 = overlaps2.sel(f=f, mode_index_1=keep_mode_inds2)
O2out = O2.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old")
O2in = O2.rename(mode_index_0="mode_index_in", mode_index_1="mode_index_in_old")
S12 = S12.rename(mode_index_in="mode_index_in_old")
S21 = S21.rename(mode_index_out="mode_index_out_old")
S22 = S22.rename(
mode_index_in="mode_index_in_old", mode_index_out="mode_index_out_old"
)
# same for this exception handling
try:
S12 = S12.dot(O2in, dim="mode_index_in_old")
S21 = O2out.dot(S21, dim="mode_index_out_old")
S22 = O2out.dot(S22, dim="mode_index_out_old").dot(
O2in, dim="mode_index_in_old"
)
except TypeError:
S12 = S12.dot(O2in, dims="mode_index_in_old")
S21 = O2out.dot(S21, dims="mode_index_out_old")
S22 = O2out.dot(S22, dims="mode_index_out_old").dot(
O2in, dims="mode_index_in_old"
)
data11[:, sweep_index, :, :] = S11.to_numpy()
data12[:, sweep_index, :, :] = S12.to_numpy()
data21[:, sweep_index, :, :] = S21.to_numpy()
data22[:, sweep_index, :, :] = S22.to_numpy()
coords11 = dict(
f=f, sweep_index=sweep_indices, mode_index_out=mode_index_1, mode_index_in=mode_index_1
)
coords12 = dict(
f=f, sweep_index=sweep_indices, mode_index_out=mode_index_1, mode_index_in=mode_index_2
)
coords21 = dict(
f=f, sweep_index=sweep_indices, mode_index_out=mode_index_2, mode_index_in=mode_index_1
)
coords22 = dict(
f=f, sweep_index=sweep_indices, mode_index_out=mode_index_2, mode_index_in=mode_index_2
)
xrS11 = EMESMatrixDataArray(data11, coords=coords11)
xrS12 = EMESMatrixDataArray(data12, coords=coords12)
xrS21 = EMESMatrixDataArray(data21, coords=coords21)
xrS22 = EMESMatrixDataArray(data22, coords=coords22)
if not sweep:
xrS11 = xrS11.drop_vars("sweep_index")
xrS12 = xrS12.drop_vars("sweep_index")
xrS21 = xrS21.drop_vars("sweep_index")
xrS22 = xrS22.drop_vars("sweep_index")
if not modes_in_1:
xrS11 = xrS11.drop_vars(("mode_index_out", "mode_index_in"))
xrS12 = xrS12.drop_vars("mode_index_out")
xrS21 = xrS21.drop_vars("mode_index_in")
if not modes_in_2:
xrS12 = xrS12.drop_vars("mode_index_in")
xrS21 = xrS21.drop_vars("mode_index_out")
xrS22 = xrS22.drop_vars(("mode_index_out", "mode_index_in"))
smatrix = EMESMatrixDataset(S11=xrS11, S12=xrS12, S21=xrS21, S22=xrS22)
return smatrix
[docs]
def field_in_basis(
self,
field: EMEFieldData,
modes: Union[FieldData, ModeData] = None,
port_index: Literal[0, 1] = 0,
) -> EMEFieldData:
"""Express the electromagnetic field in the provided basis.
Change of basis is done by computing overlaps between provided modes and port modes.
Parameters
----------
field: EMEFieldData
EME field to express in new basis.
modes: Union[FieldData, ModeData]
New modal basis. If None, use port_modes.
port_index: Literal[0, 1]
Port to excite.
Returns
-------
:class:`.EMEFieldData`
The propagated electromagnetic fied expressed in the basis
of the provided modes, rather than in the basis of ``port_modes`` used
in computation.
"""
if self.port_modes is None:
raise SetupError(
"Cannot convert the EME field to the provided "
"basis, because 'port_modes' is 'None'. Please set 'store_port_modes' "
"to 'True' and re-run the simulation."
)
sweep_in_field = "sweep_index" in list(field.field_components.values())[0].coords
new_fields = {}
if sweep_in_field:
sweep_indices = list(field.field_components.values())[0].sweep_index.to_numpy()
else:
sweep_indices = [0]
port_modes = self.port_modes_list_sweep[0][port_index]
modes_provided = modes is not None
if not modes_provided:
modes = self.port_modes_list_sweep[0][port_index]
modes_present = "mode_index" in list(modes.field_components.values())[0].coords
if modes_present:
mode_index = list(modes.field_components.values())[0].mode_index.to_numpy()
else:
mode_index = [0]
f1 = list(modes.field_components.values())[0].f.values
f2 = list(field.field_components.values())[0].f.values
f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs)))
# set up field arrays
field_data = {}
field_coords = {}
for field_key, field_comp in field.field_components.items():
shape = list(field_comp.shape)
shape[-1] = len(mode_index)
shape[-2] = 1
field_data[field_key] = np.empty(shape, dtype=complex)
field_data[field_key][:] = np.nan
field_coords[field_key] = dict(
x=field_comp.x.to_numpy(),
y=field_comp.y.to_numpy(),
z=field_comp.z.to_numpy(),
f=field_comp.f.to_numpy(),
sweep_index=sweep_indices,
eme_port_index=[port_index],
mode_index=mode_index,
)
# populate the arrays
for sweep_index in sweep_indices:
if self.simulation._sweep_modes:
port_modes = self.port_modes_list_sweep[sweep_index][port_index]
if modes_provided:
overlaps = modes.outer_dot(port_modes, conjugate=False)
if not modes_present:
overlaps = overlaps.expand_dims(dim={"mode_index_0": [0]}, axis=1)
overlaps = overlaps.sel(f=f)
for field_key, field_comp in field.field_components.items():
field_comp_data = field_comp.sel(f=f).to_numpy()
if modes_provided:
# we loop here to avoid memory issues from broadcasting
field_data[field_key][..., sweep_index, 0, :] = 0
for mode_index_old in field_comp.mode_index:
field_comp_curr = field_comp_data[
..., sweep_index, port_index, mode_index_old
]
overlap = overlaps.sel(mode_index_1=mode_index_old).to_numpy()
# some nans in field are fine, but all nans means invalid mode
if np.all(np.isnan(field_comp_curr)):
continue
# nans in overlap mean invalid port mode
if np.any(np.isnan(overlap)):
continue
field_data[field_key][..., sweep_index, 0, :] += (
field_comp_curr[..., None] * overlap[None, None, None, :, :]
)
else:
field_data[field_key][..., sweep_index, 0, :] = field_comp_data[
..., sweep_index, port_index, :
]
for field_key in field.field_components.keys():
new_fields[field_key] = EMEScalarFieldDataArray(
field_data[field_key], coords=field_coords[field_key]
)
if not modes_present:
new_fields[field_key] = new_fields[field_key].drop_vars("mode_index")
if not sweep_in_field:
new_fields[field_key] = new_fields[field_key].drop_vars("sweep_index")
return field.updated_copy(**new_fields)