## Contents

Note: Tidy3D now supports automatic differentiation natively through autograd. The jax-based adjoint plugin will be deprecated from 2.7 onwards. To see this notebook implemented in the new feature, see this notebook.

To install the jax module required for this feature, we recommend running pip install "tidy3d[jax]".

In this notebook, we will show how to use the adjoint plugin efficiently for objectives involving several simulations.

One common application of this involves defining an objective function that may depend on several different definitions of your structures, each with some geometric or material modification. For example, including the performance of devices with slightly larger or smaller feature sizes into oneβs objective can serve to make optimization more robust to fabrication errors. For more details, see this paper.

[1]:

import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pylab as plt

import tidy3d as td


## Setup#

First, letβs define the frequency that our objective will depend on

[2]:

freq0 = 2e14
wavelength = td.C_0 / freq0


Now we set up some physical parameters.

We will be putting a td.Box in the center of a domain with a point source on one side and a diffraction monitor on the other.

The objective will involve summing the power of the 0th order diffraction order.

The gradient of this objective will be computed with respect to the permittivity of the box.

We will adjust the size of the box by dy in the y direction and define a combined objective function that returns the average power when the box is either perturbed by +dy, 0, or -dy, which emulates a treatment for enhancing device robustness via dilation and erosion effects, as described in the paper linked at the top of this notebook.

[3]:

permittivity_val = 2.0

# box sizes (original design)
lx = wavelength
ly = wavelength
lz = wavelength

# amount by which we will change the box size in y
dy = ly / 5.0

buffer = 2 * wavelength

Lx = lx + 2 * buffer
Ly = ly + dy + 2 * buffer
Lz = lz + 2 * buffer

src_pos_x = -Lx / 2 + buffer / 2
mnt_pos_x = +Lx / 2 - buffer / 2

[4]:

def make_sim(permittivity: float, dy_sign: int) -> tda.JaxSimulation:
"""Make a simulation as a function of the box permittivity and the frequency."""

box_size = ly + dy_sign * dy

box = tda.JaxStructure(
geometry=tda.JaxBox(
center=(0.0, 0.0, 0.0),
size=(lx, box_size, lz)
),
medium=tda.JaxMedium(permittivity=permittivity),
)

src = td.PointDipole(
center=(src_pos_x, 0, 0),
polarization="Ey",
source_time=td.GaussianPulse(
freq0=freq0,
fwidth=freq0/10,
),
)

mnt = td.DiffractionMonitor(
center=(mnt_pos_x, 0, 0),
size=(0, td.inf, td.inf),
freqs=[freq0],
name="diffraction",
)

return tda.JaxSimulation(
size=(Lx, Ly, Lz),
input_structures=[box],
output_monitors=[mnt],
sources=[src],
grid_spec=td.GridSpec.auto(wavelength=td.C_0 / freq0),
boundary_spec=td.BoundarySpec(
x=td.Boundary.pml(), y=td.Boundary.periodic(), z=td.Boundary.periodic()
),
run_time=200 / src.source_time.fwidth,
)


Letβs make a simulation for each of the perturbed size values and visualize them.

[5]:

f, axes = plt.subplots(1,3, tight_layout=True, figsize=(10, 4))

jax_sim = make_sim(permittivity=permittivity_val, dy_sign=dy_sign)
ax = jax_sim.plot(z=0, ax=ax)
ax.set_title(f"size[y]={jax_sim.input_structures[0].geometry.size[1]:.2f}")


## Define Objective#

Now letβs define our objective function, first we defined how to postprocess a SimulationData to give the desired power.

[6]:

def post_process(sim_data: tda.JaxSimulationData) -> float:
"""O-th order diffracted power."""
amp = sim_data["diffraction"].amps.sel(orders_x=0, orders_y=0)
return jnp.sum(jnp.abs(amp.values) ** 2)


And then we write our combined, multi-objective over all of the dy values. We use the plugins.adjoint.web.run_async function to run a list of these three simulations simultaneously.

[7]:

def objective(permittivity: float) -> float:
"""Average of O-th order diffracted power over all dy_sign values."""
sim_data_list = tda.web.run_async(sims, path_dir="data", verbose=True)
powers = [post_process(sim_data) for sim_data in sim_data_list]
return jnp.mean(jnp.array(powers))


Finally, we are ready to use jax.value_and_grad to differentiate this function.

[8]:

grad_objective = jax.value_and_grad(objective)

print(f"average power = {power_average:.2e}")
print(f"derivative of average power wrt permittivity = {grad_power_min:.2e}")

[10:58:35] Created task '0_fwd' with task_id
'fdve-b20f9e17-3302-4b67-8159-c83354a27ea1v1'.

           View task using web UI at
4b67-8159-c83354a27ea1v1'.

[10:58:36] Created task '1_fwd' with task_id
'fdve-aee667ea-d54c-47dd-b1d1-b7652e74ccddv1'.

           View task using web UI at
47dd-b1d1-b7652e74ccddv1'.

[10:58:37] Created task '2_fwd' with task_id

           View task using web UI at

[10:58:39] Started working on Batch.

[10:58:40] Maximum FlexCredit cost: 0.075 for the whole batch. Use
'Batch.real_cost()' to get the billed FlexCredit cost after the Batch
has completed.

[10:58:59] Batch complete.

[10:59:01] loading SimulationData from
data/fdve-b20f9e17-3302-4b67-8159-c83354a27ea1v1.hdf5

[10:59:02] loading SimulationData from
data/fdve-aee667ea-d54c-47dd-b1d1-b7652e74ccddv1.hdf5

[10:59:03] loading SimulationData from

[10:59:04] Created task '0_adj' with task_id
'fdve-c2264f3e-360a-499a-b410-f84ef336d4ecv1'.

           View task using web UI at
499a-b410-f84ef336d4ecv1'.

[10:59:05] Created task '1_adj' with task_id
'fdve-0914c5d7-6da1-4127-8b7c-4107cb45dce4v1'.

           View task using web UI at
4127-8b7c-4107cb45dce4v1'.

           Created task '2_adj' with task_id
'fdve-1ac8f06a-d66a-43de-b7a3-035bb6ca86b9v1'.

           View task using web UI at
43de-b7a3-035bb6ca86b9v1'.

[10:59:07] Started working on Batch.

[10:59:09] Maximum FlexCredit cost: 0.075 for the whole batch. Use
'Batch.real_cost()' to get the billed FlexCredit cost after the Batch
has completed.

[10:59:28] Batch complete.

[10:59:30] loading SimulationData from
data/fdve-c2264f3e-360a-499a-b410-f84ef336d4ecv1.hdf5

[10:59:31] loading SimulationData from
data/fdve-0914c5d7-6da1-4127-8b7c-4107cb45dce4v1.hdf5

[10:59:32] loading SimulationData from
data/fdve-1ac8f06a-d66a-43de-b7a3-035bb6ca86b9v1.hdf5

average power = 2.62e+00
derivative of average power wrt permittivity = -8.79e-01


## Sanity Checking: Manual Loop over size#

Now we will implement the brute force approach to computing the multi-objective gradient by the naive approach of looping over dy values and computing the individual gradient contributions one by one.

[9]:

def grad_manual(permittivity: float) -> float:
"""Average of O-th order diffracted power over all dy_sign values."""

print(f"working on dy_sign = {dy_sign}")

def objective_fn(p):
sims = make_sim(p, dy_sign=dy_sign)
return post_process(sim_data)


[10]:

grad_man = grad_manual(permittivity_val)

working on dy_sign = -1
working on dy_sign = 0
working on dy_sign = 1


We can see that they match, as expected.

[11]:

print(f"gradient (batched) = {grad_power_min:.4e}")

gradient (batched) = -8.7947e-01


### Takeaways#

The main thing to note here is that, using plugins.adjoint.web.run_async, all of the individual simulations were uploaded at roughly the same time.

This means that the server is able to work on them concurrently rather than needing to wait for the previously uploaded one to finish. The time savings for applications with several simulations can be enormous.

Note: Native support for multi-frequency output monitors was added in Tidy3D 2.5. Previously it was recommended that users use the run_async approach described here for multi-frequency objectives, but this is no longer necessary. That being said, for objectives with very tightly-packed frequency spacing in the output monitors, using a batch approach such as described here may be advantageous as the multi-frequency adjoint approach requires a run_time that scales inversely with the difference between frequency points, potentially requiring long simulation run times. The overall wait time may therefore be shorter using a batch approach.

[ ]:



[ ]: