"""Defines jax-compatible mediums."""
from __future__ import annotations
from typing import Dict, Tuple, Union, Callable, Optional
from abc import ABC
import pydantic.v1 as pd
import numpy as np
from jax.tree_util import register_pytree_node_class
import xarray as xr
from ....components.types import Bound, Literal
from ....components.medium import Medium, AnisotropicMedium, CustomMedium
from ....components.geometry.base import Geometry
from ....components.data.monitor_data import FieldData
from ....exceptions import SetupError
from ....constants import CONDUCTIVITY
from .base import JaxObject
from .types import JaxFloat
from .data.data_array import JaxDataArray
from .data.dataset import JaxPermittivityDataset
# number of integration points per unit wavelength in material
PTS_PER_WVL_INTEGRATION = 20
# maximum number of pixels allowed in each component of a JaxCustomMedium
MAX_NUM_CELLS_CUSTOM_MEDIUM = 250_000
class AbstractJaxMedium(ABC, JaxObject):
"""Holds some utility functions for Jax medium types."""
def _get_volume_disc(
self, grad_data: FieldData, sim_bounds: Bound, wvl_mat: float
) -> Tuple[Dict[str, np.ndarray], float]:
"""Get the coordinates and volume element for the inside of the corresponding structure."""
# find intersecting volume between structure and simulation
mnt_bounds = grad_data.monitor.geometry.bounds
rmin, rmax = Geometry.bounds_intersection(mnt_bounds, sim_bounds)
# assemble volume coordinates and differential volume element
d_vol = 1.0
vol_coords = {}
for coord_name, min_edge, max_edge in zip("xyz", rmin, rmax):
size = max_edge - min_edge
# don't discretize this dimension if there is no thickness along it
if size == 0:
vol_coords[coord_name] = [max_edge]
continue
# update the volume element value
num_cells_dim = int(size * PTS_PER_WVL_INTEGRATION / wvl_mat) + 1
d_len = size / num_cells_dim
d_vol *= d_len
# construct the interpolation coordinates along this dimension
coords_interp = np.linspace(min_edge + d_len / 2, max_edge - d_len / 2, num_cells_dim)
vol_coords[coord_name] = coords_interp
return vol_coords, d_vol
@staticmethod
def make_inside_mask(vol_coords: Dict[str, np.ndarray], inside_fn: Callable) -> xr.DataArray:
"""Make a 3D mask of where the volume coordinates are inside a supplied function."""
meshgrid_args = [vol_coords[dim] for dim in "xyz" if dim in vol_coords]
vol_coords_meshgrid = np.meshgrid(*meshgrid_args, indexing="ij")
inside_kwargs = dict(zip("xyz", vol_coords_meshgrid))
values = inside_fn(**inside_kwargs)
return xr.DataArray(values, coords=vol_coords)
def e_mult_volume(
self,
field: Literal["Ex", "Ey", "Ez"],
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
vol_coords: Dict[str, np.ndarray],
d_vol: float,
inside_fn: Callable,
) -> xr.DataArray:
"""Get the E_fwd * E_adj * dV field distribution inside of the discretized volume."""
e_fwd = grad_data_fwd.field_components[field]
e_adj = grad_data_adj.field_components[field]
e_dotted = e_fwd * e_adj
inside_mask = self.make_inside_mask(vol_coords=vol_coords, inside_fn=inside_fn)
isel_kwargs = {
key: [0]
for key, value in vol_coords.items()
if isinstance(value, float) or len(value) <= 1
}
interp_kwargs = {key: value for key, value in vol_coords.items() if key not in isel_kwargs}
fields_eval = e_dotted.isel(**isel_kwargs).interp(**interp_kwargs, assume_sorted=True)
inside_mask = inside_mask.isel(**isel_kwargs)
mask_dV = inside_mask * d_vol
fields_eval = fields_eval.assign_coords(**mask_dV.coords)
return mask_dV * fields_eval
def d_eps_map(
self,
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
sim_bounds: Bound,
wvl_mat: float,
inside_fn: Callable,
) -> xr.DataArray:
"""Mapping of gradient w.r.t. permittivity at each point in discretized volume."""
vol_coords, d_vol = self._get_volume_disc(
grad_data=grad_data_fwd, sim_bounds=sim_bounds, wvl_mat=wvl_mat
)
e_mult_sum = 0.0
for field in ("Ex", "Ey", "Ez"):
e_mult_sum += self.e_mult_volume(
field=field,
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
vol_coords=vol_coords,
d_vol=d_vol,
inside_fn=inside_fn,
)
return e_mult_sum
[docs]
@register_pytree_node_class
class JaxMedium(Medium, AbstractJaxMedium):
"""A :class:`.Medium` registered with jax."""
_tidy3d_class = Medium
permittivity_jax: JaxFloat = pd.Field(
1.0,
title="Permittivity",
description="Relative permittivity of the medium. May be a ``jax`` ``Array``.",
stores_jax_for="permittivity",
)
conductivity_jax: JaxFloat = pd.Field(
0.0,
title="Conductivity",
description="Electric conductivity. Defined such that the imaginary part of the complex "
"permittivity at angular frequency omega is given by conductivity/omega.",
units=CONDUCTIVITY,
stores_jax_for="conductivity",
)
[docs]
def store_vjp(
self,
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
sim_bounds: Bound,
wvl_mat: float,
inside_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
) -> JaxMedium:
"""Returns the gradient of the medium parameters given forward and adjoint field data."""
# integrate the dot product of each E component over the volume, update vjp for epsilon
d_eps_map = self.d_eps_map(
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
sim_bounds=sim_bounds,
wvl_mat=wvl_mat,
inside_fn=inside_fn,
)
vjp_eps_complex = d_eps_map.sum(dim=("x", "y", "z"))
vjp_eps = 0.0
vjp_sigma = 0.0
for freq in d_eps_map.coords["f"]:
vjp_eps_complex_f = vjp_eps_complex.sel(f=freq)
_vjp_eps, _vjp_sigma = self.eps_complex_to_eps_sigma(vjp_eps_complex_f, freq)
vjp_eps += _vjp_eps
vjp_sigma += _vjp_sigma
return self.copy(
update=dict(
permittivity_jax=vjp_eps,
conductivity_jax=vjp_sigma,
)
)
[docs]
@register_pytree_node_class
class JaxAnisotropicMedium(AnisotropicMedium, AbstractJaxMedium):
"""A :class:`.Medium` registered with jax."""
_tidy3d_class = AnisotropicMedium
xx: JaxMedium = pd.Field(
...,
title="XX Component",
description="Medium describing the xx-component of the diagonal permittivity tensor.",
jax_field=True,
)
yy: JaxMedium = pd.Field(
...,
title="YY Component",
description="Medium describing the yy-component of the diagonal permittivity tensor.",
jax_field=True,
)
zz: JaxMedium = pd.Field(
...,
title="ZZ Component",
description="Medium describing the zz-component of the diagonal permittivity tensor.",
jax_field=True,
)
[docs]
def store_vjp(
self,
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
sim_bounds: Bound,
wvl_mat: float,
inside_fn: Callable,
) -> JaxMedium:
"""Returns the gradient of the medium parameters given forward and adjoint field data."""
# integrate the dot product of each E component over the volume, update vjp for epsilon
vol_coords, d_vol = self._get_volume_disc(
grad_data=grad_data_fwd, sim_bounds=sim_bounds, wvl_mat=wvl_mat
)
vjp_fields = {}
for component in "xyz":
field_name = "E" + component
component_name = component + component
e_mult_dim = self.e_mult_volume(
field=field_name,
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
vol_coords=vol_coords,
d_vol=d_vol,
inside_fn=inside_fn,
)
vjp_eps_complex_ii = e_mult_dim.sum(dim=("x", "y", "z"))
freq = e_mult_dim.coords["f"][0]
vjp_eps_ii = 0.0
vjp_sigma_ii = 0.0
for freq in e_mult_dim.coords["f"]:
vjp_eps_complex_ii_f = vjp_eps_complex_ii.sel(f=freq)
_vjp_eps_ii, _vjp_sigma_ii = self.eps_complex_to_eps_sigma(
vjp_eps_complex_ii_f, freq
)
vjp_eps_ii += _vjp_eps_ii
vjp_sigma_ii += _vjp_sigma_ii
vjp_medium = self.components[component_name]
vjp_fields[component_name] = vjp_medium.updated_copy(
permittivity_jax=vjp_eps_ii,
conductivity_jax=vjp_sigma_ii,
)
return self.copy(update=vjp_fields)
[docs]
@register_pytree_node_class
class JaxCustomMedium(CustomMedium, AbstractJaxMedium):
"""A :class:`.CustomMedium` registered with ``jax``.
Note: The gradient calculation assumes uniform field across the pixel.
Therefore, the accuracy degrades as the pixel size becomes large
with respect to the field variation.
"""
_tidy3d_class = CustomMedium
eps_dataset: Optional[JaxPermittivityDataset] = pd.Field(
None,
title="Permittivity Dataset",
description="User-supplied dataset containing complex-valued permittivity "
"as a function of space. Permittivity distribution over the Yee-grid will be "
"interpolated based on the data nearest to the grid location.",
jax_field=True,
)
@pd.root_validator(pre=True)
def _pre_deprecation_dataset(cls, values):
"""Don't allow permittivity as a field until we support it."""
if values.get("permittivity") or values.get("conductivity"):
raise SetupError(
"'permittivity' and 'conductivity' are not yet supported in adjoint plugin. "
"Please continue to use the 'eps_dataset' field to define the component "
"of the permittivity tensor."
)
return values
@pd.validator("eps_dataset", always=True)
def _is_not_too_large(cls, val):
"""Ensure number of pixels does not surpass a set amount."""
for field_dim in "xyz":
field_name = f"eps_{field_dim}{field_dim}"
data_array = val.field_components[field_name]
coord_lens = [len(data_array.coords[key]) for key in "xyz"]
num_cells_dim = np.prod(coord_lens)
if num_cells_dim > MAX_NUM_CELLS_CUSTOM_MEDIUM:
raise SetupError(
"For the adjoint plugin, each component of the 'JaxCustomMedium.eps_dataset' "
f"is restricted to have a maximum of {MAX_NUM_CELLS_CUSTOM_MEDIUM} cells. "
f"Detected {num_cells_dim} grid cells in the '{field_name}' component ."
)
return val
@pd.validator("eps_dataset", always=True)
def _eps_dataset_single_frequency(cls, val):
"""Override of inherited validator. (still needed)"""
return val
@pd.validator("eps_dataset", always=True)
def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, values):
"""Override of inherited validator."""
return val
[docs]
def store_vjp(
self,
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
sim_bounds: Bound,
wvl_mat: float,
inside_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
) -> JaxMedium:
"""Returns the gradient of the medium parameters given forward and adjoint field data."""
# get the boundaries of the intersection of the CustomMedium and the Simulation
mnt_bounds = grad_data_fwd.monitor.geometry.bounds
bounds_intersect = Geometry.bounds_intersection(mnt_bounds, sim_bounds)
# get the grids associated with the user-supplied coordinates within these bounds
grids = self.grids(bounds=bounds_intersect)
vjp_field_components = {}
for dim in "xyz":
eps_field_name = f"eps_{dim}{dim}"
# grab the original data and its coordinates
orig_data_array = self.eps_dataset.field_components[eps_field_name]
coords = orig_data_array.coords
grid = grids[eps_field_name]
d_sizes = grid.sizes
d_sizes = [d_sizes.x, d_sizes.y, d_sizes.z]
# construct the coordinates for interpolation and selection within the custom medium
# TODO: extend this to all points within the volume.
interp_coords = {}
sum_axes = []
for dim_index, dim_pt in enumerate("xyz"):
coord_dim = coords[dim_pt]
# if it's uniform / single pixel along this dim
if len(np.array(coord_dim)) == 1:
# discretize along this edge like a regular volume
# compute the length of the pixel within the sim bounds
r_min_coords, r_max_coords = grid.boundaries.to_list[dim_index]
r_min_sim, r_max_sim = np.array(sim_bounds).T[dim_index]
r_min = max(r_min_coords, r_min_sim)
r_max = min(r_max_coords, r_max_sim)
size = abs(r_max - r_min)
# compute the length element along the dim, handling case of sim.size=0
if size > 0:
# discretize according to PTS_PER_WVL
num_cells_dim = int(size * PTS_PER_WVL_INTEGRATION / wvl_mat) + 1
d_len = size / num_cells_dim
coords_interp = np.linspace(
r_min + d_len / 2, r_max - d_len / 2, num_cells_dim
)
else:
# just interpolate at the single position, dL=1 to normalize out
d_len = 1.0
coords_interp = np.array([(r_min + r_max) / 2.0])
# construct the interpolation coordinates along this dimension
d_sizes[dim_index] = np.array([d_len])
interp_coords[dim_pt] = coords_interp
# only sum this dimension if there are multiple points
sum_axes.append(dim_pt)
# otherwise
else:
# just evaluate at the original data coords
interp_coords[dim_pt] = coord_dim
# outer product all dimensions to get a volume element mask
d_vols = np.einsum("i, j, k -> ijk", *d_sizes)
# grab the corresponding dotted fields at these interp_coords and sum over len-1 pixels
field_name = "E" + dim
e_dotted = (
self.e_mult_volume(
field=field_name,
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
vol_coords=interp_coords,
d_vol=d_vols,
inside_fn=inside_fn,
)
.sum(sum_axes)
.sum(dim="f")
)
# reshape values to the expected vjp shape to be more safe
vjp_shape = tuple(len(coord) for _, coord in coords.items())
# make sure this has the same dtype as the original
dtype_orig = np.array(orig_data_array.values).dtype
vjp_values = e_dotted.values.reshape(vjp_shape)
if dtype_orig.kind == "f":
vjp_values = vjp_values.real
vjp_values = vjp_values.astype(dtype_orig)
# construct a DataArray storing the vjp
vjp_data_array = JaxDataArray(values=vjp_values, coords=coords)
vjp_field_components[eps_field_name] = vjp_data_array
# package everything into dataset
vjp_eps_dataset = JaxPermittivityDataset(**vjp_field_components)
return self.copy(update=dict(eps_dataset=vjp_eps_dataset))
JaxMediumType = Union[JaxMedium, JaxAnisotropicMedium, JaxCustomMedium]
JAX_MEDIUM_MAP = {
Medium: JaxMedium,
AnisotropicMedium: JaxAnisotropicMedium,
CustomMedium: JaxCustomMedium,
}