Adjoint optimization of a wavelength division multiplexer#

In this notebook, we will use a multi-objective optimization to design a wavelength division multiplexer (WDM).

In short, this device takes in broadband light and directs light of different wavelengths to different output ports.

This demo combines the basic setup of our 3rd tutorial of a mode converter with the multi-frequency objective introduced in our 4th tutorial.

If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.

# first import tidy3d, its adjoint plugin, numpy, and jax.

import tidy3d as td
import tidy3d.plugins.adjoint as tda
import numpy as np
import jax.numpy as jnp
import jax


First we set up our basic simulation.

We have an input waveguide connected to a square design region, which has two output waveguides.

The square design region is a custom medium with a pixellated permittivity grid that we wish to optimize such that input light of different wavelengths get directed to different output ports.

As this is a SOI device, we typically define the design region and waveguides as Silicon sitting on an SiO2 substrate. For this demo, we make a 2D simulation, but it can be easily made 3D by changing the Lz parameter, adding dimension to the structures, and adding a substrate.

# material information
n_si = 3.49
n_sio2 = 1.45 # not used in 2D
n_air = 1

# design output wavelengths
wavelength_top = 1.300
wavelength_bot = 1.550

# and their corresponding frequencies and spectral information
freq_top = td.C_0 / wavelength_top
freq_bot = td.C_0 / wavelength_bot
freq0 = (freq_top + freq_bot) / 2.0
fwidth = abs(freq_bot - freq_top)
run_time = 100 / fwidth

# create dictionaries to reference these later by string key 'top' or 'bot'
freqs = dict(top=freq_top, bot=freq_bot)
wavelengths = dict(top=wavelength_top, bot=wavelength_bot)

# size of design region
lx = 2.8
ly = 2.8
lz = td.inf # in 2D, we say the size of components is inf but the size of simulation is 0.

# size of waveguides
wg_width = 0.3
wg_length = 1.5
wg_spacing = 0.8

# spacing between design region and PML in y
buffer = 1.5

# size of simulation
Lx = lx + wg_length * 2
Ly = ly + buffer * 2
Lz = 0.0

# fabrication constraint (feature size and projection strength)
radius = 0.150
beta = 30

# resolution information
min_steps_per_wvl = 25
# define the waveguide ports

wg_in = td.Structure(
        center=(-Lx/2, 0, 0),
        size=(wg_length * 2, wg_width, lz),

wg_top = td.Structure(
        center=(+Lx/2, +wg_width/2+wg_spacing/2, 0),
        size=(wg_length * 2, wg_width, lz),

wg_bot = td.Structure(
        center=(+Lx/2, -wg_width/2-wg_spacing/2, 0),
        size=(wg_length * 2, wg_width, lz),

# and a field monitor that measures fields on the z=0 plane
fld_mnt = td.FieldMonitor(
    size=(td.inf, td.inf, 0),
    freqs=[freq_top, freq_bot],
[12:58:10] WARNING: Default value for the field monitor 'colocate' setting has  
           changed to 'True' in Tidy3D 2.4.0. All field components will be      
           colocated to the grid boundaries. Set to 'False' to get the raw      
           fields on the Yee grid instead.                                      

Note we can ignore this warning as it will be resolved after 2.4.0

Define design region#

Here we define the design region as a pixellated grid of permittivity values.

We first define the overall geometry as a JaxBox and also the number of pixels in x and y.

nx = 55
ny = 55

design_region_geo = tda.JaxBox(
    size=(lx, ly, lz),

Next we write a function to give us the pixellated array as a function of our parameters through our filtering and projection methods, which are used to make the resulting structures easier to fabricate. For more details, refer to our 4th lecture in the inverse design 101 lecture series, which focuses on fabrication constraints.

We also wrap this function in another one that generates the entire JaxStructure corresponding to the design region, for convenience later.

from tidy3d.plugins.adjoint.utils.filter import ConicFilter

conic_filter = ConicFilter(radius=radius, design_region_dl=lx/nx)

# note: params is an array of shape (nx, ny) that stores values between -inf (air) and +inf (silicon)

def tanh_projection(x, eta=0.5):
    tanhbn = jnp.tanh(beta * eta)
    num = tanhbn + jnp.tanh(beta * (x - eta))
    den = tanhbn + jnp.tanh(beta * (1 - eta))
    return num / den

def filter_project(x, eta=0.5):
    x = conic_filter.evaluate(x)
    return tanh_projection(x, eta=eta)

def pre_process(params):
    """Get the permittivity values (1, eps_wg) array as a funciton of the parameters (0,1)"""
    params1 = filter_project(params)
    params2 = filter_project(params1)
    return params2

def make_eps(params):
    params = pre_process(params)
    eps_values = 1.0001 + (n_si**2 - 1.0001) * params
    return eps_values

def make_custom_medium(params):
    """Make JaxCustomMedium as a function of provided parameters."""
    eps = make_eps(params).reshape((nx, ny, 1, 1))

    xs = list(jnp.linspace(-lx/2, lx/2, nx))
    ys = list(jnp.linspace(-ly/2, ly/2, ny))
    zs = [0]
    freqs = [freq0]
    coords = dict(x=xs, y=ys, z=zs, f=freqs)

    eps_dataset = tda.JaxDataArray(values=eps, coords=coords)

    medium = tda.JaxCustomMedium(

    struct = tda.JaxStructure(

    return struct

Define base simulation#

With all of these functions and variables defined, we can write a single function to return our “base” JaxSimulation as a function of our design parameters. This function first constructs the design region and then creates a JaxSimulation with all of the basic parameters.

Note, we don’t yet have a source or monitors for injecting and measuring our fields, but will add those next after running the mode solver.

def make_sim_base(params):

    input_struct = make_custom_medium(params)

    return tda.JaxSimulation(
        size=(Lx, Ly, Lz),, wavelength=wavelength_top),
        structures=[wg_in, wg_top, wg_bot],
        boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True if Lz else False),

Let’s test out our function. We’ll make an initially random array of parameters between 0 and 1 and generate the base simulation to plot and inspect.

params0 = np.random.random((nx, ny))
sim_base = make_sim_base(params0)
ax = sim_base.plot_eps(z=0, monitor_alpha=0.0)

It all looks good, so now we add the bits that define the optimization.

Adding Mode Sources and Monitors#

Solving modes#

First, we need to create our ModeSource and ModeMonitor objects that inject and measure the modes that we are interested in optimizing.

We’ll use tidy3d’s ModeSolver and use the remote run function that gets more accurate results by running on Flexcompute’s servers.

from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.web import run as run_mode_solver

# we'll ask for 4 modes just to inspect
num_modes = 4

# let's define how large the mode planes are and how far they are from the PML relative to the design region
mode_size = (0, 1.8 * wg_spacing + wg_width, max([Lz, lz, 3]))
space_fraction = 0.2

# make a plane corresponding to where we wish to measure the input mode
plane_in = td.Box(
    center=(-Lx/2 + space_fraction * wg_length, 0, 0),

# construct the mode solver using our base sim (converted from `JaxSimulation` to regular `Simulation`) + our plane
mode_solver = ModeSolver(

Next we run the mode solver on the servers.

mode_data = run_mode_solver(mode_solver)
[12:58:12] Mode solver created with
[12:58:15] Mode solver status: queued
[12:58:16] Mode solver status: running
[12:58:30] Mode solver status: success

And visualize the results.

import matplotlib.pyplot as plt
print("Effective index of computed modes: ", jnp.array(mode_data.n_eff))

fig, axs = plt.subplots(num_modes, 3, figsize=(4*num_modes, 10), tight_layout=True)
for mode_ind in range(num_modes):
    for field_ind, field_name in enumerate(("Ex", "Ey", "Ez")):
        field = mode_data.field_components[field_name].sel(mode_index=mode_ind)
        ax = axs[mode_ind, field_ind]
        ax.set_title(f"{field_name}, index={mode_ind}")
Effective index of computed modes:  [[3.141966  2.8063042 1.9542309 1.0627401]]

We identify mode_index=0 as the first order mode that is out of plane of the device. Let’s choose to optimize our device with respect to this as the mode of interest for both the input and output.

We re-set the ModeSpec to only compute the number of modes we need (1) and also update our ModeSolver accordingly.

mode_index = 0
mode_spec = td.ModeSpec(num_modes=mode_index+1)
mode_solver = mode_solver.updated_copy(mode_spec=mode_spec)

Make input and output mode sources and monitors#

Next, we will generate the input ModeSource and output ModeMonitor objects using the convenience methods defined in the ModeSolver.

Because our plane was defined at the input port, we’ll modify the centers of the ModeMonitors to place them at the output ports to the right of the device.

# make source
mode_src = mode_solver.to_source(

# make a basic monitor
mode_mnt = mode_solver.to_monitor(

# construct the proper centers for the monitors at the 'top' and 'bot' ports
mnt_center_top = list(
mnt_center_bot = list(
mnt_center_top[0] =[0]
mnt_center_bot[0] =[0]
mnt_center_top[1] =[1]
mnt_center_bot[1] =[1]

# make a dictionary of names and frequencies to refer to later by key
mnt_names = dict(top="mode_top", bot="mode_bot")
mnt_freqs = dict(top=freq_top, bot=freq_bot)

# make two updated copies of the mode monitor with the proper frequencies, centers, and names
mode_mnt_top = mode_mnt.updated_copy(freqs=mnt_freqs["top"], center=mnt_center_top, name=mnt_names["top"])
mode_mnt_bot = mode_mnt.updated_copy(freqs=mnt_freqs["bot"], center=mnt_center_bot, name=mnt_names["bot"])

# make another dictionary mapping the keys to the monitors
mode_mnts = dict(top=mode_mnt_top, bot=mode_mnt_bot)

Add flux monitors#

For plotting later, we’ll add a couple of FluxMonitor objects at the output ports to measure the total flux over a large spectrum. With this data, we should be able to clearly see the difference in transmission for each of the ports at the design region and get an idea about the device bandwidth.

Nf = 121
freqs_flux = np.linspace(freq_bot - fwidth, freq_top + fwidth, Nf)

flux_mnt_names = dict(top="flux_top", bot="flux_bot")

flux_mnt_top = td.FluxMonitor(,

flux_mnt_bot = td.FluxMonitor(,

Add to simulation#

Finally, we will wrap our previous make_sim_base() function in a new one that adds our new objects to this base simulation given a key top_or_bot that can be "top" or "bot". Since the adjoint plugin can currently only support a single frequency, we need to make a new JaxSimulation for each of the ModeMonitor outputs. So this function gives us a way to toggle which simulation to generate.

def make_sim(params, top_or_bot: str):

    output_monitor = mode_mnts[top_or_bot]

    sim_base = make_sim_base(params)
    return sim_base.updated_copy(
        monitors=tuple(list(sim_base.monitors) + [flux_mnt_top, flux_mnt_bot])

Let’s make the two simulations corresponding to "bot" and "top" and inspect them and ensure it worked as expected.

sim_bot = make_sim(params0, "bot")
sim_top = make_sim(params0, "top")

sims = dict(top=sim_top, bot=sim_bot)

Note: the FluxMonitor objects are overlaying the output ModeMonitor objects, but we can identify the ModeMonitors by their arrows. We verify that the monitors and sources are in the proper locations and we are ready to move on.

f, (ax1, ax2) = plt.subplots(1, 2, tight_layout=True, figsize=(10, 4))

ax1 = sim_top.plot_eps(z=0.01, ax=ax1)
ax2 = sim_bot.plot_eps(z=0.01, ax=ax2)

Defining objective function#

With our simulation fully defined as a function of our parameters, we are ready to define our objective function.

Computing power transmission#

In this case, it is quite simple, we simply measure the transmitted power in our output waveguide mode. We wish to maximize transmission to the top port at the “top” wavelength (1330 nm) and maximize transmission to the bottom port at the “bot” wavelength (1550 nm).

We first define a function to measure each of these depending on the top_or_bot key.

def measure_power(sim_data, top_or_bot: str):
    mnt_name = mnt_names[top_or_bot]
    mnt_data = sim_data[mnt_name]
    amps = mnt_data.amps
    amp = amps.sel(direction="+", mode_index=0, f=freqs[top_or_bot])
    return jnp.abs(jnp.array(amps.values)) ** 2

Next we add a penalty to produce structures that are invariant under erosion and dilation, which is a useful approach to implementing minimum length scale features.

def penalty(params, delta_eps=0.49):
    params = pre_process(params)
    dilate_fn = lambda x: filter_project(x, eta=0.5-delta_eps)
    eroded_fn = lambda x: filter_project(x, eta=0.5+delta_eps)

    params_dilate_erode = eroded_fn(dilate_fn(params))
    params_erode_dilate = dilate_fn(eroded_fn(params))
    diff = params_dilate_erode - params_erode_dilate
    return jnp.linalg.norm(diff) / jnp.linalg.norm(jnp.ones_like(diff))

Writing objective function#

Then we write an objective function that calls this once per key using the run_asnyc functionality to do the simulations in parallel over the two frequencies.

Note, we return the list of JaxSimulationData as the second output. The reason for this is that we might wish to access our flux and field data later on. jax gives an option has_aux to use only the first output for differentiation while letting the user have access to the 2nd “auxiliary” output, which we will make use of.

def objective(params, verbose=False) -> float:
    sim_list = [make_sim(params, top_or_bot) for top_or_bot in ["top", "bot"]]
    sim_data_list = tda.web.run_async(sim_list, path_dir="data", verbose=verbose)
    powers = [measure_power(sim_data, top_or_bot) for sim_data, top_or_bot in zip(sim_data_list, ["top", "bot"])]
    J = jnp.sum(jnp.array(powers)) / len(powers) - penalty(params)
    return J, sim_data_list

Differentiating objective#

Finally, we can simply use jax to transform this objective function into a function that returns our objective function value, the auxiliary data, and our gradient, which we will feed to the optimizer.

grad_fn = jax.value_and_grad(objective, has_aux=True)

Let’s try out our gradient function with verbosity on for just this run.

(J, sim_data_list), grad = grad_fn(params0, verbose=True)
[12:58:34] Created task '0' with task_id
[12:58:35] Created task '1' with task_id
[12:58:40] Started working on Batch.
[12:58:48] Maximum FlexCredit cost: 0.050 for the whole batch. Use
           'Batch.real_cost()' to get the billed FlexCredit cost after the Batch
           has completed.
[12:59:41] Batch complete.
[12:59:45] loading SimulationData from
[12:59:48] loading SimulationData from
[12:59:49] Created task '0' with task_id
[12:59:50] Created task '1' with task_id
[12:59:54] Started working on Batch.
[13:00:03] Maximum FlexCredit cost: 0.050 for the whole batch. Use
           'Batch.real_cost()' to get the billed FlexCredit cost after the Batch
           has completed.
[13:00:34] Batch complete.

Run Opimization#

Finally, we are ready to optimize our device. We will make use the optax package to define an optimizer using the adam method, as we’ve done in the previous adjoint tutorials.

We record a history of objective function values, simulation data, and parameters for visualization later.

import optax

# hyperparameters
num_steps = 50
learning_rate = 7e-3

# initialize adam optimizer with starting parameters
params = params0
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# store history
Js = []
params_history = [params0]
data_history = []

for i in range(num_steps):

    # compute gradient and current objective funciton value
    (value, data), gradient = grad_fn(params)

    # outputs
    print(f"step = {i + 1}")
    print(f"\tJ = {value:.4e}")
    print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")

    # compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
    updates, opt_state = optimizer.update(-gradient, opt_state, params)
    params = optax.apply_updates(params, updates)

    # save history
step = 1
        J = -4.7329e-01
        grad_norm = 4.4967e-01
step = 2
        J = -3.9410e-01
        grad_norm = 8.4733e-01
step = 3
        J = -2.8600e-01
        grad_norm = 1.2904e+00
step = 4
        J = -2.2443e-01
        grad_norm = 4.7716e+00
step = 5
        J = -2.2247e-01
        grad_norm = 2.3129e+00
step = 6
        J = -1.7047e-01
        grad_norm = 1.5507e+00
step = 7
        J = -6.5531e-02
        grad_norm = 2.1438e+00
step = 8
        J = 2.6483e-02
        grad_norm = 3.2374e+00
step = 9
        J = 7.3216e-02
        grad_norm = 2.6327e+00
step = 10
        J = 1.2012e-01
        grad_norm = 3.7335e+00
step = 11
        J = 1.7554e-01
        grad_norm = 2.6433e+00
step = 12
        J = 2.2288e-01
        grad_norm = 2.2682e+00
step = 13
        J = 2.3549e-01
        grad_norm = 4.2679e+00
step = 14
        J = 2.7424e-01
        grad_norm = 3.9962e+00
step = 15
        J = 2.9780e-01
        grad_norm = 3.5945e+00
step = 16
        J = 3.3670e-01
        grad_norm = 2.0813e+00
step = 17
        J = 3.5610e-01
        grad_norm = 2.6826e+00
step = 18
        J = 3.6166e-01
        grad_norm = 2.9888e+00
step = 19
        J = 3.8068e-01
        grad_norm = 3.1492e+00
step = 20
        J = 4.0101e-01
        grad_norm = 2.0711e+00
step = 21
        J = 4.1796e-01
        grad_norm = 2.6197e+00
step = 22
        J = 4.2581e-01
        grad_norm = 5.0041e+00
step = 23
        J = 4.2845e-01
        grad_norm = 4.3646e+00
step = 24
        J = 4.5843e-01
        grad_norm = 2.3381e+00
step = 25
        J = 4.3961e-01
        grad_norm = 6.1606e+00
step = 26
        J = 4.6821e-01
        grad_norm = 3.2926e+00
step = 27
        J = 4.7678e-01
        grad_norm = 3.8982e+00
step = 28
        J = 4.8196e-01
        grad_norm = 6.0061e+00
step = 29
        J = 5.0639e-01
        grad_norm = 3.3902e+00
step = 30
        J = 5.1525e-01
        grad_norm = 3.4171e+00
step = 31
        J = 5.1659e-01
        grad_norm = 3.9614e+00
step = 32
        J = 5.3002e-01
        grad_norm = 1.5894e+00
step = 33
        J = 5.2986e-01
        grad_norm = 2.1878e+00
step = 34
        J = 5.3621e-01
        grad_norm = 1.7290e+00
step = 35
        J = 5.4549e-01
        grad_norm = 1.9306e+00
step = 36
        J = 5.4732e-01
        grad_norm = 2.0833e+00
step = 37
        J = 5.5743e-01
        grad_norm = 1.0262e+00
step = 38
        J = 5.5981e-01
        grad_norm = 2.0251e+00
step = 39
        J = 5.6720e-01
        grad_norm = 1.7411e+00
step = 40
        J = 5.7623e-01
        grad_norm = 1.0897e+00
step = 41
        J = 5.8176e-01
        grad_norm = 9.9590e-01
step = 42
        J = 5.8916e-01
        grad_norm = 1.0190e+00
step = 43
        J = 5.9337e-01
        grad_norm = 2.1709e+00
step = 44
        J = 5.9594e-01
        grad_norm = 4.0222e+00
step = 45
        J = 5.8571e-01
        grad_norm = 5.9844e+00
step = 46
        J = 5.9020e-01
        grad_norm = 6.6299e+00
step = 47
        J = 6.3105e-01
        grad_norm = 1.8183e+00
step = 48
        J = 6.1828e-01
        grad_norm = 4.9559e+00
step = 49
        J = 6.2594e-01
        grad_norm = 4.7662e+00
step = 50
        J = 6.4317e-01
        grad_norm = 1.0012e+00

Visualize Results#

Let’s visualize the results of our optimization.

Objective function vs Iteration#

First we inspect the objective function value as a function of optimization iteration number. We see that it steadily increases as expected.

The presence of fabrication constraints tends to create some minor bumps in the optimization, which can be a signal that one needs to reduce the step size, but these results are sufficient for our purposes.

plt.xlabel('iteration number')
plt.ylabel('objective function')

Final Simulation#

Let’s take a look at the final simulation, which we grab from our history.

sim_data_final = data_history[-1][0]
sim_final = sim_data_final.simulation.to_simulation()[0]

We notice that the structure has reasonably large feature sizes but is not well binarized. This could be improved by increasing the beta projection value slowly over iteration number, as was done in the grating coupler tutorial.

ax = sim_final.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0)


Let’s inspect the flux over each of the output ports as a function of wavelength.

We notice that the top and bottom ports have peaks in transmission at their corresponding design wavelengths, as expected!

# plot flux
for key, color in zip(("top", "bot"), ('royalblue', 'firebrick')):
    freq = freqs[key]
    flux_data = sim_data_final[flux_mnt_names[key]]
    wvl_nm = 1000 * td.C_0 / freq
    wavelengths_nm = 1000 * td.C_0 / np.array(flux_data.flux.f)
    flux = np.array(flux_data.flux.values)
    flux_db = 10 * np.log(flux)
    label = f"{key} ({int(wvl_nm)} nm)"
    plt.plot(wavelengths_nm, flux_db, label=label, color=color)
    plt.scatter([wvl_nm], [1], 100, marker="*", color=color)
    plt.xlabel('wavelength (nm)')
    plt.ylabel('transmission (dB)')


Let’s also plot the field intensity patterns at the two design wavelengths. We see from this plot the expected result that the power is directed to the design port at each frequency.

# plot fields at the two design wavelengths

fig, axes = plt.subplots(1, 2, tight_layout=True, figsize=(7, 3))

for key, ax in zip(("top", "bot"), axes):
    freq = freqs[key]
    sim_data_final.plot_field("field", "E", "abs^2", f=freq, ax=ax, vmax=1200)
    wvl = 1000 * td.C_0 / freq
    ax.set_title(f"wavelength = {int(wvl)} nm")


Finally, we animate this plot as a function of iteration number. The animation shows the device quickly accomplishing our design objective.

Note: can take a few minutes to complete

import matplotlib.animation as animation
from IPython.display import HTML

fig, (ax1, ax2, ax3) = fig, axes = plt.subplots(1, 3, tight_layout=False, figsize=(9, 4))

def animate(i):

    sim_data_list_i = data_history[i]

    sim_i = sim_data_list_i[0].simulation.to_simulation()[0]
    sim_i.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0, ax=ax1)

    for key, ax, sim_data_i in zip(("top", "bot"), (ax2, ax3), sim_data_list_i):

        freq = freqs[key]
        wvl = 1000 * td.C_0 / freq

        int_i = sim_data_i.get_intensity("field").sel(f=freq)
        int_i.squeeze().plot.pcolormesh(x='x', y='y', ax=ax, add_colorbar=False, cmap="magma", vmax=1000)

        ax.set_title(f"wavelength = {int(wvl)} nm")

# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(data_history))

# display the animation (press "play" to start)