Multi-objective adjoint optimization#

To install the jax module required for this feature, we recommend running pip install "tidy3d[jax]".

In this notebook, we will show how to use the adjoint plugin efficiently for objectives involving several simulations.

One common application of this involves defining an objective function that may depend on several different definitions of your structures, each with some geometric or material modification. For example, including the performance of devices with slightly larger or smaller feature sizes into one’s objective can serve to make optimization more robust to fabrication errors. For more details, see this paper.

import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pylab as plt

import tidy3d as td
import tidy3d.plugins.adjoint as tda


First, let’s define the frequency that our objective will depend on

freq0 = 2e14
wavelength = td.C_0 / freq0

Now we set up some physical parameters.

We will be putting a td.Box in the center of a domain with a point source on one side and a diffraction monitor on the other.

The objective will involve summing the power of the 0th order diffraction order.

The gradient of this objective will be computed with respect to the permittivity of the box.

We will adjust the size of the box by dy in the y direction and define a combined objective function that returns the average power when the box is either perturbed by +dy, 0, or -dy, which emulates a treatment for enhancing device robustness via dilation and erosion effects, as described in the paper linked at the top of this notebook.

permittivity_val = 2.0

# box sizes (original design)
lx = wavelength
ly = wavelength
lz = wavelength

# amount by which we will change the box size in y
dy = ly / 5.0

buffer = 2 * wavelength

Lx = lx + 2 * buffer
Ly = ly + dy + 2 * buffer
Lz = lz + 2 * buffer

src_pos_x = -Lx / 2 + buffer / 2
mnt_pos_x = +Lx / 2 - buffer / 2

def make_sim(permittivity: float, dy_sign: int) -> tda.JaxSimulation:
    """Make a simulation as a function of the box permittivity and the frequency."""

    box_size = ly + dy_sign * dy

    box = tda.JaxStructure(
            center=(0.0, 0.0, 0.0),
            size=(lx, box_size, lz)

    src = td.PointDipole(
        center=(src_pos_x, 0, 0),

    mnt = td.DiffractionMonitor(
        center=(mnt_pos_x, 0, 0),
        size=(0, td.inf, td.inf),

    return tda.JaxSimulation(
        size=(Lx, Ly, Lz),
        sources=[src], / freq0),
            x=td.Boundary.pml(), y=td.Boundary.periodic(), z=td.Boundary.periodic()
        run_time=200 / src.source_time.fwidth,

Let’s make a simulation for each of the perturbed size values and visualize them.

f, axes = plt.subplots(1,3, tight_layout=True, figsize=(10, 4))

for ax, dy_sign in zip(axes, (-1, 0, 1)):
    jax_sim = make_sim(permittivity=permittivity_val, dy_sign=dy_sign)
    ax = jax_sim.plot(z=0, ax=ax)


Define Objective#

Now let’s define our objective function, first we defined how to postprocess a SimulationData to give the desired power.

def post_process(sim_data: tda.JaxSimulationData) -> float:
    """O-th order diffracted power."""
    amp = sim_data["diffraction"].amps.sel(orders_x=0, orders_y=0)
    return jnp.sum(jnp.abs(amp.values) ** 2)

And then we write our combined, multi-objective over all of the dy values. We use the plugins.adjoint.web.run_async function to run a list of these three simulations simultaneously.

def objective(permittivity: float) -> float:
    """Average of O-th order diffracted power over all dy_sign values."""
    sims = [make_sim(permittivity, dy_sign=dy_sign) for dy_sign in (-1, 0, 1)]
    sim_data_list = tda.web.run_async(sims, path_dir="data", verbose=True)
    powers = [post_process(sim_data) for sim_data in sim_data_list]
    return jnp.mean(jnp.array(powers))

Multi-Objective Gradient Calculation#

Finally, we are ready to use jax.value_and_grad to differentiate this function.

grad_objective = jax.value_and_grad(objective)

power_average, grad_power_min = grad_objective(permittivity_val)
print(f"average power = {power_average:.2e}")
print(f"derivative of average power wrt permittivity = {grad_power_min:.2e}")

[10:58:35] Created task '0_fwd' with task_id
[10:58:36] Created task '1_fwd' with task_id
[10:58:37] Created task '2_fwd' with task_id
[10:58:39] Started working on Batch.
[10:58:40] Maximum FlexCredit cost: 0.075 for the whole batch. Use
           'Batch.real_cost()' to get the billed FlexCredit cost after the Batch
           has completed.
[10:58:59] Batch complete.
[10:59:01] loading SimulationData from
[10:59:02] loading SimulationData from
[10:59:03] loading SimulationData from
[10:59:04] Created task '0_adj' with task_id
[10:59:05] Created task '1_adj' with task_id
           Created task '2_adj' with task_id
[10:59:07] Started working on Batch.
[10:59:09] Maximum FlexCredit cost: 0.075 for the whole batch. Use
           'Batch.real_cost()' to get the billed FlexCredit cost after the Batch
           has completed.
[10:59:28] Batch complete.
[10:59:30] loading SimulationData from
[10:59:31] loading SimulationData from
[10:59:32] loading SimulationData from
average power = 2.62e+00
derivative of average power wrt permittivity = -8.79e-01

Sanity Checking: Manual Loop over size#

Now we will implement the brute force approach to computing the multi-objective gradient by the naive approach of looping over dy values and computing the individual gradient contributions one by one.

def grad_manual(permittivity: float) -> float:
    """Average of O-th order diffracted power over all dy_sign values."""

    grad_avg = 0.0

    for dy_sign in (-1, 0, 1):
        print(f"working on dy_sign = {dy_sign}")

        def objective_fn(p):
            sims = make_sim(p, dy_sign=dy_sign)
            sim_data =, task_name=f"dy_sign={dy_sign}", verbose=False)
            return post_process(sim_data)

        grad_fn = jax.grad(objective_fn)

        gradient = grad_fn(permittivity)
        grad_avg += gradient / 3.0

    return grad_avg

grad_man = grad_manual(permittivity_val)

working on dy_sign = -1
working on dy_sign = 0
working on dy_sign = 1

We can see that they match, as expected.

print(f"gradient (batched) = {grad_power_min:.4e}")
print(f"gradient (looped) = {grad_man:.4e}")

gradient (batched) = -8.7947e-01
gradient (looped) = -8.7947e-01


The main thing to note here is that, using plugins.adjoint.web.run_async, all of the individual simulations were uploaded at roughly the same time.

This means that the server is able to work on them concurrently rather than needing to wait for the previously uploaded one to finish. The time savings for applications with several simulations can be enormous.

Note: Native support for multi-frequency output monitors was added in Tidy3D 2.5. Previously it was recommended that users use the run_async approach described here for multi-frequency objectives, but this is no longer necessary. That being said, for objectives with very tightly-packed frequency spacing in the output monitors, using a batch approach such as described here may be advantageous as the multi-frequency adjoint approach requires a run_time that scales inversely with the difference between frequency points, potentially requiring long simulation run times. The overall wait time may therefore be shorter using a batch approach.

[ ]:

[ ]: