Adjoint Plugin: 4 Multi-Objective Gradients#

In this notebook, we will show how to use the adjoint plugin efficiently for objectives involving several simulations.

One common application of this, which we will explore, is objective functions that depend on several frequencies.

import numpy as np
import jax.numpy as jnp
import jax

import tidy3d as td
import tidy3d.plugins.adjoint as tda


First, let’s define the frequencies that our objective will depend on

freq_min = 1e14
freq_max = 2e14
num_freqs = 2
freqs = np.linspace(freq_min, freq_max, num_freqs)

wavelength_max = td.C_0 / freq_min

Now we set up some physical parameters.

We will be putting a td.Box in the center of a domain with a point source on one side and a diffraction monitor on the other.

The objective will involve summing the power of the 0th order diffraction order over all frequencies defined above.

The gradient of this objective will be computed with respect to the permittivity of the box.

permittivity_val = 2.0

lx = wavelength_max
ly = wavelength_max
lz = wavelength_max

buffer = 2 * wavelength_max

Lx = lx + 2 * buffer
Ly = ly + 2 * buffer
Lz = lz + 2 * buffer

src_pos_x = -Lx / 2 + buffer / 2
mnt_pos_x = +Lx / 2 - buffer / 2

def make_sim(permittivity: float, freq: float) -> tda.JaxSimulation:
    """Make a simulation as a function of the box permittivity and the frequency."""

    box = tda.JaxStructure(
        geometry=tda.JaxBox(center=(0.0, 0.0, 0.0), size=(lx, ly, lz)),

    src = td.PointDipole(
        center=(src_pos_x, 0, 0),
            fwidth=freq / 10,

    mnt = td.DiffractionMonitor(
        center=(mnt_pos_x, 0, 0),
        size=(0, td.inf, td.inf),

    return tda.JaxSimulation(
        size=(Lx, Ly, Lz),
        sources=[src], / freq),
            x=td.Boundary.pml(), y=td.Boundary.periodic(), z=td.Boundary.periodic()
        run_time=5000 / freq,

Let’s try it out and make a jax simulation to plot

jax_sim = make_sim(permittivity=permittivity_val, freq=freq_min)
ax = jax_sim.plot(z=0)


Define Objective#

Now let’s define our objective function for a single simulation result.

def post_process(sim_data: tda.JaxSimulationData) -> float:
    """O-th order diffracted power."""
    amp = sim_data["diffraction"].amps.isel(f=0).sel(orders_x=0, orders_y=0)
    return abs(amp.values) ** 2

And we can put everything together to define our full objective over all frequencies using the plugins.adjoint.web.run_async function.

def objective(permittivity: float) -> float:
    """Average of O-th order diffracted power over all frequencies."""
    sim_list = [make_sim(permittivity, freq) for freq in freqs]
    sim_data_list = tda.web.run_async(sim_list, path_dir="data", verbose=True)
    power = [post_process(sim_data) for sim_data in sim_data_list]
    return jnp.sum(jnp.array(power)) / len(freqs)

Let’s now run this function as is to make sure it works properly for only computing the objective.

power_average = objective(permittivity=permittivity_val)
print(f"average power (freq) = {power_average:.2e}")

[14:33:32] loading SimulationData from data/fdve-b62f8de1-c777-4e30-8d10-8470d2d48355v1.hdf5
[14:33:33] loading SimulationData from data/fdve-e6a82480-ac08-4a3a-ac7a-17a3995c7afev1.hdf5
average power (freq) = 7.64e-01

Multi-Frequency Gradient Calculation#

Finally, we are ready to use jax.grad to differentiate this function.

grad_objective = jax.grad(objective)

grad_power_average = grad_objective(permittivity_val)
print(f"derivative of average power wrt permittivity = {grad_power_average:.2e}")

[14:34:08] loading SimulationData from data/fdve-e21fb23f-b291-4996-8d8b-ddf3ff439b07v1.hdf5
[14:34:10] loading SimulationData from data/fdve-a51abc4a-f53f-4486-9434-4d62cac117c3v1.hdf5
[14:34:45] loading SimulationData from data/fdve-8665fc82-4c7d-4f9a-b487-02e456044bb9v1.hdf5
[14:34:53] loading SimulationData from data/fdve-2833c5f4-8c2c-41ae-a4a5-35d26af96bd8v1.hdf5
derivative of average power wrt permittivity = -1.82e+00

Sanity Checking: Manual Loop over Frequency#

Now we will implement the brute force approach to computing the multi-frequency gradient through looping over frequency and adding the individual gradient contributions.

def grad_manual(permittivity: float) -> float:
    """Average of O-th order diffracted power over all frequencies."""

    def objective(permittivity, freq):
        sim = make_sim(permittivity, freq)
        sim_data =, task_name=f"f={freq:.2e}", verbose=False)
        return jnp.sum(post_process(sim_data))

    total_grad = 0.0
    for freq in freqs:
        print(f"working on freq = {freq:.2e} (Hz)")
        obj_fn = lambda x: objective(x, freq=freq)
        grad_fn = jax.grad(obj_fn)
        gradient = grad_fn(permittivity)
        total_grad += gradient

    return total_grad / len(freqs)

grad_man = grad_manual(permittivity_val)

working on freq = 1.00e+14 (Hz)
working on freq = 2.00e+14 (Hz)

Finally, we can ensure that they match.

print(f"gradient (batched) = {grad_power_average:.4e}")
print(f"gradient (looped) = {grad_man:.4e}")

gradient (batched) = -1.8176e+00
gradient (looped) = -1.8176e+00


The main thing to note here is that, using plugins.adjoint.web.run_async, all of the individual simulations were uploaded at roughly the same time.

This means that the server is able to work on them concurrently rather than needing to wait for the previously uploaded one to finish. The time savings for applications with several simulations can be enormous.

While we focused this example on a multi-frequency objective, this basic strategy can be broadly applied to other multi-objective problems.

For example, if the objective depends on the results from slightly different simulations, for example dilated or contracted structures, random variation, or other instances, this general approach can be very useful.

[ ]: