Inverse design integrated with circuit simulation#
In this tutorial, we will show how to integrate the adjoint
plugin of Tidy3D
with a differentiable optical circuit simulator sax
. This allows one to model a complicated circuit composed of many connected components, each simulated independently using Tidy3D
. Through the adjoint
plugin and jax
, the gradients of all of the individual components are similarly connected. This allows one to write an objective function in terms of the scattering matrix of the entire circuit and
optimize this function with respect to the design parameters in each of the individual Tidy3D
simulations.
To demonstrate this capability, in this notebook we optimize a Mach-Zehnder Interferometer (MZI) circuit. This simplified MZI has a single input and two outputs. We wish to switch the transmitted power between the two outputs depending on a phase shift applied to a waveguide in the system. We set up our circuit to have a single splitter component that takes the input light and splits it into two waveguides, we apply the phase shift to one of these waveguides, and then add a component that
combines the light from the two waveguides, mixes it together, and sends it to our two outputs. The scattering matrices of the two components are computed using Tidy3D
simulations and the waveguide connections and phase shifter are defined using the sax
circuit simulator. As all of the gradients are passed automatically through jax
, we then optimize our circuit with respect to the permittivity distributions in each of the two Tidy3D
simulations simultaneously.
Below is a schematic of this process and some of the variable labels we use in the code.
To install the
jax
module required for this feature, we recommend runningpip install "tidy3d[jax]"
. You will also need topip install sax
.
If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.
Setup#
First we import all of the packages we need.
[1]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import sax
import tidy3d as td
import tidy3d.plugins.adjoint as tda
np.random.seed(2)
/Users/twhughes/.pyenv/versions/3.10.9/lib/python3.10/site-packages/sax/backends/__init__.py:24: UserWarning: klujax not found. Please install klujax for better performance during circuit evaluation!
warnings.warn(
Tidy3D Simulation Parameters#
Then we will initialize some parameters needed for our individual component simulations.
For this application, we model each of the Tidy3D
components as square design regions accepting 1 or 2 inputs and transmitting to 1 or 2 outputs.
[2]:
# wavelength and frequency
wavelength = 1.0
freq0 = td.C_0 / wavelength
# resolution control
steps_per_wvl = 20
# space between boxes and PML
buffer = 1.0 * wavelength
# optimize region size
lz = td.inf
lx = 3.0
ly = lx
wg_width = 0.4
# num cells
nx = 120
ny = nx
num_cells = nx * ny
# position of source and monitor (constant for all)
source_x = -lx / 2 - buffer * 0.8
meas_x = lx / 2 + buffer * 0.8
# total size
Lx = lx + 2 * buffer
Ly = ly + 2 * buffer
Lz = 0
# permittivity info
eps_wg = 2.75
eps_deviation_random = 0.5
# note, we choose the starting parameters
params0 = np.random.random((nx, ny))
# frequency width and run time
freqw = freq0 / 10
run_time = 50 / freqw
Because we want to be able to model a general system of 1 or 2 inputs coupling to 1 or 2 outputs, we pre-define all of the possible waveguide configurations beforehand to make things simpler later.
[3]:
big_number = Lx * 10
dy = (ly - 2 * wg_width) / 4 + wg_width/2
# all of the possible input and output waveguides
waveguide_in_center = td.Structure(
geometry=td.Box(
size=(big_number, wg_width, lz),
center=(-big_number/2, 0, 0),
),
medium=td.Medium(permittivity=eps_wg)
)
waveguide_in_top = td.Structure(
geometry=td.Box(
size=(big_number, wg_width, lz),
center=(-big_number/2, +dy, 0),
),
medium=td.Medium(permittivity=eps_wg)
)
waveguide_in_bot = td.Structure(
geometry=td.Box(
size=(big_number, wg_width, lz),
center=(-big_number/2, -dy, 0),
),
medium=td.Medium(permittivity=eps_wg)
)
waveguide_out_center = td.Structure(
geometry=td.Box(
size=(big_number, wg_width, lz),
center=(+big_number/2, 0, 0),
),
medium=td.Medium(permittivity=eps_wg),
name="center"
)
waveguide_out_top = td.Structure(
geometry=td.Box(
size=(big_number, wg_width, lz),
center=(+big_number/2, +dy, 0),
),
medium=td.Medium(permittivity=eps_wg),
name="top"
)
waveguide_out_bot = td.Structure(
geometry=td.Box(
size=(big_number, wg_width, lz),
center=(+big_number/2, -dy, 0),
),
medium=td.Medium(permittivity=eps_wg),
name="bot"
)
We also define some information about our mode source and monitor geometries.
[4]:
# the source and measurement plane size
mode_size = (0, wg_width * 3, lz)
# source plane centered at y=0
source_plane_base = td.Box(
center=[source_x, 0, 0],
size=mode_size,
)
def get_source_plane(waveguide: td.Structure) -> td.Box:
"""SOurce plane with y position moved to cover a specific waveguide"""
return source_plane_base.updated_copy(center=(source_x, waveguide.geometry.center[1], 0))
measure_plane = td.Box(
center=[meas_x, 0, 0],
size=mode_size,
)
Design Parameterization#
As in many of the other adjoint
demos, now we define our design region structure using a JaxCustomMedium
generated as a function of our design parameters. We will apply filtering and projection to create smooth features. For more details, we refer the reader to our intro to inverse design tutorials.
[5]:
from tidy3d.plugins.adjoint.utils.filter import ConicFilter
from typing import List
radius = .120
beta = 50
conic_filter = ConicFilter(radius=radius, design_region_dl=float(lx) / nx)
def tanh_projection(x, beta, eta=0.5):
tanhbn = jnp.tanh(beta * eta)
num = tanhbn + jnp.tanh(beta * (x - eta))
den = tanhbn + jnp.tanh(beta * (1 - eta))
return num / den
def filter_project(x, beta, eta=0.5):
x = conic_filter.evaluate(x)
return tanh_projection(x, beta=beta, eta=eta)
def pre_process(params, beta):
"""Get the permittivity values (1, eps_wg) array as a function of the parameters (0,1)"""
params1 = filter_project(params, beta=beta)
params2 = filter_project(params1, beta=beta)
return params2
def get_eps(params, beta):
params = pre_process(params, beta=beta)
eps_min = 1.0001
eps_values = eps_min + (eps_wg - eps_min) * params
return eps_values
def make_input_structures(params, beta) -> List[tda.JaxStructure]:
size_box_x = float(lx) / nx
size_box_y = float(ly) / ny
size_box = (size_box_x, size_box_y, lz)
x0_min = -lx / 2 + size_box_x / 2
y0_min = -ly / 2 + size_box_y / 2
input_structures = []
coords_x = [x0_min + index_x * size_box_x - 1e-5 for index_x in range(nx)]
coords_y = [y0_min + index_y * size_box_y - 1e-5 for index_y in range(ny)]
coords = dict(x=coords_x, y=coords_y, z=[0], f=[freq0])
eps_boxes = get_eps(params, beta=beta).reshape((nx, ny, 1, 1))
field_components = {
f"eps_{dim}{dim}": tda.JaxDataArray(values=eps_boxes, coords=coords) for dim in "xyz"
}
eps_dataset = tda.JaxPermittivityDataset(**field_components)
custom_medium = tda.JaxCustomMedium(eps_dataset=eps_dataset)
box = tda.JaxBox(center=(0, 0, 0), size=(lx, ly, lz))
custom_structure = tda.JaxStructure(geometry=box, medium=custom_medium)
return [custom_structure]
Base Simulation#
Next, we write a “base” simulation (without sources or monitors) as a function of our input parameters. We also accept the shape
of our component, which specifies the number of inputs and outputs. This determines which waveguides we add to our simulation.
[6]:
def make_sim_base(params, beta, shape) -> tda.JaxSimulation:
input_structures = make_input_structures(params, beta=beta)
num_wg_in, num_wg_out = shape
if num_wg_in == 1:
wgs_in = [waveguide_in_center]
else:
wgs_in = [waveguide_in_top, waveguide_in_bot]
if num_wg_out == 1:
wgs_out = [waveguide_out_center]
else:
wgs_out = [waveguide_out_top, waveguide_out_bot]
return tda.JaxSimulation(
size=[Lx, Ly, Lz],
grid_spec=td.GridSpec.auto(min_steps_per_wvl=steps_per_wvl, wavelength=wavelength),
structures=wgs_in + wgs_out,
input_structures=input_structures,
sources=[],
monitors=[],
output_monitors=[],
run_time=run_time,
subpixel=True,
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=False),
shutoff=1e-8,
courant=0.9,
)
Let’s make a base simulation for a few different shapes and plot them to make sure they work properly.
[7]:
f, ((ax1, ax2), (ax3, ax4)) = f, (axtop, axbot) = f, axes = plt.subplots(2, 2, tight_layout=True, figsize=(10,8))
for num_in in (1,2):
for num_out in (1,2):
ax = axes[num_in - 1, num_out-1]
shape = (num_in, num_out)
sim = make_sim_base(params0, beta=5.0, shape=shape)
_ = sim.plot_eps(z=0, ax=ax)
ax.set_title(f"sim for shape={shape}")
plt.show()
Mode Solver#
Next, we’ll run the mode solver on one of these waveguides to make sure we inject and measure the desired waveguide modes in our system.
[8]:
from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.web import run as run_mode_solver
num_modes = 4
mode_spec = td.ModeSpec(num_modes=num_modes)
sim_start = make_sim_base(params0, beta=5.0, shape=(1,1))
mode_solver = ModeSolver(
simulation=sim_start.to_simulation()[0],
plane=get_source_plane(sim_start.structures[0]),
mode_spec=td.ModeSpec(num_modes=num_modes),
freqs=[freq0]
)
modes = run_mode_solver(mode_solver)
[09:51:22] Mode solver created with task_id='fdve-c8eaa444-395e-4f6b-800e-f84df8599f86v1', solver_id='mo-b647dd3c-92f7-4d74-861f-6d3b2acf472f'.
[09:51:27] Mode solver status: queued
[09:51:29] Mode solver status: running
[09:51:40] Mode solver status: success
Let’s plot the modes.
[9]:
print("Effective index of computed modes: ", np.array(modes.n_eff))
fig, axs = plt.subplots(num_modes, 2, figsize=(10, 14), tight_layout=True)
for mode_ind in range(num_modes):
for field_ind, field_name in enumerate(("Ey", "Ez")):
field = modes.field_components[field_name].sel(mode_index=mode_ind)
ax = axs[mode_ind, field_ind]
field.real.plot(ax=ax)
ax.set_title(f'index={mode_ind}, {field_name}(y)')
Effective index of computed modes: [[1.4718767 1.3555466 1.007765 0.9316404]]
We wish to inject the fundamental Ez
-polarized mode, which is given by mode_index=0
above. Thus, we make a variable to store this and re-set the ModeSpec.num_modes
to account for this index
without being too high, which could waste computation.
[10]:
mode_index = 0
num_modes = mode_index + 1
mode_spec = td.ModeSpec(num_modes=num_modes)
Sources and Monitors#
Next we will define our input sources and output monitors for this component. We’ll write these as functions of the input and output waveguides so the process of generating them is more general.
[11]:
def make_source(waveguide):
# source seeding the simulation
return td.ModeSource(
source_time=td.GaussianPulse(freq0=freq0, fwidth=freqw),
center=[source_x, waveguide.geometry.center[1], 0],
size=mode_size,
mode_index=mode_index,
mode_spec=mode_spec,
direction="+",
)
def make_output_monitors(waveguides):
monitors = []
for waveguide in waveguides:
# monitor where we compute the objective function from
measurement_monitor = td.ModeMonitor(
center=[meas_x, waveguide.geometry.center[1], 0],
size=mode_size,
freqs=[freq0],
mode_spec=mode_spec,
name=waveguide.name,
)
monitors.append(measurement_monitor)
return monitors
Final Simulation#
Finally, we write a function to generate a component simulation based on the design parameters, projection strength, shape (inputs x outputs), and the index of the source we wish to inject.
[12]:
def make_sim(params, beta, shape, source_index: int):
sim = make_sim_base(params, beta=beta, shape=shape)
num_wgs_in, num_wgs_out = shape
wg_in = sim.structures[source_index]
forward_source_in = make_source(wg_in)
wgs_out = list(sim.structures)[int(num_wgs_in):]
output_monitors = make_output_monitors(wgs_out)
return sim.updated_copy(
sources=[forward_source_in],
output_monitors=output_monitors
)
Let’s generate a simulation and plot it with the sources and monitors to make sure it works properly.
[13]:
ax = make_sim(params0, shape=(2,1), beta=1, source_index=0).plot(z=0)
Defining Circuit#
With our function to generate the component simulations, now we can start focusing on combining these components together into a circuit using sax
. We highly recommend referring to the sax
documentation for any additional information, but will give a brief tutorial of the tool through the next few cells.
Components#
In sax
, the individual “nodes” in the circuit are defined as functions that return the scattering matrix of that component as a dictionary. In our case, our individual components are modelled as Tidy3D
simulations. Therefore, we will write our component function to accept the design parameters and run one Tidy3D
simulation per input source to construct the scattering matrix of the system.
[14]:
def component(params=params0, beta=5, shape=(2,2)):
num_in, num_out = shape
num_in = int(num_in)
num_out = int(num_out)
def get_S_column(sim_data):
"""Compute a column of the scattering matrix for a single dataset."""
outputs = []
for out_mnt in sim_data.simulation.output_monitors:
amps = sim_data[out_mnt.name].amps
amp = jnp.sum(amps.sel(mode_index=mode_index, direction="+", f=freq0))
outputs.append(amp)
return outputs
sims = [make_sim(params, shape=shape, beta=beta, source_index=source_index) for source_index in range(num_in)]
sim_datas = tda.web.run_async(sims, verbose=False, path_dir="data")
s_columns = [get_S_column(sim_data) for sim_data in sim_datas]
# assemble the scattering matrix
s_dict = {}
for index_in in range(num_in):
label_in = "in" + str(index_in)
s_col = s_columns[index_in]
for index_out in range(num_out):
label_out = "out" + str(index_out)
s_element = s_col[index_out]
s_dict[(label_in, label_out)] = s_element
return sax.reciprocal(s_dict)
Note: these component functions must only contain keyword arguments (like
x=1
) with default values. So we defineparams=params0
andbeta=5
as defaults for now, but will show how to pass our own values later.
Let’s test this out by calling this function with some example inputs and visualizing the s-matrix.
We see that it returns a dictionary where the keys are tuples mapping the names of our input waveguide to our output waveguide.
[15]:
component_sdict = component(params0, beta=1, shape=(1,2))
component_sdict
[15]:
{('in0', 'out0'): Array(0.38201833-0.0332941j, dtype=complex64),
('in0', 'out1'): Array(0.40652457-0.05268222j, dtype=complex64),
('out0', 'in0'): Array(0.38201833-0.0332941j, dtype=complex64),
('out1', 'in0'): Array(0.40652457-0.05268222j, dtype=complex64)}
Next, we define a more simple component function to model our phase shifter. This component simply takes the phase value phi
and adds it to the connection.
[16]:
def phase_shifter(phi: float = 0.0):
phase_added = jnp.exp(1j * phi)
s_dict = {("in", "out"): phase_added}
return sax.reciprocal(s_dict)
Circuit#
Next, we need to combine these components together into a circuit. We do this through sax.circuit
, which lets us define our “instances” (these component functions defined earlier), the “connections” between each of these instances, and then the “ports” for the entire circuit.
We wish to create a (1->2) component, with one output connected to our phase shifter, and then combine everything in a (2->2) component. We define these components and connections below and then specify the ports for the entire S-matrix, which is a (1->2) system.
[17]:
import functools
circuit_fn, _ = sax.circuit(
netlist={
"instances": {
"splitter": functools.partial(component, shape=(1,2)),
"phase_shifter": phase_shifter,
"combiner": functools.partial(component, shape=(2,2)),
},
"connections": {
"splitter,out0": "phase_shifter,in",
"phase_shifter,out": "combiner,in0",
"splitter,out1": "combiner,in1",
},
"ports": {
"in": "splitter,in0",
"out0": "combiner,out0",
"out1": "combiner,out1",
},
}
)
circuit_fn
[17]:
<function sax.circuit._flat_circuit.<locals>._circuit(*, splitter={'params': Array([[0.4359949 , 0.02592623, 0.5496625 , ..., 0.17671216, 0.59125733,
0.48926616],
[0.54790777, 0.69952065, 0.24581116, ..., 0.6424524 , 0.38690034,
0.85511965],
[0.3807926 , 0.17830983, 0.7816594 , ..., 0.4921191 , 0.9379131 ,
0.13442676],
...,
[0.35449517, 0.7365258 , 0.73508275, ..., 0.62516195, 0.26062906,
0.5743313 ],
[0.87019104, 0.9364767 , 0.56900996, ..., 0.47169012, 0.08907937,
0.9284895 ],
[0.25833175, 0.5660962 , 0.85214543, ..., 0.31971204, 0.79901004,
0.170014 ]], dtype=float32), 'beta': Array(5., dtype=float32), 'shape': Array([1., 2.], dtype=float32)}, phase_shifter={'phi': Array(0., dtype=float32)}, combiner={'params': Array([[0.4359949 , 0.02592623, 0.5496625 , ..., 0.17671216, 0.59125733,
0.48926616],
[0.54790777, 0.69952065, 0.24581116, ..., 0.6424524 , 0.38690034,
0.85511965],
[0.3807926 , 0.17830983, 0.7816594 , ..., 0.4921191 , 0.9379131 ,
0.13442676],
...,
[0.35449517, 0.7365258 , 0.73508275, ..., 0.62516195, 0.26062906,
0.5743313 ],
[0.87019104, 0.9364767 , 0.56900996, ..., 0.47169012, 0.08907937,
0.9284895 ],
[0.25833175, 0.5660962 , 0.85214543, ..., 0.31971204, 0.79901004,
0.170014 ]], dtype=float32), 'beta': Array(5., dtype=float32), 'shape': Array([2., 2.], dtype=float32)}) -> 'SType'>
Passing individual parameters#
The circuit_fn
returned is a function that accepts parameters to each of our component functions. It is worth noting that we can pass different inputs to different functions by passing them as keyword arguments, as shown below. This is important to note as we will be optimizing each of the Tidy3D
components individually with their own independent parameters.
Let’s call the circuit function and print the result, which is the S-matrix for the entire circuit given our passed parameters.
[18]:
# how to pass specific parmaeters to each of the sub-functions for the instances
s = circuit_fn(splitter={"params": params0}, combiner={"params": 0 * params0}, beta=3, phase_sifter=dict(phi=2.0))
[19]:
s
[19]:
{('out0', 'out0'): Array(0.+0.j, dtype=complex64),
('out0', 'out1'): Array(0.+0.j, dtype=complex64),
('out1', 'out0'): Array(0.+0.j, dtype=complex64),
('out1', 'out1'): Array(0.+0.j, dtype=complex64),
('in', 'in'): Array(0.+0.j, dtype=complex64),
('in', 'out0'): Array(0.09807562-0.12380885j, dtype=complex64),
('in', 'out1'): Array(0.06793377-0.14270785j, dtype=complex64),
('out0', 'in'): Array(0.09807562-0.12380885j, dtype=complex64),
('out1', 'in'): Array(0.06793377-0.14270784j, dtype=complex64)}
Objective Function#
With our circuit defined, we can now combine everything into a single objective function. We first write a penalty
function that evaluates how well the structure respects the feature size constraints that we defined earlier.
[20]:
def penalty(params, beta, delta_eps=0.49):
params = pre_process(params, beta=beta)
dilate_fn = lambda x: filter_project(x, beta=100, eta=0.5-delta_eps)
eroded_fn = lambda x: filter_project(x, beta=100, eta=0.5+delta_eps)
params_dilate_erode = eroded_fn(dilate_fn(params))
params_erode_dilate = dilate_fn(eroded_fn(params))
diff = params_dilate_erode - params_erode_dilate
return jnp.linalg.norm(diff) / jnp.linalg.norm(jnp.ones_like(diff))
We then write a combined objective function that accepts our parameters for each of the individual components (as one array params
) and the projection strength beta
applied to each design region.
The objective function uses these parameters to construct each of the individual components and simulates them to compute their scattering matrix. Then, it defines a circuit-level objective to look at the transmission of the entire circuit into the two output ports as a function of the phase shift phi
. We seek to maximize transmission to the top port when phi=0
and the bottom port when phi=pi
.
[21]:
def J(params, beta) -> float:
"""Circuit-level objective function."""
params1, params2 = params
circuit_function = functools.partial(circuit_fn, splitter={"params": params1}, combiner={"params": params2}, beta=beta)
def top_minus_bot(phi: float) -> float:
"""Power in top port minus power in bottom port."""
#evaluate the circuit at phi
sdict = circuit_function(phase_shifter={"phi": phi})
# S-parameters for the whole circuit
s_00 = sdict["in", "out0"]
s_01 = sdict["in", "out1"]
# power at ports
power_top = jnp.sum(jnp.abs(s_00)**2)
power_bot = jnp.sum(jnp.abs(s_01)**2)
# top power minus bottom power
return power_top - power_bot
# combine objectives together: at worst, it will be -1, at best + 1.
objective = (top_minus_bot(0.0) - top_minus_bot(np.pi)) / 2.0
# combined penalty for both devices
penalty_weight = 0.5
feature_penalty1 = penalty(params=params1, beta=beta)
feature_penalty2 = penalty(params=params2, beta=beta)
feature_penalty = penalty_weight * (feature_penalty1 + feature_penalty2) / 2.0
return objective - feature_penalty
Next we use jax to compute a function that returns the value of this objective function and its gradient when passed some input parameters.
[22]:
dJ_fn = jax.value_and_grad(J)
Let’s try running this function with some example parameters and inspect the results.
[23]:
params0_combined = np.stack((params0, params0), axis=0)
val, grad = dJ_fn(params0_combined, beta=1)
[24]:
print(val, grad)
-0.50251895 [[[ 7.14473344e-06 9.11673851e-06 1.05827448e-05 ... -9.17585021e-06
-7.92271931e-06 -6.19440652e-06]
[ 8.68899588e-06 1.10152214e-05 1.26921068e-05 ... -1.13362294e-05
-9.85804763e-06 -7.75598346e-06]
[ 9.43710984e-06 1.18747666e-05 1.35484370e-05 ... -1.26476180e-05
-1.10918045e-05 -8.78276478e-06]
...
[-3.22923770e-05 -4.00895296e-05 -4.48482424e-05 ... 4.50681364e-05
4.12528025e-05 3.38398604e-05]
[-2.90573880e-05 -3.61831262e-05 -4.05935389e-05 ... 4.25772196e-05
3.88274893e-05 3.17442318e-05]
[-2.33855517e-05 -2.92259228e-05 -3.28775859e-05 ... 3.56646669e-05
3.24215143e-05 2.64082391e-05]]
[[ 3.00405318e-05 3.74189149e-05 4.21421937e-05 ... -3.59632759e-05
-3.20999643e-05 -2.58842447e-05]
[ 3.59139303e-05 4.45573241e-05 5.00202914e-05 ... -4.29933280e-05
-3.85160092e-05 -3.11923541e-05]
[ 3.78112854e-05 4.67239806e-05 5.22388145e-05 ... -4.54894471e-05
-4.09546483e-05 -3.33193311e-05]
...
[-3.32819945e-06 -3.46686102e-06 -3.34001015e-06 ... 9.98089945e-06
9.28078589e-06 7.96504173e-06]
[-2.19371759e-06 -2.00998147e-06 -1.55896669e-06 ... 9.13287022e-06
8.60987348e-06 7.42999919e-06]
[-1.26351188e-06 -9.15017154e-07 -3.27843736e-07 ... 7.39111147e-06
7.05704315e-06 6.11908763e-06]]]
[25]:
print(grad.shape)
(2, 120, 120)
The resulting value and gradient are reasonable. Note the gradient is shaped (2, nx, ny)
, which represents the gradients with respect to each of the two (nx, nx)
pixelated grids for the individual components.
Optimization Loop#
Next, as in the other examples, we use optax
to run the optimization of this entire circuit using gradient descent using the Adam
optimization method.
[26]:
import optax
# hyperparameters
num_steps = 45
learning_rate = 1.0
# initialize adam optimizer with starting parameters
params = params0_combined.copy()
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
# store history
Js = []
params_history = [params]
beta_history = []
beta0 = 1.0
beta_final = 20
for i in range(num_steps):
# compute gradient and current objective function value
perc_done = i / num_steps
beta = beta0 * (1 - perc_done) + beta_final * perc_done
value, gradient = dJ_fn(params, beta=beta)
# outputs
print(f"step = {i + 1}")
print(f"\tbeta = {beta:.4e}")
print(f"\tJ = {value:.4e}")
print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")
# compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
updates, opt_state = optimizer.update(-gradient, opt_state, params)
params = optax.apply_updates(params, updates)
# cap the parameters
params = jnp.minimum(params, 1.0)
params = jnp.maximum(params, 0.0)
# save history
Js.append(value)
params_history.append(params)
beta_history.append(beta)
power = J(params_history[-1], beta=beta)
Js.append(power)
step = 1
beta = 1.0000e+00
J = -5.0252e-01
grad_norm = 2.0214e-02
step = 2
beta = 1.4222e+00
J = -2.4675e-01
grad_norm = 1.3238e-02
step = 3
beta = 1.8444e+00
J = -2.1442e-01
grad_norm = 9.7928e-03
step = 4
beta = 2.2667e+00
J = -1.6024e-01
grad_norm = 8.3475e-03
step = 5
beta = 2.6889e+00
J = -9.5096e-02
grad_norm = 8.0457e-03
step = 6
beta = 3.1111e+00
J = -1.4640e-02
grad_norm = 8.4054e-03
step = 7
beta = 3.5333e+00
J = 1.3219e-02
grad_norm = 3.5979e-02
step = 8
beta = 3.9556e+00
J = -5.9181e-02
grad_norm = 3.9888e-02
step = 9
beta = 4.3778e+00
J = 1.6096e-01
grad_norm = 1.1203e-02
step = 10
beta = 4.8000e+00
J = 1.4436e-01
grad_norm = 3.3816e-02
step = 11
beta = 5.2222e+00
J = 2.6768e-01
grad_norm = 9.0722e-03
step = 12
beta = 5.6444e+00
J = 2.7818e-01
grad_norm = 2.2982e-02
step = 13
beta = 6.0667e+00
J = 3.4127e-01
grad_norm = 1.0269e-02
step = 14
beta = 6.4889e+00
J = 3.7026e-01
grad_norm = 1.5204e-02
step = 15
beta = 6.9111e+00
J = 4.0043e-01
grad_norm = 2.1623e-02
step = 16
beta = 7.3333e+00
J = 4.4509e-01
grad_norm = 1.0588e-02
step = 17
beta = 7.7556e+00
J = 4.7139e-01
grad_norm = 7.8376e-03
step = 18
beta = 8.1778e+00
J = 4.9354e-01
grad_norm = 9.4674e-03
step = 19
beta = 8.6000e+00
J = 5.1339e-01
grad_norm = 7.4156e-03
step = 20
beta = 9.0222e+00
J = 5.3156e-01
grad_norm = 5.8523e-03
step = 21
beta = 9.4444e+00
J = 5.4703e-01
grad_norm = 5.7467e-03
step = 22
beta = 9.8667e+00
J = 5.6340e-01
grad_norm = 9.4500e-03
step = 23
beta = 1.0289e+01
J = 5.5934e-01
grad_norm = 2.3881e-02
step = 24
beta = 1.0711e+01
J = 5.2488e-01
grad_norm = 4.0713e-02
step = 25
beta = 1.1133e+01
J = 4.8486e-01
grad_norm = 4.8898e-02
step = 26
beta = 1.1556e+01
J = 5.7733e-01
grad_norm = 2.2211e-02
step = 27
beta = 1.1978e+01
J = 5.9732e-01
grad_norm = 2.4999e-02
step = 28
beta = 1.2400e+01
J = 6.0919e-01
grad_norm = 1.4594e-02
step = 29
beta = 1.2822e+01
J = 6.1784e-01
grad_norm = 1.1322e-02
step = 30
beta = 1.3244e+01
J = 6.2263e-01
grad_norm = 9.7416e-03
step = 31
beta = 1.3667e+01
J = 6.3181e-01
grad_norm = 9.6787e-03
step = 32
beta = 1.4089e+01
J = 6.3439e-01
grad_norm = 8.7720e-03
step = 33
beta = 1.4511e+01
J = 6.3985e-01
grad_norm = 5.5513e-03
step = 34
beta = 1.4933e+01
J = 6.4201e-01
grad_norm = 6.7834e-03
step = 35
beta = 1.5356e+01
J = 6.4506e-01
grad_norm = 1.1735e-02
step = 36
beta = 1.5778e+01
J = 6.4778e-01
grad_norm = 6.5648e-03
step = 37
beta = 1.6200e+01
J = 6.5068e-01
grad_norm = 6.2077e-03
step = 38
beta = 1.6622e+01
J = 6.5446e-01
grad_norm = 5.3513e-03
step = 39
beta = 1.7044e+01
J = 6.5579e-01
grad_norm = 9.0492e-03
step = 40
beta = 1.7467e+01
J = 6.5536e-01
grad_norm = 1.5514e-02
step = 41
beta = 1.7889e+01
J = 6.1316e-01
grad_norm = 3.9242e-02
step = 42
beta = 1.8311e+01
J = 4.6327e-01
grad_norm = 7.3251e-02
step = 43
beta = 1.8733e+01
J = 4.7778e-01
grad_norm = 6.6240e-02
step = 44
beta = 1.9156e+01
J = 6.3996e-01
grad_norm = 1.1438e-02
step = 45
beta = 1.9578e+01
J = 6.1807e-01
grad_norm = 3.0977e-02
Results#
Finally, we can inpect the results.
First we plot the objective function over iteration number and note that it steadily increases.
[27]:
plt.plot(Js)
plt.xlabel("iterations")
plt.ylabel("objective function")
plt.ylim(-1.5, 1)
plt.show()
We grab the final design parameters and beta
value.
[28]:
params_final = params1_final, params2_final = params_history[-1]
beta_final = beta_history[-1]
And use these to construct the Tidy3D
simulations corresponding to the final optimized state of each of the components.
[29]:
sim1_final = make_sim(params1_final, beta=beta_final, source_index=0, shape=(1,2))
sim2_final = make_sim(params2_final, beta=beta_final, source_index=0, shape=(2,2))
sim3_final = make_sim(params2_final, beta=beta_final, source_index=1, shape=(2,2))
Let’s plot these simulations. Note that the 3rd and 2nd are the same, except with different source, so we can visualize the fields sourced from each of the individual inputs.
[30]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, tight_layout=True, figsize=(10,6))
sim1_final.plot_eps(z=0, ax=ax1)
sim2_final.plot_eps(z=0, ax=ax2)
sim3_final.plot_eps(z=0, ax=ax3)
ax1.set_title('first component (splitter)')
ax2.set_title('second component (combiner)')
ax3.set_title('second component (combiner)')
plt.show()
To visualize the fields, let’s create and add a FieldMonitor
to each of the simulations.
[31]:
field_mnt = td.FieldMonitor(
size=(td.inf, td.inf, 0),
freqs=[freq0],
name="field_mnt",
colocate=True,
)
sim1_final = sim1_final.copy(update=dict(monitors=(field_mnt,)))
sim2_final = sim2_final.copy(update=dict(monitors=(field_mnt,)))
sim3_final = sim3_final.copy(update=dict(monitors=(field_mnt,)))
Next, run the simulations
[37]:
sims_final = (sim1_final, sim2_final, sim3_final)
sim_data1_final, sim_data2_final, sim_data3_final = tda.web.run_async(sims_final, path_dir="data", verbose=False)
and plot the results.
[38]:
f, (axes_eps, axes_fld, axes_int) = plt.subplots(3, 3, figsize=(10, 8), tight_layout=True)
sim_datas = [sim_data1_final, sim_data2_final, sim_data3_final]
for sim_data_final, ax_eps, ax_fld, ax_int in zip(sim_datas, axes_eps, axes_fld, axes_int):
sim_data_final.simulation.plot_eps(z=0.01, ax=ax_eps)
sim_data_final.plot_field("field_mnt", "Ez", z=0, ax=ax_fld)
sim_data_final.plot_field("field_mnt", "E", "abs^2", z=0, ax=ax_int)
While this gives an interesting picture, what we really want to visualize is how the fields look under our design conditions when phi=0
and phi=pi
. For that, we write a function to compute the source parameters for the 2nd component under values of phi
and run that simulation.
[39]:
def get_sim_data_right(phi):
out_top_1 = sim_data1_final["top"].amps.sel(direction="+", f=freq0, mode_index=0)
out_bot_1 = sim_data1_final["bot"].amps.sel(direction="+", f=freq0, mode_index=0)
# apply phi phase shift to top arm
phase_top = np.angle(out_top_1) + phi
phase_bot = np.angle(out_bot_1)
src_top = sim2_final.sources[0]
src_bot = sim3_final.sources[0]
src_time_top = src_top.source_time.updated_copy(amplitude=abs(out_top_1), phase=phase_top)
src_time_bot = src_bot.source_time.updated_copy(amplitude=abs(out_bot_1), phase=phase_bot)
src_top = src_top.updated_copy(source_time=src_time_top)
src_bot = src_bot.updated_copy(source_time=src_time_bot)
sim_right = sim2_final.updated_copy(sources=[src_top, src_bot])
return tda.web.run(sim_right, task_name=f"phi={phi:.3f}")
We compute the field data for the output component for both phi=0
and phi=pi
.
[40]:
sim_data_right_p0 = get_sim_data_right(phi=0)
sim_data_right_pi = get_sim_data_right(phi=np.pi)
[17:59:00] Created task 'phi=0.000' with task_id 'fdve-303d2de1-75c2-438a-9087-f48c40f1abb2v1'.
View task using web UI at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-303d2de1-75c2- 438a-9087-f48c40f1abb2v1'.
[17:59:03] status = queued
[17:59:06] status = preprocess
[17:59:10] Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get the billed FlexCredit cost after a simulation run.
starting up solver
[17:59:11] running solver
To cancel the simulation, use 'web.abort(task_id)' or 'web.delete(task_id)' or abort/delete the task in the web UI. Terminating the Python script will not stop the job running on the cloud.
[17:59:17] early shutoff detected, exiting.
status = postprocess
[17:59:21] status = success
[17:59:22] View simulation result at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-303d2de1-75c2- 438a-9087-f48c40f1abb2v1'.
[17:59:23] loading SimulationData from simulation_data.hdf5
[17:59:24] Created task 'phi=3.142' with task_id 'fdve-fe2d9123-5397-4a8c-9fb0-86c13495bb41v1'.
View task using web UI at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-fe2d9123-5397- 4a8c-9fb0-86c13495bb41v1'.
[17:59:26] status = queued
[17:59:30] status = preprocess
[17:59:37] Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get the billed FlexCredit cost after a simulation run.
starting up solver
running solver
To cancel the simulation, use 'web.abort(task_id)' or 'web.delete(task_id)' or abort/delete the task in the web UI. Terminating the Python script will not stop the job running on the cloud.
[17:59:43] early shutoff detected, exiting.
[17:59:44] status = postprocess
[17:59:48] status = success
View simulation result at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-fe2d9123-5397- 4a8c-9fb0-86c13495bb41v1'.
[17:59:50] loading SimulationData from simulation_data.hdf5
And plot the results. Note that the device works exactly as intended! When phi=0
, the light is transmitted into the top port and when phi=pi
, the light is transmitted into the bottom port.
[41]:
alpha = 0.0
f, (axes_eps, axes_fld, axes_int) = plt.subplots(3, 3, figsize=(10, 8), tight_layout=True)
sim_datas = [sim_data1_final, sim_data_right_p0, sim_data_right_pi]
for sim_data_final, ax_eps, ax_fld, ax_int, phi in zip(sim_datas, axes_eps, axes_fld, axes_int, (None, "0", "π")):
sim_data_final.simulation.plot_eps(z=0.01, ax=ax_eps, source_alpha=alpha, monitor_alpha=0)
sim_data_final.plot_field("field_mnt", "Ez", z=0, ax=ax_fld)
sim_data_final.plot_field("field_mnt", "E", "abs^2", z=0, ax=ax_int)
for ax in (ax_eps, ax_fld, ax_int):
if phi is not None:
ax.set_title(rf'output sim (phi={phi})')
else:
ax.set_title("input sim")
With some minor modifications to this MZI device (such as adding a 2nd input port and adding a 2nd phase shifter on the output), we can implement any unitary 2x2 matrix and build very complex components for performing arbitrary linear operations in optical circuits, such as optical neural networks.
With the adjoint
plugin of Tidy3D
and the differentiable circuit modeling of sax
, we have a convenient tool for combining the power and flexibility of inverse design with the modularity of traditional component design and can perform co-optimization of individual components with minimal overhead.
[52]:
power_top_p0 = jnp.sum(jnp.abs(jnp.array(sim_data_right_p0.output_data[0].amps.values))**2)
power_bot_p0 = jnp.sum(jnp.abs(jnp.array(sim_data_right_p0.output_data[1].amps.values))**2)
power_top_pi = jnp.sum(jnp.abs(jnp.array(sim_data_right_pi.output_data[0].amps.values))**2)
power_bot_pi = jnp.sum(jnp.abs(jnp.array(sim_data_right_pi.output_data[1].amps.values))**2)
[58]:
print('phi = 0')
print(f' Transmission_top = {100 * power_top_p0:.2f} %')
print(f' Transmission_bot = {100 * power_bot_p0:.2f} %')
print('phi = pi')
print(f' Transmission_top = {100 * power_top_pi:.2f} %')
print(f' Transmission_bot = {100 * power_bot_pi:.2f} %')
phi = 0
Transmission_top = 58.65 %
Transmission_bot = 0.91 %
phi = pi
Transmission_top = 0.39 %
Transmission_bot = 79.51 %
[ ]: