Adjoint-based shape optimization of a waveguide bend#
Note: Tidy3D now supports automatic differentiation natively through
autograd
. Thejax
-basedadjoint
plugin will be deprecated from 2.7 onwards. To see this notebook implemented in the new feature, see this notebook.
In this notebook, we will apply the adjoint method to the optimization of a low-loss waveguide bend. We start with a 90 degree bend in a SiN waveguide, parameterized using a td.PolySlab
.
We define an objective function that seeks to maximize the transmission of the TE0 output mode amplitude with respect to the position of the polygon vertices defining the bend. A penalty is applied to keep the local radii of curvature larger than a pre-defined value.
The resulting device demonstrates low loss and exhibits a smooth geometry.
To install the
jax
module required for this feature, we recommend runningpip install "tidy3d[jax]"
.
If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.
Setup#
First, we import tidy3d
and itβs adjoint
plugin. We will also use numpy
, matplotlib
and jax
.
[1]:
import tidy3d as td
import tidy3d.plugins.adjoint as tda
from tidy3d.plugins.adjoint.web import run_local as run
[2]:
import numpy as np
import matplotlib.pylab as plt
[3]:
import jax
import jax.numpy as jnp
Next, we define all the global parameters for our device and optimization.
[4]:
wavelength = 1.5
freq0 = td.C_0 / wavelength
# frequency of measurement and source
# note: we only optimize results at the central frequency for now.
fwidth = freq0 / 10
num_freqs = 10
freqs = np.linspace(freq0 - fwidth/2, freq0 + fwidth/2, num_freqs)
# define the discretization of the bend polygon in angle
num_pts = 60
angles = np.linspace(0, np.pi/2, num_pts + 2)[1:-1]
# refractive indices of waveguide and substrate (air above)
n_wg = 2.0
n_sub = 1.5
# min space between waveguide and PML
spc = 1 * wavelength
# length of input and output straight waveguide sections
t = 1 * wavelength
# distance between PML and the mode source / mode monitor
mode_spc = t / 2.0
# height of waveguide core
h = 0.7
# minimum, starting, and maximum allowed thicknesses for the bend geometry
wmin = 0.5
wmid = 1.5
wmax = 2.5
# average radius of curvature of the bend
radius = 6
# minimum allowed radius of curvature of the polygon
min_radius = 150e-3
# name of the monitor measuring the transmission amplitudes for optimization
monitor_name = "mode"
# how many grid points per wavelength in the waveguide core material
min_steps_per_wvl = 30
# how many mode outputs to measure
num_modes = 3
mode_spec = td.ModeSpec(num_modes=num_modes)
Using all of these parameters, we can define the total simulation size.
[5]:
Lx = Ly = t + radius + abs(wmax - wmid) + spc
Lz = spc + h + spc
Define parameterization#
Next we describe how the geometry looks as a function of our design parameters.
At each angle on our bend discretization, we define a parameter that can range between -inf and +inf to control the thickness of that section. If that parameter is -inf, 0, and +inf, the thickness of that section is wmin
, wmid
, and wmax
, respectively.
This gives us a smooth way to constrain our measurable parameter without needing to worry about it in the optimization.
[6]:
def thickness(param: float) -> float:
"""thickness of a bend section as a function of a parameter in (-inf, +inf)."""
param_01 = (jnp.tanh(param) + 1.0) / 2.0
return wmax * param_01 + wmin * (1 - param_01)
Next we write a function to generate all of our bend polygon vertices given our array of design parameters. Note that we add extra vertices at the beginning and end of the bend that are independent of the parameters (static) and are only there to make it easier to connect the bend to the input and output waveguide sections.
[7]:
def make_vertices(params: np.ndarray) -> list:
"""Make bend polygon vertices as a function of design parameters."""
vertices = []
vertices.append((-Lx/2 + 1e-2, -Ly/2 + t + radius))
vertices.append((-Lx/2 + t, -Ly/2 + t + radius + wmid/2))
for angle, param in zip(angles, params):
thickness_i = thickness(param)
radius_i = radius + thickness_i/2.0
x = radius_i * np.sin(angle) -Lx/2 + t
y = radius_i * np.cos(angle) -Ly/2 + t
vertices.append((x, y))
vertices.append((-Lx/2 + t + radius + wmid/2, -Ly/2 + t))
vertices.append((-Lx/2 + t + radius, -Ly/2 + 1e-2))
vertices.append((-Lx/2 + t + radius - wmid/2, -Ly/2 + t))
for angle, param in zip(angles[::-1], params[::-1]):
thickness_i = thickness(param)
radius_i = radius - thickness_i/2.0
x = radius_i * np.sin(angle) -Lx/2 + t
y = radius_i * np.cos(angle) -Ly/2 + t
vertices.append((x, y))
vertices.append((-Lx/2 + t, -Ly/2 + t + radius - wmid/2))
return vertices
Letβs try out our make_vertices
function on a set of all 0
parameters, which should give the starting waveguide width of wmid
across the bend.
[8]:
params = np.zeros(num_pts)
vertices = make_vertices(params)
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[9]:
plt.scatter(*np.array(vertices).T)
ax = plt.gca()
ax.set_aspect("equal")
Looks good, note again that the extra points on the ends are just to ensure a solid overlap with the in and out waveguides. At this time, the adjoint plugin does not handle polygons that extend outside of the simulation domain so we need to also ensure that all points are inside of the domain.
Next we wrap this to write a function to generate a 3D JaxPolySlab geometry given our design parameters. The JaxPolySlab
is simply a jax-compatible version of the regular PolySlab
geometry that can be differentiated through.
[10]:
def make_polyslab(params: np.ndarray) -> tda.JaxPolySlab:
"""Make a `tidy3d.PolySlab` for the bend given the design parameters."""
vertices = make_vertices(params)
return tda.JaxPolySlab(
vertices=vertices,
slab_bounds=(-h/2, h/2),
axis=2,
)
Letβs visualize this as well.
[11]:
polyslab = make_polyslab(params)
ax = polyslab.plot(z=0)
Keeping with this theme, we add a function to generate a list of JaxStructures with just one element (our differentiable polygon bend).
[12]:
def make_input_structures(params) -> list[tda.JaxStructure]:
polyslab = make_polyslab(params)
medium = tda.JaxMedium(permittivity=n_wg**2)
return [tda.JaxStructure(geometry=polyslab, medium=medium)]
[13]:
(ring,) = input_structures = make_input_structures(params)
ax = ring.plot(z=0)
Next, we define the other βstaticβ geometries, such as the input waveguide section, output waveguide section, and substrate.
[14]:
box_in = td.Box.from_bounds(
rmin=(-Lx/2 - 1, -Ly/2 + t + radius - wmid/2, -h/2),
rmax=(-Lx/2 + t + 1e-3, -Ly/2 + t + radius + wmid/2, +h/2),
)
box_out = td.Box.from_bounds(
rmin=(-Lx/2 + t + radius - wmid/2, -Ly/2 - 1, -h/2),
rmax=(-Lx/2 + t + radius + wmid/2, -Ly/2 + t, +h/2),
)
geo_sub = td.Box.from_bounds(
rmin=(-td.inf, -td.inf, -10000),
rmax=(+td.inf, +td.inf, -h/2),
)
wg_in = td.Structure(geometry=box_in, medium=td.Medium(permittivity=n_wg**2))
wg_out = td.Structure(geometry=box_out, medium=td.Medium(permittivity=n_wg**2))
substrate = td.Structure(geometry=geo_sub, medium=td.Medium(permittivity=n_sub**2))
Fabrication Constraints#
With the current parameterization, it is possible to generate structures with wildly varying radii of curvature which may be difficult to fabricate. To alleviate this, we introduce a minimum radius of curvature penalty transformation using the tidy3d adjoint utilities. The penalty will take a set of vertices, compute the local radius of curvature using a quadratic Bezier curve, and return an average penalty function that depends on how much smaller the local radii are compared to a desired minimum radius.
[15]:
from tidy3d.plugins.adjoint.utils.penalty import RadiusPenalty
penalty = RadiusPenalty(min_radius=min_radius, alpha=1.0, kappa=10.0)
We then wrap this penalty to look at only the inner and outer vertices independently and average the penalty from each.
[16]:
def eval_penalty(params):
"""Evaluate penalty on a set of params looking at radius of curvature."""
vertices = make_vertices(params)
_vertices = jnp.array(vertices)
vertices_top = _vertices[1:num_pts+3] # select outer set of points along bend
vertices_bot = _vertices[num_pts+4:] # select inner set of points along bend
penalty_top = penalty.evaluate(vertices_top)
penalty_bot = penalty.evaluate(vertices_bot)
return (penalty_top + penalty_bot) / 2.0
Letβs try this out on our starting parameters. We see we get a jax traced float that seems reasonably low given our smooth starting structure.
[17]:
eval_penalty(params)
[17]:
Array(3.5788543e-23, dtype=float32)
Define Simulation#
Now we define our sources, monitors, and simulation.
We first define a mode source injected at the input waveguide.
[18]:
mode_width = wmid + 2 * spc
mode_height = Lz
mode_src = td.ModeSource(
size=(0, mode_width, mode_height),
center=(-Lx/2 + t/2, -Ly/2 + t + radius, 0),
direction="+",
source_time=td.GaussianPulse(
freq0=freq0,
fwidth=fwidth,
)
)
Next, we define monitors for storing:
The output mode amplitude at the central frequency.
The flux on the output plane (for reference).
The output mode amplitude across a frequency range (for examining the transmission spectrum of our final device).
A field monitor to measure fields directly in the z-normal plane intersecting the waveguide.
[19]:
mode_mnt = td.ModeMonitor(
size=(mode_width, 0, mode_height),
center=(-Lx/2 + t + radius, -Ly/2 + t/2, 0),
name=monitor_name,
freqs=[freq0],
mode_spec=mode_spec,
)
flux_mnt = td.FluxMonitor(
size=(mode_width, 0, mode_height),
center=(-Lx/2 + t + radius, -Ly/2 + t/2, 0),
name="flux",
freqs=[freq0],
)
mode_mnt_bb = td.ModeMonitor(
size=(mode_width, 0, mode_height),
center=(-Lx/2 + t + radius, -Ly/2 + t/2, 0),
name="mode_bb",
freqs=freqs.tolist(),
mode_spec=mode_spec,
)
fld_mnt = td.FieldMonitor(
size=(td.inf, td.inf, 0),
freqs=[freq0],
name="field",
)
Next we put everything together into a function that returns a JaxSimulation given our parameters and an optional boolean specifying whether to include the field monitor (to save data when fields are not required).
[20]:
def make_sim(params, use_fld_mnt: bool = True) -> tda.JaxSimulation:
monitors = [mode_mnt_bb, flux_mnt]
if use_fld_mnt:
monitors += [fld_mnt]
input_structures = make_input_structures(params)
return tda.JaxSimulation(
size=(Lx, Ly, Lz),
input_structures=input_structures,
structures=[substrate, wg_in, wg_out],
sources=[mode_src],
output_monitors=[mode_mnt],
grid_spec=td.GridSpec.auto(min_steps_per_wvl=min_steps_per_wvl),
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True),
monitors=monitors,
run_time = 10/fwidth,
)
Letβs try it out and plot our simulation.
[21]:
sim = make_sim(params)
f, (ax1, ax2) = plt.subplots(1,2,tight_layout=True, figsize=(10,4))
ax = sim.plot(z=0.01, ax=ax1)
ax = sim.plot(x=-Lx/2+t/2, ax=ax2)
12:20:23 -03 WARNING: 'JaxPolySlab'-containing 'JaxSimulation.input_structures[0]' intersects with 'JaxSimulation.structures[1]'. Note that in this version of the adjoint plugin, there may be errors in the gradient when 'JaxPolySlab' intersects with background structures.
WARNING: Suppressed 1 WARNING message.
WARNING: Structure at structures[3] was detected as being less than half of a central wavelength from a PML on side x-min. To avoid inaccurate results or divergence, please increase gap between any structures and PML or fully extend structure through the pml.
WARNING: Suppressed 1 WARNING message.
WARNING: Structure at structures[3] was detected as being less than half of a central wavelength from a PML on side x-min. To avoid inaccurate results or divergence, please increase gap between any structures and PML or fully extend structure through the pml.
WARNING: Suppressed 1 WARNING message.
Note: we get warnings from the adjoint plugin because the polyslab intersects the static waveguide ports and those edges will give inaccurate gradients. We can safely ignore those warnings because we donβt need gradients with respect to them.
[22]:
td.config.logging_level = "ERROR"
Select the desired waveguide mode#
Next, we use the ModeSolver to solve and select the mode_index
that gives us the proper injected and measured modes. We plot all of the fields for the first 3 modes and see that the TE0 mode is mode_index=0
.
[23]:
from tidy3d.plugins.mode import ModeSolver
ms = ModeSolver(simulation=sim.to_simulation()[0], plane=mode_src, mode_spec=mode_spec, freqs=mode_mnt.freqs)
data = ms.solve()
print("Effective index of computed modes: ", np.array(data.n_eff))
fig, axs = plt.subplots(num_modes, 3, figsize=(14, 10), tight_layout=True)
for mode_ind in range(num_modes):
for field_ind, field_name in enumerate(("Ex", "Ey", "Ez")):
field = data.field_components[field_name].sel(mode_index=mode_ind)
ax = axs[mode_ind, field_ind]
field.real.plot(x='y', y='z', ax=ax, cmap='RdBu')
ax.set_title(f"{field_name}, mode_ind={mode_ind}")
Effective index of computed modes: [[1.7966835 1.7514164 1.6002883]]
Since this is already the default mode index, we can leave the original make_sim()
function as is. However, to generate a new mode source with a different mode_index
, we could do the following and rewrite that function with the returned mode_src
.
[24]:
# select the mode index
mode_index = 0
# make the mode source with appropriate mode index
mode_src = ms.to_source(mode_index=mode_index, source_time=mode_src.source_time, direction=mode_src.direction)
Defining objective function#
Now we can define our objective function to maximize. The objective function first generates a simulation given the parameters, runs the simulation using the jax
-compatible tidy3d.plugins.adjoint.run function, measures the power transmitted into the TE0 output mode at our desired polarization, and then subtracts the radius of curvature penalty that we defined earlier.
For convenience, we also return the JaxSimulationData as the 2nd output, which will be ignored by jax
when we pass has_aux=True
when computing the gradient of this function.
[25]:
def objective(params, use_fld_mnt:bool = True):
sim = make_sim(params, use_fld_mnt=use_fld_mnt)
sim_data = run(sim, task_name='bend', verbose=False)
amps = sim_data[monitor_name].amps.sel(direction="-", mode_index=mode_index).values
transmission = jnp.abs(jnp.array(amps))**2
J = jnp.sum(transmission) - eval_penalty(params)
return J, sim_data
Next, we use jax.value_and_grad to transform this objective function into a function that returns the
Objective function evaluated at the passed parameters.
Auxilary JaxSimulationData corresponding to the forward pass (for plotting later).
Gradient of the objective function with respect to the passed parameters.
[26]:
val_grad = jax.value_and_grad(objective, has_aux=True)
Letβs run this function and take a look at the outputs.
[27]:
(val, sim_data), grad = val_grad(params)
[28]:
print(val)
print(grad)
0.56060445
[-0.0188388 -0.02968329 -0.04119703 -0.05332524 -0.06572866 -0.07820014
-0.09031231 -0.10045926 -0.11044138 -0.11628448 -0.11940151 -0.11856014
-0.11367594 -0.10428433 -0.09043697 -0.07393097 -0.05485657 -0.0337466
-0.01146656 0.01097788 0.03232807 0.05213357 0.07016988 0.08564365
0.0990769 0.10944834 0.11628198 0.12678821 0.1184852 0.12728047
0.12726428 0.11848194 0.12681223 0.11626809 0.10944062 0.09910151
0.0856217 0.0701622 0.05214944 0.03230007 0.01097185 -0.01146848
-0.03375959 -0.05487485 -0.07394323 -0.09041391 -0.10433427 -0.11368035
-0.11850713 -0.1194798 -0.11626592 -0.1103911 -0.10054259 -0.09028433
-0.07816786 -0.06577757 -0.05332028 -0.04118454 -0.02970344 -0.01884984]
These seem reasonable and can now be used for plugging into our optimization algorithm.
Optimization Procedure#
With our gradients defined, we write a simple optimization loop using the optax package. We use the adam
method with a tunable number of steps and learning rate. The intermediate values, parameters, and data are stored for visualization later.
Note: this will take several minutes. While not shown here, it is good practice to checkpoint your optimization results by saving to file on every iteration, or ensure you have a stable internet connection. See this notebook for more details.
[29]:
import optax
# hyperparameters
num_steps = 40
learning_rate = 0.1
# initialize adam optimizer with starting parameters
params = np.array(params).copy()
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
# store history
objective_history = []
param_history = [params]
data_history = []
for i in range(num_steps):
# compute gradient and current objective funciton value
(value, sim_data), gradient = val_grad(params)
# multiply all by -1 to maximize obj_fn
gradient = -np.array(gradient.copy())
# outputs
print(f"step = {i + 1}")
print(f"\tJ = {value:.4e}")
print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")
# compute and apply updates to the optimizer based on gradient
updates, opt_state = optimizer.update(gradient, opt_state, params)
params = optax.apply_updates(params, updates)
# save history
objective_history.append(value)
param_history.append(params)
data_history.append(sim_data)
step = 1
J = 5.6060e-01
grad_norm = 6.7527e-01
step = 2
J = 8.4790e-01
grad_norm = 4.8038e-01
step = 3
J = 8.9984e-01
grad_norm = 4.0508e-01
step = 4
J = 8.5131e-01
grad_norm = 5.0347e-01
step = 5
J = 8.3507e-01
grad_norm = 4.8750e-01
step = 6
J = 8.6679e-01
grad_norm = 4.8799e-01
step = 7
J = 9.2224e-01
grad_norm = 3.5351e-01
step = 8
J = 9.5242e-01
grad_norm = 1.8942e-01
step = 9
J = 9.4218e-01
grad_norm = 2.4272e-01
step = 10
J = 9.1946e-01
grad_norm = 2.9455e-01
step = 11
J = 9.0531e-01
grad_norm = 3.2003e-01
step = 12
J = 9.1071e-01
grad_norm = 3.1154e-01
step = 13
J = 9.3085e-01
grad_norm = 2.5687e-01
step = 14
J = 9.5300e-01
grad_norm = 1.7602e-01
step = 15
J = 9.6490e-01
grad_norm = 1.0443e-01
step = 16
J = 9.6218e-01
grad_norm = 1.3480e-01
step = 17
J = 9.5239e-01
grad_norm = 1.9509e-01
step = 18
J = 9.4549e-01
grad_norm = 2.2595e-01
step = 19
J = 9.4680e-01
grad_norm = 2.1706e-01
step = 20
J = 9.5441e-01
grad_norm = 1.7688e-01
step = 21
J = 9.6259e-01
grad_norm = 1.2695e-01
step = 22
J = 9.6678e-01
grad_norm = 9.8817e-02
step = 23
J = 9.6526e-01
grad_norm = 1.2305e-01
step = 24
J = 9.6186e-01
grad_norm = 1.4717e-01
step = 25
J = 9.5993e-01
grad_norm = 1.5419e-01
step = 26
J = 9.6063e-01
grad_norm = 1.4312e-01
step = 27
J = 9.6385e-01
grad_norm = 1.2450e-01
step = 28
J = 9.6808e-01
grad_norm = 1.0043e-01
step = 29
J = 9.7131e-01
grad_norm = 7.4764e-02
step = 30
J = 9.7190e-01
grad_norm = 6.6778e-02
step = 31
J = 9.7000e-01
grad_norm = 8.9502e-02
step = 32
J = 9.6791e-01
grad_norm = 1.0968e-01
step = 33
J = 9.6787e-01
grad_norm = 1.1277e-01
step = 34
J = 9.7046e-01
grad_norm = 9.2506e-02
step = 35
J = 9.7376e-01
grad_norm = 5.5469e-02
step = 36
J = 9.7542e-01
grad_norm = 2.4987e-02
step = 37
J = 9.7487e-01
grad_norm = 4.3868e-02
step = 38
J = 9.7339e-01
grad_norm = 6.9073e-02
step = 39
J = 9.7261e-01
grad_norm = 8.0034e-02
step = 40
J = 9.7333e-01
grad_norm = 7.4658e-02
Analyzing results#
After the optimization is finished, letβs look at the results.
[30]:
_ = plt.plot(objective_history)
ax = plt.gca()
ax.set_xlabel('iteration number')
ax.set_ylabel('objective function')
ax.set_title('optimization progress')
plt.show()
Next, we can grab our initial and final device from the history lists.
[31]:
sim_start = make_sim(param_history[0])
data_start = data_history[0]
sim_final = make_sim(param_history[-1])
data_final = data_history[-1]
Letβs take a look at the final structure. We see that it has a smooth design which is symmetric about the 45 degree angle.
[32]:
ax = sim_final.plot(z=0.01)
Now letβs inspect the difference between the initial and final intensity patterns. We notice that the final device is quite effective at coupling light into the output waveguide! This is especially evident when compared to the starting device.
[33]:
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, tight_layout=True, figsize=(10, 6))
_ = data_start.plot_field('field', 'E', 'abs^2', ax=ax1)
_ = sim_start.plot(z=0, ax=ax2)
ax1.set_title('starting device')
ax2.set_title('starting device')
_ = data_final.plot_field('field', 'E', 'abs^2', ax=ax3)
_ = sim_final.plot(z=0, ax=ax4)
ax3.set_title('final device')
ax4.set_title('final device')
plt.show()
Letβs view the transmission now, both in linear and dB scale.
The mode amplitudes are simply an xarray.DataArray that can be selected, post processed, and plotted.
[34]:
amps = sim_data['mode_bb'].amps.sel(direction="-", mode_index=mode_index)
[35]:
transmission = abs(amps)**2
transmission_percent = 100 * transmission
transmission_percent.plot(x="f")
ax = plt.gca()
ax.set_title('mode_index=0, transmitted power %')
ax.set_ylabel('T (%)')
plt.show()
We can also put this in log scale.
[36]:
loss = 1 - transmission
loss_db = 10 * np.log10(transmission)
loss_db.plot(x="f")
plt.ylabel('loss (dB)')
plt.show()
Finally, letβs animate the field pattern evolution over the entire optimization. This will take a minute or so.
[37]:
import matplotlib.animation as animation
from IPython.display import HTML
fig, (ax1, ax2) = plt.subplots(1, 2, tight_layout=False, figsize=(8, 4))
def animate(i):
# grab data at iteration "i"
sim_data_i = data_history[i]
# plot permittivity
sim_i = sim_data_i.simulation
sim_i.plot_eps(z=0, monitor_alpha=0.0, source_alpha=0.0, ax=ax1)
# ax1.set_aspect('equal')
# plot intensity
int_i = sim_data_i.get_intensity("field")
int_i.squeeze().plot.pcolormesh(x='x', y='y', ax=ax2, add_colorbar=False, cmap="magma")
# ax2.set_aspect('equal')
# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(data_history));
plt.close()
# display the animation (press "play" to start)
HTML(ani.to_jshtml())
[37]: