To install the jax module required for this feature, we recommend running pip install "tidy3d[jax]".

In this notebook, we will show how to use the adjoint plugin efficiently for objectives involving several simulations.

One common application of this, which we will explore, is objective functions that depend on several frequencies.

:

import numpy as np
import jax.numpy as jnp
import jax

import tidy3d as td


## Setup#

First, let’s define the frequencies that our objective will depend on

:

freq_min = 1e14
freq_max = 2e14
num_freqs = 2
freqs = np.linspace(freq_min, freq_max, num_freqs)

wavelength_max = td.C_0 / freq_min



Now we set up some physical parameters.

We will be putting a td.Box in the center of a domain with a point source on one side and a diffraction monitor on the other.

The objective will involve summing the power of the 0th order diffraction order over all frequencies defined above.

The gradient of this objective will be computed with respect to the permittivity of the box.

:

permittivity_val = 2.0

lx = wavelength_max
ly = wavelength_max
lz = wavelength_max

buffer = 2 * wavelength_max

Lx = lx + 2 * buffer
Ly = ly + 2 * buffer
Lz = lz + 2 * buffer

src_pos_x = -Lx / 2 + buffer / 2
mnt_pos_x = +Lx / 2 - buffer / 2


:

def make_sim(permittivity: float, freq: float) -> tda.JaxSimulation:
"""Make a simulation as a function of the box permittivity and the frequency."""

box = tda.JaxStructure(
geometry=tda.JaxBox(center=(0.0, 0.0, 0.0), size=(lx, ly, lz)),
medium=tda.JaxMedium(permittivity=permittivity),
)

src = td.PointDipole(
center=(src_pos_x, 0, 0),
polarization="Ey",
source_time=td.GaussianPulse(
freq0=freq,
fwidth=freq / 10,
),
)

mnt = td.DiffractionMonitor(
center=(mnt_pos_x, 0, 0),
size=(0, td.inf, td.inf),
freqs=[freq],
name="diffraction",
)

return tda.JaxSimulation(
size=(Lx, Ly, Lz),
input_structures=[box],
output_monitors=[mnt],
sources=[src],
grid_spec=td.GridSpec.auto(wavelength=td.C_0 / freq),
boundary_spec=td.BoundarySpec(
x=td.Boundary.pml(), y=td.Boundary.periodic(), z=td.Boundary.periodic()
),
run_time=5000 / freq,
)



Let’s try it out and make a jax simulation to plot

:

jax_sim = make_sim(permittivity=permittivity_val, freq=freq_min)
ax = jax_sim.plot(z=0) ## Define Objective#

Now let’s define our objective function for a single simulation result.

:

def post_process(sim_data: tda.JaxSimulationData) -> float:
"""O-th order diffracted power."""
amp = sim_data["diffraction"].amps.isel(f=0).sel(orders_x=0, orders_y=0)
return abs(amp.values) ** 2



And we can put everything together to define our full objective over all frequencies using the plugins.adjoint.web.run_async function.

:

def objective(permittivity: float) -> float:
"""Average of O-th order diffracted power over all frequencies."""
sim_list = [make_sim(permittivity, freq) for freq in freqs]
sim_data_list = run_async(sim_list, path_dir="data", verbose=True)
power = [post_process(sim_data) for sim_data in sim_data_list]
return jnp.sum(jnp.array(power)) / len(freqs)



Let’s now run this function as is to make sure it works properly for only computing the objective.

:

power_average = objective(permittivity=permittivity_val)
print(f"average power (freq) = {power_average:.2e}")


[16:32:08] Created task '0' with task_id                           webapi.py:188
'fdve-dcd16c19-9334-4efc-8e38-ed712de442bbv1'.

           View task using web UI at                               webapi.py:190
dcd16c19-9334-4efc-8e38-ed712de442bbv1'.

[16:32:09] Created task '1' with task_id                           webapi.py:188
'fdve-23e8e41f-1342-49b0-8118-60809a049bccv1'.

           View task using web UI at                               webapi.py:190
23e8e41f-1342-49b0-8118-60809a049bccv1'.

[16:32:10] Started working on Batch.                            container.py:475

[16:32:20] Maximum FlexCredit cost: 0.090 for the whole batch.  container.py:479
Use 'Batch.real_cost()' to get the billed FlexCredit
cost after the Batch has completed.

[16:32:45] Batch complete.                                      container.py:522

[16:32:46] loading SimulationData from                             webapi.py:590
data/fdve-dcd16c19-9334-4efc-8e38-ed712de442bbv1.hdf5

           loading SimulationData from                             webapi.py:590
data/fdve-23e8e41f-1342-49b0-8118-60809a049bccv1.hdf5

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

average power (freq) = 6.21e-01


Finally, we are ready to use jax.grad to differentiate this function.

:

grad_objective = jax.grad(objective)

print(f"derivative of average power wrt permittivity = {grad_power_average:.2e}")


[16:32:47] Created task '0' with task_id                           webapi.py:188
'fdve-9974418c-1a79-4829-a152-74f4e225e104v1'.

           View task using web UI at                               webapi.py:190
9974418c-1a79-4829-a152-74f4e225e104v1'.

[16:32:48] Created task '1' with task_id                           webapi.py:188
'fdve-3da4125d-4fda-4cf0-9117-b751dc8b6f43v1'.

           View task using web UI at                               webapi.py:190
3da4125d-4fda-4cf0-9117-b751dc8b6f43v1'.

[16:32:50] Started working on Batch.                            container.py:475

[16:32:59] Maximum FlexCredit cost: 0.091 for the whole batch.  container.py:479
Use 'Batch.real_cost()' to get the billed FlexCredit
cost after the Batch has completed.

[16:33:25] Batch complete.                                      container.py:522

[16:33:26] loading SimulationData from                             webapi.py:590
data/fdve-9974418c-1a79-4829-a152-74f4e225e104v1.hdf5

[16:33:27] loading SimulationData from                             webapi.py:590
data/fdve-3da4125d-4fda-4cf0-9117-b751dc8b6f43v1.hdf5

[16:33:28] Created task '0' with task_id                           webapi.py:188
'fdve-2467993a-1e58-4383-8740-18feb13f9ec6v1'.

           View task using web UI at                               webapi.py:190
2467993a-1e58-4383-8740-18feb13f9ec6v1'.

           Created task '1' with task_id                           webapi.py:188
'fdve-bb75b4a1-50ca-465d-9c53-dd36c6f734a8v1'.

           View task using web UI at                               webapi.py:190
bb75b4a1-50ca-465d-9c53-dd36c6f734a8v1'.

[16:33:30] Started working on Batch.                            container.py:475

[16:33:40] Maximum FlexCredit cost: 0.090 for the whole batch.  container.py:479
Use 'Batch.real_cost()' to get the billed FlexCredit
cost after the Batch has completed.

[16:34:06] Batch complete.                                      container.py:522

derivative of average power wrt permittivity = -1.24e+00


## Sanity Checking: Manual Loop over Frequency#

Now we will implement the brute force approach to computing the multi-frequency gradient through looping over frequency and adding the individual gradient contributions.

:

def grad_manual(permittivity: float) -> float:
"""Average of O-th order diffracted power over all frequencies."""

def objective(permittivity, freq):
sim = make_sim(permittivity, freq)
return jnp.sum(post_process(sim_data))

for freq in freqs:
print(f"working on freq = {freq:.2e} (Hz)")
obj_fn = lambda x: objective(x, freq=freq)


:

grad_man = grad_manual(permittivity_val)


working on freq = 1.00e+14 (Hz)
working on freq = 2.00e+14 (Hz)


Finally, we can ensure that they match.

:

print(f"gradient (batched) = {grad_power_average:.4e}")


gradient (batched) = -1.2444e+00


### Takeaways#

The main thing to note here is that, using plugins.adjoint.web.run_async, all of the individual simulations were uploaded at roughly the same time.

This means that the server is able to work on them concurrently rather than needing to wait for the previously uploaded one to finish. The time savings for applications with several simulations can be enormous.

While we focused this example on a multi-frequency objective, this basic strategy can be broadly applied to other multi-objective problems.

For example, if the objective depends on the results from slightly different simulations, for example dilated or contracted structures, random variation, or other instances, this general approach can be very useful.

[ ]: