# Adjoint optimization of a wavelength division multiplexer#

In this notebook, we will use a multi-objective optimization to design a wavelength division multiplexer (WDM).

In short, this device takes in broadband light and directs light of different wavelengths to different output ports.

This demo combines the basic setup of our 3rd tutorial of a mode converter with the multi-frequency feature introduced in Tidy3D version 2.5.

If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.

[1]:

import matplotlib.pylab as plt

# first import tidy3d, its adjoint plugin, numpy, and jax.
import tidy3d as td
import numpy as np
import jax.numpy as jnp
import jax

np.random.seed(2)


## Setup#

First we set up our basic simulation.

We have an input waveguide connected to a square design region, which has two output waveguides.

The square design region is a custom medium with a pixellated permittivity grid that we wish to optimize such that input light of different wavelengths get directed to different output ports.

As this is a SOI device, we typically define the design region and waveguides as Silicon sitting on an SiO2 substrate. For this demo, we make a 2D simulation, but it can be easily made 3D by changing the Lz parameter, adding dimension to the structures, and adding a substrate.

[2]:

# material information
n_si = 3.49
n_sio2 = 1.45 # not used in 2D
n_air = 1

# design output wavelengths
wavelength_top = 1.300
wavelength_bot = 1.550

# and their corresponding frequencies and spectral information
freq_top = td.C_0 / wavelength_top
freq_bot = td.C_0 / wavelength_bot
freq0 = (freq_top + freq_bot) / 2.0
fwidth = abs(freq_bot - freq_top)
run_time = 100 / fwidth

# create dictionaries to reference these later by string key 'top' or 'bot'
freqs = dict(top=freq_top, bot=freq_bot)
wavelengths = dict(top=wavelength_top, bot=wavelength_bot)

# size of design region
lx = 2.8
ly = 2.8
lz = td.inf # in 2D, we say the size of components is inf but the size of simulation is 0.

# size of waveguides
wg_width = 0.3
wg_length = 1.5
wg_spacing = 0.8

# spacing between design region and PML in y
buffer = 1.5

# size of simulation
Lx = lx + wg_length * 2
Ly = ly + buffer * 2
Lz = 0.0

# fabrication constraint (feature size and projection strength)
beta = 30

# resolution information
min_steps_per_wvl = 25

[3]:

# define the waveguide ports

wg_in = td.Structure(
geometry=td.Box(
center=(-Lx/2, 0, 0),
size=(wg_length * 2, wg_width, lz),
),
medium=td.Medium(permittivity=n_si**2)
)

wg_top = td.Structure(
geometry=td.Box(
center=(+Lx/2, +wg_width/2+wg_spacing/2, 0),
size=(wg_length * 2, wg_width, lz),
),
medium=td.Medium(permittivity=n_si**2)
)

wg_bot = td.Structure(
geometry=td.Box(
center=(+Lx/2, -wg_width/2-wg_spacing/2, 0),
size=(wg_length * 2, wg_width, lz),
),
medium=td.Medium(permittivity=n_si**2)
)

# and a field monitor that measures fields on the z=0 plane
fld_mnt = td.FieldMonitor(
center=(0,0,0),
size=(td.inf, td.inf, 0),
freqs=[freq_top, freq_bot],
name="field",
)


Note we can ignore this warning as it will be resolved after 2.4.0

### Define design region#

Here we define the design region as a pixellated grid of permittivity values.

We first define the overall geometry as a JaxBox and also the number of pixels in x and y.

[5]:

nx = 55
ny = 55

design_region_geo = tda.JaxBox(
size=(lx, ly, lz),
center=(0,0,0)
)
design_region_dl=lx/nx


Next we write a function to give us the pixellated array as a function of our parameters through our filtering and projection methods, which are used to make the resulting structures easier to fabricate. For more details, refer to our 4th lecture in the inverse design 101 lecture series, which focuses on fabrication constraints.

We also wrap this function in another one that generates the entire JaxStructure corresponding to the design region, for convenience later.

[15]:

from tidy3d.plugins.adjoint.utils.filter import ConicFilter

# note: params is an array of shape (nx, ny) that stores values between -inf (air) and +inf (silicon)

def tanh_projection(x, beta, eta=0.5):
tanhbn = jnp.tanh(beta * eta)
num = tanhbn + jnp.tanh(beta * (x - eta))
den = tanhbn + jnp.tanh(beta * (1 - eta))
return num / den

def filter_project(x, beta, eta=0.5):
x = conic_filter.evaluate(x)
return tanh_projection(x, beta=beta, eta=eta)

# number of times to filter -> project. Two times with a lower beta (~30) seems to give decent results.
num_projections = 2

def pre_process(params, beta):
"""Get the permittivity values (1, eps_wg) array as a funciton of the parameters (0,1)"""
for _ in range(num_projections):
params = filter_project(params, beta=beta)
return params

def make_eps(params, beta):
params = pre_process(params, beta=beta)
eps_values = 1 + (n_si**2 - 1) * params
return eps_values

def make_custom_medium(params, beta):
"""Make JaxCustomMedium as a function of provided parameters."""
eps = make_eps(params, beta).reshape((nx, ny, 1, 1))
eps = jnp.where(eps < 1, 1, eps)
eps = jnp.where(eps > n_si**2, n_si**2, eps)

xs = list(jnp.linspace(-lx/2, lx/2, nx))
ys = list(jnp.linspace(-ly/2, ly/2, ny))
zs = [0]
freqs = [freq0]
coords = dict(x=xs, y=ys, z=zs, f=freqs)

eps_dataset = tda.JaxDataArray(values=eps, coords=coords)

medium = tda.JaxCustomMedium(
eps_dataset=tda.JaxPermittivityDataset(
eps_xx=eps_dataset,
eps_yy=eps_dataset,
eps_zz=eps_dataset,
)
)

struct = tda.JaxStructure(
geometry=design_region_geo,
medium=medium
)

return struct


### Define base simulation#

With all of these functions and variables defined, we can write a single function to return our βbaseβ JaxSimulation as a function of our design parameters. This function first constructs the design region and then creates a JaxSimulation with all of the basic parameters.

Note, we donβt yet have a source or monitors for injecting and measuring our fields, but will add those next after running the mode solver.

[16]:

def make_sim_base(params, beta):

input_struct = make_custom_medium(params, beta=beta)

return tda.JaxSimulation(
size=(Lx, Ly, Lz),
grid_spec=td.GridSpec.auto(min_steps_per_wvl=min_steps_per_wvl, wavelength=wavelength_top),
structures=[wg_in, wg_top, wg_bot],
monitors=[fld_mnt],
input_structures=[input_struct],
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True if Lz else False),
run_time=run_time,
)


Letβs test out our function. Weβll make an initially random array of parameters between 0 and 1 and generate the base simulation to plot and inspect.

[8]:

params0 = np.random.random((nx, ny))
sim_base = make_sim_base(params0, beta=1)

WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[9]:

ax = sim_base.plot_eps(z=0, monitor_alpha=0.0)


It all looks good, so now we add the bits that define the optimization.

## Adding Mode Sources and Monitors#

### Solving modes#

First, we need to create our ModeSource and ModeMonitor objects that inject and measure the modes that we are interested in optimizing.

Weβll use tidy3dβs ModeSolver and use the remote run function that gets more accurate results by running on Flexcomputeβs servers.

[10]:

from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.web import run as run_mode_solver

# we'll ask for 4 modes just to inspect
num_modes = 4

# let's define how large the mode planes are and how far they are from the PML relative to the design region
mode_size = (0, 1.8 * wg_spacing + wg_width, max([Lz, lz, 3]))
space_fraction = 0.2

# make a plane corresponding to where we wish to measure the input mode
plane_in = td.Box(
center=(-Lx/2 + space_fraction * wg_length, 0, 0),
size=mode_size,
)

# construct the mode solver using our base sim (converted from JaxSimulation to regular Simulation) + our plane
mode_solver = ModeSolver(
simulation=sim_base.to_simulation()[0],
plane=plane_in,
freqs=[freq_top],
mode_spec=td.ModeSpec(num_modes=num_modes)
)


Next we run the mode solver on the servers.

[12]:

mode_data = run_mode_solver(mode_solver, reduce_simulation=True)

10:59:18 PST Mode solver created with
solver_id='mo-4bbb01f6-7a1f-490a-a82f-7d81cbd26a0f'.

10:59:19 PST Mode solver status: queued

11:01:30 PST Mode solver status: running

11:01:33 PST Mode solver status: success


And visualize the results.

[17]:

fig, axs = plt.subplots(num_modes, 3, figsize=(12, 12), tight_layout=True)
for mode_index in range(num_modes):
vmax = 1.1 * max(abs(mode_data.field_components[n].sel(mode_index=mode_index)).max() for n in ("Ex", "Ey", "Ez"))
for field_name, ax in zip(("Ex", "Ey", "Ez"), axs[mode_index]):
field = mode_data.field_components[field_name].sel(mode_index=mode_index)
field.real.plot(label="Real", ax=ax)
field.imag.plot(ls="--", label="Imag", ax=ax)
ax.set_title(f'index={mode_index}, {field_name}')
ax.set_ylim(-vmax, vmax)

axs[0, 0].legend()

print("Effective index of computed modes: ", np.array(mode_data.n_eff))

Effective index of computed modes:  [[3.141611 2.806403 1.953868 1.062474]]


We identify mode_index=0 as the first order mode that is out of plane of the device. Letβs choose to optimize our device with respect to this as the mode of interest for both the input and output.

We re-set the ModeSpec to only compute the number of modes we need (1) and also update our ModeSolver accordingly.

[18]:

mode_index = 0
mode_spec = td.ModeSpec(num_modes=mode_index+1)
mode_solver = mode_solver.updated_copy(mode_spec=mode_spec)


### Make input and output mode sources and monitors#

Next, we will generate the input ModeSource and output ModeMonitor objects using the convenience methods defined in the ModeSolver.

Because our plane was defined at the input port, weβll modify the centers of the ModeMonitors to place them at the output ports to the right of the device.

[19]:

# make source
mode_src = mode_solver.to_source(
source_time=td.GaussianPulse(
freq0=freq0,
fwidth=fwidth,
),
direction="+",
mode_index=mode_index,
)

# make a basic monitor
mode_mnt = mode_solver.to_monitor(
freqs=[freq0],
name="_"
)

# construct the proper centers for the monitors at the 'top' and 'bot' ports
mnt_center_top = list(plane_in.center)
mnt_center_bot = list(plane_in.center)
mnt_center_top[0] = -plane_in.center[0]
mnt_center_bot[0] = -plane_in.center[0]
mnt_center_top[1] = wg_top.geometry.center[1]
mnt_center_bot[1] = wg_bot.geometry.center[1]

# make a dictionary of names and frequencies to refer to later by key
mnt_names = dict(top="mode_top", bot="mode_bot")
mnt_freqs = dict(top=freq_top, bot=freq_bot)

# make two updated copies of the mode monitor with the proper frequencies, centers, and names
mode_mnt_top = mode_mnt.updated_copy(freqs=[mnt_freqs["top"]], center=mnt_center_top, name=mnt_names["top"])
mode_mnt_bot = mode_mnt.updated_copy(freqs=[mnt_freqs["bot"]], center=mnt_center_bot, name=mnt_names["bot"])

# make another dictionary mapping the keys to the monitors
mode_mnts = dict(top=mode_mnt_top, bot=mode_mnt_bot)


For plotting later, weβll add a couple of FluxMonitor objects at the output ports to measure the total flux over a large spectrum. With this data, we should be able to clearly see the difference in transmission for each of the ports at the design region and get an idea about the device bandwidth.

[20]:

Nf = 121
freqs_flux = np.linspace(freq_bot - fwidth/10, freq_top + fwidth/10, Nf)

flux_mnt_names = dict(top="flux_top", bot="flux_bot")

flux_mnt_top = td.FluxMonitor(
center=mode_mnt_top.center,
size=mode_mnt_top.size,
name=flux_mnt_names["top"],
freqs=list(freqs_flux),
)

flux_mnt_bot = td.FluxMonitor(
center=mode_mnt_bot.center,
size=mode_mnt_bot.size,
name=flux_mnt_names["bot"],
freqs=list(freqs_flux),
)


Finally, we will wrap our previous make_sim_base() function in a new one that adds our new objects to this base simulation.

[21]:

def make_sim(params, beta):

output_monitors = [mode_mnts["top"], mode_mnts["bot"]]

sim_base = make_sim_base(params, beta=beta)
return sim_base.updated_copy(
output_monitors=output_monitors,
sources=[mode_src],
monitors=tuple(list(sim_base.monitors) + [flux_mnt_top, flux_mnt_bot])
)


Letβs make the final simulation and visualize it with the sources and monitors added.

[22]:

sim = make_sim(params0, beta=1)


Note: the FluxMonitor objects are overlaying the output ModeMonitor objects.

[23]:

ax = sim.plot_eps(z=0.01)
plt.show()


## Defining objective function#

With our simulation fully defined as a function of our parameters, we are ready to define our objective function.

### Computing power transmission#

In this case, it is quite simple, we simply measure the transmitted power in our output waveguide mode. We wish to maximize transmission to the top port at the βtopβ wavelength (1330 nm) and maximize transmission to the bottom port at the βbotβ wavelength (1550 nm).

[24]:

def measure_power(sim_data) -> float:
"""Extract power from simulation data."""

def get_power(mnt_key: str, freq_key: str) -> float:
"""Get the power at monitor 'mnt_key' at frequency 'freq_key' (both either 'top' or 'bot')."""
mnt_name = mnt_names[mnt_key]
freq = freqs[freq_key]
mnt_data = sim_data[mnt_name]
amp = mnt_data.amps.sel(direction="+", mode_index=0, f=freq)
return jnp.abs(amp) ** 2

power_max = get_power("top", "top") + get_power("bot", "bot")
return power_max / 2.0


Next we add a penalty to produce structures that are invariant under erosion and dilation, which is a useful approach to implementing minimum length scale features.

[25]:

from tidy3d.plugins.adjoint.utils.penalty import ErosionDilationPenalty

def penalty(params, beta) -> float:
"""Penalty based on the amount of change after erosion and dilation of structures."""
params = pre_process(params, beta=beta)

return ed_penalty.evaluate(params)


### Writing objective function#

Then we write an objective function that constructs our simulation, runs it, measures the power, and subtracts our penalty.

Note, we return a JaxSimulationData as the second output. The reason for this is that we might wish to access our flux and field data later on. jax gives an option has_aux to use only the first output for differentiation while letting the user have access to the 2nd βauxiliaryβ output, which we will make use of.

[26]:

def objective(params, beta, verbose=False) -> float:
sim = make_sim(params, beta=beta)
power = measure_power(sim_data)
J = power - penalty(params, beta=beta)
return J, sim_data


### Differentiating objective#

Finally, we can simply use jax to transform this objective function into a function that returns our objective function value, the auxiliary data, and our gradient, which we will feed to the optimizer.

[27]:

grad_fn = jax.value_and_grad(objective, has_aux=True)


Letβs try out our gradient function with verbosity on for just this run.

[28]:

(J, sim_data), grad = grad_fn(params0, beta=1, verbose=True)

11:02:20 PST Created task 'WDM_MULTIFREQ' with task_id

             View task using web UI at
3-4a23-b70a-1cae0a9791e3'.

11:02:21 PST status = success

11:02:22 PST loading simulation from simulation_data.hdf5

11:02:23 PST WARNING: 2 unique frequencies detected in the output monitors with
a minimum spacing of 3.720e+13 (Hz). Setting the 'fwidth' of the
adjoint sources to 0.1 times this value = 3.720e+12 (Hz) to avoid
spectral overlap. To account for this, the corresponding 'run_time'
in the adjoint simulation is will be set to 2.688527e-11 compared
to 2.688527e-12 in the forward simulation. If the adjoint
'run_time' is large due to small frequency spacing, it could be
better to instead run one simulation per frequency, which can be

             Created task 'WDM_MULTIFREQ_adj' with task_id

             View task using web UI at
2-4379-82f1-707760603828'.

11:02:24 PST status = queued

             To cancel the simulation, use 'web.abort(task_id)' or
Terminating the Python script will not stop the job running on the
cloud.

11:02:32 PST status = preprocess

11:02:36 PST Maximum FlexCredit cost: 0.031. Use 'web.real_cost(task_id)' to get
the billed FlexCredit cost after a simulation run.

             starting up solver

             running solver

11:02:48 PST early shutoff detected at 4%, exiting.

             status = postprocess

11:02:50 PST status = success

11:02:51 PST View simulation result at
2-4379-82f1-707760603828'.


## Run Opimization#

Finally, we are ready to optimize our device. We will make use the optax package to define an optimizer using the adam method, as weβve done in the previous adjoint tutorials.

We record a history of objective function values, simulation data, and parameters for visualization later.

[23]:

import optax

# we know that the source fwidth will be set automatically due to multi-freq adjoint, so suppress warnings
td.config.logging_level = "ERROR"

# hyperparameters
num_steps = 50
learning_rate = 5e-2
beta_min = 1
beta_max = 30

# initialize adam optimizer with starting parameters
params = params0
opt_state = optimizer.init(params)

# store history
Js = []
params_history = [params0]
data_history = []
beta_history = []

for i in range(num_steps):

perc_done = i / (num_steps - 1)
beta_i = beta_min * (1 - perc_done) + beta_max * perc_done

# compute gradient and current objective funciton value

# outputs
print(f"step = {i + 1}")
print(f"\tJ = {value:.4e}")
print(f"\tbeta = {beta_i:.2f}")

# compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)

# keep params between 0 and 1
params = jnp.minimum(1.0, params)
params = jnp.maximum(0.0, params)

# save history
Js.append(value)
params_history.append(params)
beta_history.append(beta_i)
data_history.append(data)

step = 1
J = 1.1304e-01
beta = 1.00
step = 2
J = 1.7998e-01
beta = 1.59
step = 3
J = 1.9274e-01
beta = 2.18
step = 4
J = 5.7359e-02
beta = 2.78
step = 5
J = -2.7653e-02
beta = 3.37
step = 6
J = -1.0726e-01
beta = 3.96
step = 7
J = 5.4190e-02
beta = 4.55
step = 8
J = 1.3695e-01
beta = 5.14
step = 9
J = 1.9039e-01
beta = 5.73
step = 10
J = 2.1604e-01
beta = 6.33
step = 11
J = 2.6796e-01
beta = 6.92
step = 12
J = 2.0278e-01
beta = 7.51
step = 13
J = 3.2679e-01
beta = 8.10
step = 14
J = 3.6340e-01
beta = 8.69
step = 15
J = 3.7389e-01
beta = 9.29
step = 16
J = 3.9022e-01
beta = 9.88
step = 17
J = 4.0384e-01
beta = 10.47
step = 18
J = 4.1686e-01
beta = 11.06
step = 19
J = 4.2954e-01
beta = 11.65
step = 20
J = 4.3786e-01
beta = 12.24
step = 21
J = 4.4402e-01
beta = 12.84
step = 22
J = 4.4862e-01
beta = 13.43
step = 23
J = 4.5884e-01
beta = 14.02
step = 24
J = 4.6795e-01
beta = 14.61
step = 25
J = 4.7021e-01
beta = 15.20
step = 26
J = 4.8514e-01
beta = 15.80
step = 27
J = 5.0607e-01
beta = 16.39
step = 28
J = 5.2328e-01
beta = 16.98
step = 29
J = 5.5337e-01
beta = 17.57
step = 30
J = 5.8120e-01
beta = 18.16
step = 31
J = 6.0239e-01
beta = 18.76
step = 32
J = 6.1259e-01
beta = 19.35
step = 33
J = 5.4708e-01
beta = 19.94
step = 34
J = 6.1207e-01
beta = 20.53
step = 35
J = 6.4012e-01
beta = 21.12
step = 36
J = 6.2275e-01
beta = 21.71
step = 37
J = 6.6233e-01
beta = 22.31
step = 38
J = 6.5566e-01
beta = 22.90
step = 39
J = 6.5845e-01
beta = 23.49
step = 40
J = 6.9423e-01
beta = 24.08
step = 41
J = 7.0369e-01
beta = 24.67
step = 42
J = 6.9838e-01
beta = 25.27
step = 43
J = 6.9214e-01
beta = 25.86
step = 44
J = 7.1860e-01
beta = 26.45
step = 45
J = 7.4005e-01
beta = 27.04
step = 46
J = 7.3577e-01
beta = 27.63
step = 47
J = 7.4667e-01
beta = 28.22
step = 48
J = 7.5887e-01
beta = 28.82
step = 49
J = 7.6499e-01
beta = 29.41
step = 50
J = 7.6311e-01
beta = 30.00


## Visualize Results#

Letβs visualize the results of our optimization.

### Objective function vs Iteration#

First we inspect the objective function value as a function of optimization iteration number. We see that it steadily increases as expected.

The presence of fabrication constraints tends to create some minor bumps in the optimization, which can be a signal that one needs to reduce the step size, but these results are sufficient for our purposes.

[24]:

plt.plot(Js)
plt.xlabel('iteration number')
plt.ylabel('objective function')
plt.show()


### Final Simulation#

Letβs take a look at the final simulation, which we grab from our history.

[25]:

sim_data_final = data_history[-1]
sim_final = sim_data_final.simulation


We notice that the structure has reasonably large feature sizes but is not well binarized. This could be improved by increasing the beta projection value slowly over iteration number, as was done in the grating coupler tutorial.

[26]:

ax = sim_final.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0)


### Flux#

Letβs inspect the flux over each of the output ports as a function of wavelength.

We notice that the top and bottom ports have peaks in transmission at their corresponding design wavelengths, as expected!

[27]:

# plot flux
for key, color in zip(("top", "bot"), ('royalblue', 'firebrick')):
freq = freqs[key]
flux_data = sim_data_final[flux_mnt_names[key]]
wvl_nm = 1000 * td.C_0 / freq
wavelengths_nm = 1000 * td.C_0 / np.array(flux_data.flux.f)
flux = np.array(flux_data.flux.values)
flux_db = 10 * np.log(flux)
label = f"{key} ({int(wvl_nm)} nm)"
plt.plot(wavelengths_nm, flux_db, label=label, color=color)
plt.scatter([wvl_nm], [0], 100, marker="*", color=color)
plt.xlabel('wavelength (nm)')
plt.ylabel('transmission (dB)')
plt.legend()

plt.show()


### Fields#

Letβs also plot the field intensity patterns at the two design wavelengths. We see from this plot the expected result that the power is directed to the design port at each frequency.

[28]:

# plot fields at the two design wavelengths

fig, axes = plt.subplots(1, 2, tight_layout=True, figsize=(7, 3))

for key, ax in zip(("top", "bot"), axes):
freq = freqs[key]
sim_data_final.plot_field("field", "E", "abs^2", f=freq, ax=ax, vmax=1200)
wvl = 1000 * td.C_0 / freq
ax.set_title(f"wavelength = {int(wvl)} nm")


### Animation#

Finally, we animate this plot as a function of iteration number. The animation shows the device quickly accomplishing our design objective.

Note: can take a few minutes to complete

[29]:

import matplotlib.animation as animation
from IPython.display import HTML

fig, (ax1, ax2, ax3) = fig, axes = plt.subplots(1, 3, tight_layout=False, figsize=(9, 4))

def animate(i):

sim_data_i = data_history[i]

sim_i = sim_data_i.simulation
sim_i.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0, ax=ax1)
ax1.set_aspect('equal')

for key, ax in zip(("top", "bot"), (ax2, ax3)):

freq = freqs[key]
wvl = 1000 * td.C_0 / freq

int_i = sim_data_i.get_intensity("field").sel(f=freq)
int_i.squeeze().plot.pcolormesh(x='x', y='y', ax=ax, add_colorbar=False, cmap="magma", vmax=1000)

ax.set_aspect('equal')
ax.set_title(f"wavelength = {int(wvl)} nm")

# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(data_history))
plt.close()

[30]:

# display the animation (press "play" to start)
HTML(ani.to_jshtml())

[30]: