Adjoint Plugin: 4 Multi-Objective Gradients
Contents
Adjoint Plugin: 4 Multi-Objective Gradients#
In this notebook, we will show how to use the adjoint plugin efficiently for objectives involving several simulations.
One common application of this, which we will explore, is objective functions that depend on several frequencies.
[1]:
import numpy as np
import jax.numpy as jnp
import jax
import tidy3d as td
import tidy3d.plugins.adjoint as tda
Setup#
First, let’s define the frequencies that our objective will depend on
[2]:
freq_min = 1e14
freq_max = 2e14
num_freqs = 2
freqs = np.linspace(freq_min, freq_max, num_freqs)
wavelength_max = td.C_0 / freq_min
Now we set up some physical parameters.
We will be putting a td.Box
in the center of a domain with a point source on one side and a diffraction monitor on the other.
The objective will involve summing the power of the 0th order diffraction order over all frequencies defined above.
The gradient of this objective will be computed with respect to the permittivity of the box.
[3]:
permittivity_val = 2.0
lx = wavelength_max
ly = wavelength_max
lz = wavelength_max
buffer = 2 * wavelength_max
Lx = lx + 2 * buffer
Ly = ly + 2 * buffer
Lz = lz + 2 * buffer
src_pos_x = -Lx / 2 + buffer / 2
mnt_pos_x = +Lx / 2 - buffer / 2
[4]:
def make_sim(permittivity: float, freq: float) -> tda.JaxSimulation:
"""Make a simulation as a function of the box permittivity and the frequency."""
box = tda.JaxStructure(
geometry=tda.JaxBox(center=(0.0, 0.0, 0.0), size=(lx, ly, lz)),
medium=tda.JaxMedium(permittivity=permittivity),
)
src = td.PointDipole(
center=(src_pos_x, 0, 0),
polarization="Ey",
source_time=td.GaussianPulse(
freq0=freq,
fwidth=freq / 10,
),
)
mnt = td.DiffractionMonitor(
center=(mnt_pos_x, 0, 0),
size=(0, td.inf, td.inf),
freqs=[freq],
name="diffraction",
)
return tda.JaxSimulation(
size=(Lx, Ly, Lz),
input_structures=[box],
output_monitors=[mnt],
sources=[src],
grid_spec=td.GridSpec.auto(wavelength=td.C_0 / freq),
boundary_spec=td.BoundarySpec(
x=td.Boundary.pml(), y=td.Boundary.periodic(), z=td.Boundary.periodic()
),
run_time=5000 / freq,
)
Let’s try it out and make a jax simulation to plot
[5]:
jax_sim = make_sim(permittivity=permittivity_val, freq=freq_min)
ax = jax_sim.plot(z=0)
Define Objective#
Now let’s define our objective function for a single simulation result.
[6]:
def post_process(sim_data: tda.JaxSimulationData) -> float:
"""O-th order diffracted power."""
amp = sim_data["diffraction"].amps.isel(f=0).sel(orders_x=0, orders_y=0)
return abs(amp.values) ** 2
And we can put everything together to define our full objective over all frequencies using the plugins.adjoint.web.run_async function.
[7]:
def objective(permittivity: float) -> float:
"""Average of O-th order diffracted power over all frequencies."""
sim_list = [make_sim(permittivity, freq) for freq in freqs]
sim_data_list = tda.web.run_async(sim_list, path_dir="data", verbose=True)
power = [post_process(sim_data) for sim_data in sim_data_list]
return jnp.sum(jnp.array(power)) / len(freqs)
Let’s now run this function as is to make sure it works properly for only computing the objective.
[8]:
power_average = objective(permittivity=permittivity_val)
print(f"average power (freq) = {power_average:.2e}")
average power (freq) = 7.64e-01
Multi-Frequency Gradient Calculation#
Finally, we are ready to use jax.grad
to differentiate this function.
[9]:
grad_objective = jax.grad(objective)
grad_power_average = grad_objective(permittivity_val)
print(f"derivative of average power wrt permittivity = {grad_power_average:.2e}")
derivative of average power wrt permittivity = -1.82e+00
Sanity Checking: Manual Loop over Frequency#
Now we will implement the brute force approach to computing the multi-frequency gradient through looping over frequency and adding the individual gradient contributions.
[10]:
def grad_manual(permittivity: float) -> float:
"""Average of O-th order diffracted power over all frequencies."""
def objective(permittivity, freq):
sim = make_sim(permittivity, freq)
sim_data = tda.web.run(sim, task_name=f"f={freq:.2e}", verbose=False)
return jnp.sum(post_process(sim_data))
total_grad = 0.0
for freq in freqs:
print(f"working on freq = {freq:.2e} (Hz)")
obj_fn = lambda x: objective(x, freq=freq)
grad_fn = jax.grad(obj_fn)
gradient = grad_fn(permittivity)
total_grad += gradient
return total_grad / len(freqs)
[11]:
grad_man = grad_manual(permittivity_val)
working on freq = 1.00e+14 (Hz)
working on freq = 2.00e+14 (Hz)
Finally, we can ensure that they match.
[12]:
print(f"gradient (batched) = {grad_power_average:.4e}")
print(f"gradient (looped) = {grad_man:.4e}")
gradient (batched) = -1.8176e+00
gradient (looped) = -1.8176e+00
Takeaways#
The main thing to note here is that, using plugins.adjoint.web.run_async, all of the individual simulations were uploaded at roughly the same time.
This means that the server is able to work on them concurrently rather than needing to wait for the previously uploaded one to finish. The time savings for applications with several simulations can be enormous.
While we focused this example on a multi-frequency objective, this basic strategy can be broadly applied to other multi-objective problems.
For example, if the objective depends on the results from slightly different simulations, for example dilated or contracted structures, random variation, or other instances, this general approach can be very useful.
[ ]: