from __future__ import annotations
from collections import deque
from typing import Any, Callable, Literal, Optional, Union
import autograd.numpy as np
import pydantic.v1 as pd
from autograd import value_and_grad
from numpy.typing import NDArray
from scipy.optimize import minimize
import tidy3d as td
from tidy3d.components.autograd.functions import _straight_through_clip
from tidy3d.components.base import Tidy3dBaseModel
from tidy3d.components.grid.grid import Coords
from tidy3d.plugins.autograd.constants import BETA_DEFAULT, ETA_DEFAULT
from tidy3d.plugins.autograd.types import KernelType, PaddingType
from .filters import make_filter
from .projections import tanh_projection
[docs]
class FilterAndProject(Tidy3dBaseModel):
"""A class that combines filtering and projection operations."""
radius: Union[float, tuple[float, ...]] = pd.Field(
..., title="Radius", description="The radius of the kernel."
)
dl: Union[float, tuple[float, ...]] = pd.Field(
..., title="Grid Spacing", description="The grid spacing."
)
size_px: Union[int, tuple[int, ...]] = pd.Field(
None, title="Size in Pixels", description="The size of the kernel in pixels."
)
beta: pd.NonNegativeFloat = pd.Field(
BETA_DEFAULT, title="Beta", description="The beta parameter for the tanh projection."
)
eta: pd.NonNegativeFloat = pd.Field(
ETA_DEFAULT, title="Eta", description="The eta parameter for the tanh projection."
)
filter_type: KernelType = pd.Field(
"conic", title="Filter Type", description="The type of filter to create."
)
padding: PaddingType = pd.Field(
"reflect", title="Padding", description="The padding mode to use."
)
[docs]
def __call__(
self, array: NDArray, beta: Optional[float] = None, eta: Optional[float] = None
) -> NDArray:
"""Apply the filter and projection to an input array.
Parameters
----------
array : np.ndarray
The input array to filter and project.
beta : float = None
The beta parameter for the tanh projection. If None, uses the instance's beta.
eta : float = None
The eta parameter for the tanh projection. If None, uses the instance's eta.
Returns
-------
np.ndarray
The filtered and projected array.
"""
filter_instance = make_filter(
radius=self.radius,
dl=self.dl,
size_px=self.size_px,
filter_type=self.filter_type,
padding=self.padding,
)
filtered = filter_instance(array)
beta = beta if beta is not None else self.beta
eta = eta if eta is not None else self.eta
projected = tanh_projection(filtered, beta, eta)
clip_projected = _straight_through_clip(projected, a_min=0.0, a_max=1.0)
return clip_projected
[docs]
def make_filter_and_project(
radius: Optional[Union[float, tuple[float, ...]]] = None,
dl: Optional[Union[float, tuple[float, ...]]] = None,
*,
size_px: Optional[Union[int, tuple[int, ...]]] = None,
beta: float = BETA_DEFAULT,
eta: float = ETA_DEFAULT,
filter_type: KernelType = "conic",
padding: PaddingType = "reflect",
) -> Callable:
"""Create a function that filters and projects an array.
See Also
--------
:func:`~parametrizations.FilterAndProject`.
"""
return FilterAndProject(
radius=radius,
dl=dl,
size_px=size_px,
beta=beta,
eta=eta,
filter_type=filter_type,
padding=padding,
)
def initialize_params_from_simulation(
sim: td.Simulation,
param_to_structure: Callable[..., td.Structure],
params0: np.ndarray,
*,
freq: Optional[float] = None,
outside_handling: Literal["extrapolate", "mask", "nan"] = "mask",
maxiter: int = 100,
bounds: tuple[Optional[float], Optional[float]] = (0.0, 1.0),
rel_improve_tol: float = 1e-3,
verbose: bool = False,
**param_kwargs: Any,
) -> np.ndarray:
"""Initialize design parameters to match base simulation permittivity in a region.
Builds an objective that compares the base simulation's permittivity against
the permittivity generated by a user‑supplied
parameterization and minimizes the relative L2 error using L‑BFGS‑B.
Notes
-----
- The simulation is not modified. If you need finer sampling inside the design
region, add a :class:`.MeshOverrideStructure` to ``sim.grid_spec`` first.
- The callable ``param_to_structure`` controls the parameterization and must return
a :class:`.Structure` (typically with a :class:`.CustomMedium`). Its permittivity
dataset (and coordinates) defines the grid used for comparison.
- The base simulation permittivity is sampled once on an extended subgrid covering
the design geometry and then interpolated onto the design region coordinates
using coordinate-aware interpolation (no per-iteration interpolation).
- Early stopping uses a single knob ``rel_improve_tol``; optimization stops once the
relative improvement over a small fixed window falls below this value.
- Points outside the base‑epsilon coverage can be handled via ``outside_handling``:
``'extrapolate'`` (use nearest extrapolation, default in earlier versions),
``'mask'`` (ignore outside points using a coverage mask; default), or
``'nan'`` (non‑extended grid; treat outside as NaN and ignore in the loss).
Parameters
----------
sim : :class:`.Simulation`
Base simulation without the design structure.
param_to_structure : Callable[..., :class:`.Structure`]
Function mapping parameters to a :class:`.Structure` whose medium permittivity is used
for comparison. Any extra keyword arguments passed to this initializer are forwarded to
``param_to_structure``.
params0 : :class:`numpy.ndarray`
Initial parameter array (2D or 3D). Values may take any range, consistent with ``bounds``
and the expectations of ``param_to_structure``.
maxiter : int = 100
Maximum number of L‑BFGS‑B iterations.
freq : float, optional
Frequency at which permittivity is evaluated. If ``None``, uses infinite frequency.
outside_handling : {"extrapolate", "mask", "nan"} = "mask"
Strategy for points where design coordinates fall outside the sampled base epsilon:
- "extrapolate": include them using nearest extrapolation.
- "mask": include only points within the coverage bounds of ``sim.epsilon`` on the
extended subgrid.
- "nan": sample on a non‑extended subgrid and ignore points where interpolation returns NaN.
bounds : tuple[float | None, float | None] = (0.0, 1.0)
Element‑wise parameter bounds, e.g. ``(0.0, 1.0)`` or ``(-1.0, 1.0)``. Use ``None`` to
indicate an unbounded side.
rel_improve_tol : float = 1e-3
Early‑stop tolerance on relative improvement over a small window.
verbose : bool = False
If ``True``, prints SciPy optimizer messages.
Returns
-------
:class:`numpy.ndarray`
Optimized parameters with the same shape as ``params0``.
Examples
--------
Initialize a small parameter array in a 2D simulation using a minimal,
differentiable parameterization that maps parameters to a permittivity field.
>>> import autograd.numpy as np
>>> import tidy3d as td
>>> from tidy3d.plugins.autograd.invdes import initialize_params_from_simulation
>>> sim = td.Simulation(
... size=(2.0, 2.0, 0.0),
... grid_spec=td.GridSpec.uniform(dl=1.0),
... boundary_spec=td.BoundarySpec.pml(x=True, y=True),
... run_time=1e-12,
... )
>>> box = td.Box(center=(0, 0, 0), size=(1.0, 1.0, 1.0))
>>> params0 = np.zeros((3, 3))
>>> def param_to_structure(p):
... # Map params -> density in [0, 1], then to eps in [1, 4].
... density = 0.5 * (1 + np.tanh(p))
... eps = 1.0 + 3.0 * density
... eps3d = eps.reshape((*p.shape, 1))
... return td.Structure.from_permittivity_array(geometry=box, eps_data=eps3d)
>>> params = initialize_params_from_simulation(
... sim=sim,
... param_to_structure=param_to_structure,
... params0=params0,
... maxiter=3,
... ) # doctest: +SKIP
>>> params.shape # doctest: +SKIP
(3, 3)
"""
structure_init = param_to_structure(params0, **param_kwargs)
if outside_handling not in ("extrapolate", "mask", "nan"):
raise ValueError("'outside_handling' must be one of {'extrapolate', 'mask', 'nan'}.")
# build base epsilon and interpolate onto design coords
extend_flag = outside_handling != "nan"
subgrid = sim.discretize(structure_init.geometry, extend=extend_flag)
eps_base_da = sim.epsilon_on_grid(grid=subgrid, coord_key="centers", freq=freq)
design_eps_da = structure_init.medium.permittivity
design_coords = Coords(
x=np.array(design_eps_da.coords["x"]),
y=np.array(design_eps_da.coords["y"]),
z=np.array(design_eps_da.coords["z"]),
)
if outside_handling == "nan":
eps_base_interp = design_coords.spatial_interp(
array=eps_base_da, interp_method="linear", fill_value=np.nan
).data
mask = np.isfinite(eps_base_interp)
else:
eps_base_interp = design_coords.spatial_interp(
array=eps_base_da, interp_method="linear", fill_value="extrapolate"
).data
if outside_handling == "mask":
# build mask from coverage bounds of base epsilon coordinates
xb, yb, zb = (np.array(eps_base_da.coords[d]) for d in ("x", "y", "z"))
xd, yd, zd = (np.array(design_eps_da.coords[d]) for d in ("x", "y", "z"))
mask_x = (xd >= np.min(xb)) & (xd <= np.max(xb))
mask_y = (yd >= np.min(yb)) & (yd <= np.max(yb))
mask_z = (zd >= np.min(zb)) & (zd <= np.max(zb))
mask = (mask_x[:, None, None]) & (mask_y[None, :, None]) & (mask_z[None, None, :])
else:
mask = None
if mask is not None:
covered = int(np.sum(mask))
total = int(np.prod(mask.shape))
frac = covered / max(total, 1)
if frac < 0.9:
td.log.warning(
f"Only {frac:.1%} of design points are covered by base epsilon sampling. "
"Consider adding a 'MeshOverrideStructure' or adjusting design coordinates."
)
if mask is None:
denom = np.sqrt(np.sum(eps_base_interp.real**2 + eps_base_interp.imag**2))
else:
denom = np.sqrt(
np.sum((eps_base_interp.real[mask]) ** 2 + (eps_base_interp.imag[mask]) ** 2)
)
denom = 1.0 if denom == 0 else denom
def loss_fn(params_vec: np.ndarray) -> float:
params = params_vec.reshape(params0.shape)
structure = param_to_structure(params, **param_kwargs)
eps_design = structure.medium.permittivity.data
if mask is None:
res = eps_base_interp - eps_design
return 0.5 * np.sum(res.real**2 + res.imag**2) / denom
res = (eps_base_interp - eps_design)[mask]
return 0.5 * np.sum(res.real**2 + res.imag**2) / denom
val_and_grad = value_and_grad(loss_fn)
WINDOW = 3
state = {
"vals": deque(maxlen=WINDOW),
"best_val": np.inf,
"best_x": params0.ravel().copy(),
}
def callback(xk: np.ndarray) -> None:
val = loss_fn(xk)
if val < state["best_val"]:
state["best_val"] = val
state["best_x"] = xk.copy()
state["vals"].append(val)
if len(state["vals"]) == WINDOW:
v_then = state["vals"][0]
v_now = state["vals"][-1]
rel_improve = (v_then - v_now) / max(abs(v_then), 1e-12)
if rel_improve < rel_improve_tol:
raise StopIteration
bounds_list = [bounds] * params0.size
try:
res = minimize(
fun=val_and_grad,
x0=params0.ravel(),
method="L-BFGS-B",
jac=True,
bounds=bounds_list,
callback=callback,
options={"maxiter": maxiter, "disp": verbose},
)
x_final = res.x
except StopIteration:
x_final = state["best_x"]
return x_final.reshape(params0.shape)