Inverse design integrated with circuit simulation#

In this tutorial, we will show how to integrate the adjoint plugin of Tidy3D with a differentiable optical circuit simulator sax. This allows one to model a complicated circuit composed of many connected components, each simulated independently using Tidy3D. Through the adjoint plugin and jax, the gradients of all of the individual components are similarly connected. This allows one to write an objective function in terms of the scattering matrix of the entire circuit and optimize this function with respect to the design parameters in each of the individual Tidy3D simulations.

To demonstrate this capability, in this notebook we optimize a Mach-Zehnder Interferometer (MZI) circuit. This simplified MZI has a single input and two outputs. We wish to switch the transmitted power between the two outputs depending on a phase shift applied to a waveguide in the system. We set up our circuit to have a single splitter component that takes the input light and splits it into two waveguides, we apply the phase shift to one of these waveguides, and then add a component that combines the light from the two waveguides, mixes it together, and sends it to our two outputs. The scattering matrices of the two components are computed using Tidy3D simulations and the waveguide connections and phase shifter are defined using the sax circuit simulator. As all of the gradients are passed automatically through jax, we then optimize our circuit with respect to the permittivity distributions in each of the two Tidy3D simulations simultaneously.

Below is a schematic of this process and some of the variable labels we use in the code.

Schematic of the PSR

To install the jax module required for this feature, we recommend running pip install "tidy3d[jax]". You will also need to pip install sax.

If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.

Setup#

First we import all of the packages we need.

[1]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import sax

import tidy3d as td
import tidy3d.plugins.adjoint as tda

np.random.seed(2)
/Users/twhughes/.pyenv/versions/3.10.9/lib/python3.10/site-packages/sax/backends/__init__.py:24: UserWarning: klujax not found. Please install klujax for better performance during circuit evaluation!
  warnings.warn(

Tidy3D Simulation Parameters#

Then we will initialize some parameters needed for our individual component simulations.

For this application, we model each of the Tidy3D components as square design regions accepting 1 or 2 inputs and transmitting to 1 or 2 outputs.

[2]:
# wavelength and frequency
wavelength = 1.0
freq0 = td.C_0 / wavelength

# resolution control
steps_per_wvl = 20

# space between boxes and PML
buffer = 1.0 * wavelength

# optimize region size
lz = td.inf
lx = 3.0
ly = lx
wg_width = 0.4

# num cells
nx = 120
ny = nx
num_cells = nx * ny

# position of source and monitor (constant for all)
source_x = -lx / 2 - buffer * 0.8
meas_x = lx / 2 + buffer * 0.8

# total size
Lx = lx + 2 * buffer
Ly = ly + 2 * buffer
Lz = 0

# permittivity info
eps_wg = 2.75
eps_deviation_random = 0.5

# note, we choose the starting parameters
params0 = np.random.random((nx, ny))

# frequency width and run time
freqw = freq0 / 10
run_time = 50 / freqw

Because we want to be able to model a general system of 1 or 2 inputs coupling to 1 or 2 outputs, we pre-define all of the possible waveguide configurations beforehand to make things simpler later.

[3]:
big_number = Lx * 10

dy = (ly - 2 * wg_width) / 4 + wg_width/2

# all of the possible input and output waveguides
waveguide_in_center = td.Structure(
    geometry=td.Box(
        size=(big_number, wg_width, lz),
        center=(-big_number/2, 0, 0),
    ),
    medium=td.Medium(permittivity=eps_wg)
)

waveguide_in_top = td.Structure(
    geometry=td.Box(
        size=(big_number, wg_width, lz),
        center=(-big_number/2, +dy, 0),
    ),
    medium=td.Medium(permittivity=eps_wg)
)

waveguide_in_bot = td.Structure(
    geometry=td.Box(
        size=(big_number, wg_width, lz),
        center=(-big_number/2, -dy, 0),
    ),
    medium=td.Medium(permittivity=eps_wg)
)

waveguide_out_center = td.Structure(
    geometry=td.Box(
        size=(big_number, wg_width, lz),
        center=(+big_number/2, 0, 0),
    ),
    medium=td.Medium(permittivity=eps_wg),
    name="center"
)

waveguide_out_top = td.Structure(
    geometry=td.Box(
        size=(big_number, wg_width, lz),
        center=(+big_number/2, +dy, 0),
    ),
    medium=td.Medium(permittivity=eps_wg),
    name="top"
)

waveguide_out_bot = td.Structure(
    geometry=td.Box(
        size=(big_number, wg_width, lz),
        center=(+big_number/2, -dy, 0),
    ),
    medium=td.Medium(permittivity=eps_wg),
    name="bot"
)

We also define some information about our mode source and monitor geometries.

[4]:
# the source and measurement plane size
mode_size = (0, wg_width * 3, lz)

# source plane centered at y=0
source_plane_base = td.Box(
    center=[source_x, 0, 0],
    size=mode_size,
)

def get_source_plane(waveguide: td.Structure) -> td.Box:
    """SOurce plane with y position moved to cover a specific waveguide"""
    return source_plane_base.updated_copy(center=(source_x, waveguide.geometry.center[1], 0))

measure_plane = td.Box(
    center=[meas_x, 0, 0],
    size=mode_size,
)

Design Parameterization#

As in many of the other adjoint demos, now we define our design region structure using a JaxCustomMedium generated as a function of our design parameters. We will apply filtering and projection to create smooth features. For more details, we refer the reader to our intro to inverse design tutorials.

[5]:
from tidy3d.plugins.adjoint.utils.filter import ConicFilter
from typing import List

radius = .120
beta = 50

conic_filter = ConicFilter(radius=radius, design_region_dl=float(lx) / nx)

def tanh_projection(x, beta, eta=0.5):
    tanhbn = jnp.tanh(beta * eta)
    num = tanhbn + jnp.tanh(beta * (x - eta))
    den = tanhbn + jnp.tanh(beta * (1 - eta))
    return num / den

def filter_project(x, beta, eta=0.5):
    x = conic_filter.evaluate(x)
    return tanh_projection(x, beta=beta, eta=eta)

def pre_process(params, beta):
    """Get the permittivity values (1, eps_wg) array as a function of the parameters (0,1)"""
    params1 = filter_project(params, beta=beta)
    params2 = filter_project(params1, beta=beta)
    return params2

def get_eps(params, beta):
    params = pre_process(params, beta=beta)
    eps_min = 1.0001
    eps_values = eps_min + (eps_wg - eps_min) * params
    return eps_values

def make_input_structures(params, beta) -> List[tda.JaxStructure]:

    size_box_x = float(lx) / nx
    size_box_y = float(ly) / ny
    size_box = (size_box_x, size_box_y, lz)

    x0_min = -lx / 2 + size_box_x / 2
    y0_min = -ly / 2 + size_box_y / 2

    input_structures = []

    coords_x = [x0_min + index_x * size_box_x - 1e-5 for index_x in range(nx)]
    coords_y = [y0_min + index_y * size_box_y - 1e-5 for index_y in range(ny)]

    coords = dict(x=coords_x, y=coords_y, z=[0], f=[freq0])

    eps_boxes = get_eps(params, beta=beta).reshape((nx, ny, 1, 1))

    field_components = {
        f"eps_{dim}{dim}": tda.JaxDataArray(values=eps_boxes, coords=coords) for dim in "xyz"
    }
    eps_dataset = tda.JaxPermittivityDataset(**field_components)
    custom_medium = tda.JaxCustomMedium(eps_dataset=eps_dataset)
    box = tda.JaxBox(center=(0, 0, 0), size=(lx, ly, lz))
    custom_structure = tda.JaxStructure(geometry=box, medium=custom_medium)
    return [custom_structure]

Base Simulation#

Next, we write a β€œbase” simulation (without sources or monitors) as a function of our input parameters. We also accept the shape of our component, which specifies the number of inputs and outputs. This determines which waveguides we add to our simulation.

[6]:
def make_sim_base(params, beta, shape) -> tda.JaxSimulation:

    input_structures = make_input_structures(params, beta=beta)

    num_wg_in, num_wg_out = shape
    if num_wg_in == 1:
        wgs_in = [waveguide_in_center]
    else:
        wgs_in = [waveguide_in_top, waveguide_in_bot]

    if num_wg_out == 1:
        wgs_out = [waveguide_out_center]
    else:
        wgs_out = [waveguide_out_top, waveguide_out_bot]

    return tda.JaxSimulation(
        size=[Lx, Ly, Lz],
        grid_spec=td.GridSpec.auto(min_steps_per_wvl=steps_per_wvl, wavelength=wavelength),
        structures=wgs_in + wgs_out,
        input_structures=input_structures,
        sources=[],
        monitors=[],
        output_monitors=[],
        run_time=run_time,
        subpixel=True,
        boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=False),
        shutoff=1e-8,
        courant=0.9,
    )

Let’s make a base simulation for a few different shapes and plot them to make sure they work properly.

[7]:
f, ((ax1, ax2), (ax3, ax4)) = f, (axtop, axbot) = f, axes = plt.subplots(2, 2, tight_layout=True, figsize=(10,8))

for num_in in (1,2):
    for num_out in (1,2):
        ax = axes[num_in - 1, num_out-1]
        shape = (num_in, num_out)
        sim = make_sim_base(params0, beta=5.0, shape=shape)
        _ = sim.plot_eps(z=0, ax=ax)
        ax.set_title(f"sim for shape={shape}")

plt.show()
../_images/notebooks_AdjointPlugin11CircuitMZI_13_0.png

Mode Solver#

Next, we’ll run the mode solver on one of these waveguides to make sure we inject and measure the desired waveguide modes in our system.

[8]:
from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.web import run as run_mode_solver
num_modes = 4
mode_spec = td.ModeSpec(num_modes=num_modes)

sim_start = make_sim_base(params0, beta=5.0, shape=(1,1))

mode_solver = ModeSolver(
    simulation=sim_start.to_simulation()[0],
    plane=get_source_plane(sim_start.structures[0]),
    mode_spec=td.ModeSpec(num_modes=num_modes),
    freqs=[freq0]
)
modes = run_mode_solver(mode_solver)
[09:51:22] Mode solver created with
           task_id='fdve-c8eaa444-395e-4f6b-800e-f84df8599f86v1',
           solver_id='mo-b647dd3c-92f7-4d74-861f-6d3b2acf472f'.
[09:51:27] Mode solver status: queued
[09:51:29] Mode solver status: running
[09:51:40] Mode solver status: success

Let’s plot the modes.

[9]:
print("Effective index of computed modes: ", np.array(modes.n_eff))

fig, axs = plt.subplots(num_modes, 2, figsize=(10, 14), tight_layout=True)
for mode_ind in range(num_modes):
    for field_ind, field_name in enumerate(("Ey", "Ez")):
        field = modes.field_components[field_name].sel(mode_index=mode_ind)
        ax = axs[mode_ind, field_ind]
        field.real.plot(ax=ax)
        ax.set_title(f'index={mode_ind}, {field_name}(y)')
Effective index of computed modes:  [[1.4718767 1.3555466 1.007765  0.9316404]]
../_images/notebooks_AdjointPlugin11CircuitMZI_17_1.png

We wish to inject the fundamental Ez-polarized mode, which is given by mode_index=0 above. Thus, we make a variable to store this and re-set the ModeSpec.num_modes to account for this index without being too high, which could waste computation.

[10]:
mode_index = 0
num_modes = mode_index + 1

mode_spec = td.ModeSpec(num_modes=num_modes)

Sources and Monitors#

Next we will define our input sources and output monitors for this component. We’ll write these as functions of the input and output waveguides so the process of generating them is more general.

[11]:
def make_source(waveguide):

    # source seeding the simulation
    return td.ModeSource(
        source_time=td.GaussianPulse(freq0=freq0, fwidth=freqw),
        center=[source_x, waveguide.geometry.center[1], 0],
        size=mode_size,
        mode_index=mode_index,
        mode_spec=mode_spec,
        direction="+",
    )

def make_output_monitors(waveguides):

    monitors = []

    for waveguide in waveguides:

        # monitor where we compute the objective function from
        measurement_monitor = td.ModeMonitor(
            center=[meas_x, waveguide.geometry.center[1], 0],
            size=mode_size,
            freqs=[freq0],
            mode_spec=mode_spec,
            name=waveguide.name,
        )
        monitors.append(measurement_monitor)

    return monitors

Final Simulation#

Finally, we write a function to generate a component simulation based on the design parameters, projection strength, shape (inputs x outputs), and the index of the source we wish to inject.

[12]:
def make_sim(params, beta, shape, source_index: int):
    sim = make_sim_base(params, beta=beta, shape=shape)
    num_wgs_in, num_wgs_out = shape

    wg_in = sim.structures[source_index]
    forward_source_in = make_source(wg_in)

    wgs_out = list(sim.structures)[int(num_wgs_in):]
    output_monitors = make_output_monitors(wgs_out)

    return sim.updated_copy(
        sources=[forward_source_in],
        output_monitors=output_monitors
    )

Let’s generate a simulation and plot it with the sources and monitors to make sure it works properly.

[13]:
ax = make_sim(params0, shape=(2,1), beta=1, source_index=0).plot(z=0)
../_images/notebooks_AdjointPlugin11CircuitMZI_25_0.png

Defining Circuit#

With our function to generate the component simulations, now we can start focusing on combining these components together into a circuit using sax. We highly recommend referring to the sax documentation for any additional information, but will give a brief tutorial of the tool through the next few cells.

Components#

In sax, the individual β€œnodes” in the circuit are defined as functions that return the scattering matrix of that component as a dictionary. In our case, our individual components are modelled as Tidy3D simulations. Therefore, we will write our component function to accept the design parameters and run one Tidy3D simulation per input source to construct the scattering matrix of the system.

[14]:
def component(params=params0, beta=5, shape=(2,2)):

    num_in, num_out = shape
    num_in = int(num_in)
    num_out = int(num_out)

    def get_S_column(sim_data):
        """Compute a column of the scattering matrix for a single dataset."""
        outputs = []
        for out_mnt in sim_data.simulation.output_monitors:
            amps = sim_data[out_mnt.name].amps
            amp = jnp.sum(amps.sel(mode_index=mode_index, direction="+", f=freq0))
            outputs.append(amp)
        return outputs

    sims = [make_sim(params, shape=shape, beta=beta, source_index=source_index) for source_index in range(num_in)]
    sim_datas = tda.web.run_async(sims, verbose=False, path_dir="data")

    s_columns = [get_S_column(sim_data) for sim_data in sim_datas]

    # assemble the scattering matrix
    s_dict = {}
    for index_in in range(num_in):
        label_in = "in" + str(index_in)
        s_col = s_columns[index_in]
        for index_out in range(num_out):
            label_out = "out" + str(index_out)
            s_element = s_col[index_out]
            s_dict[(label_in, label_out)] = s_element

    return sax.reciprocal(s_dict)

Note: these component functions must only contain keyword arguments (like x=1) with default values. So we define params=params0 and beta=5 as defaults for now, but will show how to pass our own values later.

Let’s test this out by calling this function with some example inputs and visualizing the s-matrix.

We see that it returns a dictionary where the keys are tuples mapping the names of our input waveguide to our output waveguide.

[15]:
component_sdict = component(params0, beta=1, shape=(1,2))
component_sdict
[15]:
{('in0', 'out0'): Array(0.38201833-0.0332941j, dtype=complex64),
 ('in0', 'out1'): Array(0.40652457-0.05268222j, dtype=complex64),
 ('out0', 'in0'): Array(0.38201833-0.0332941j, dtype=complex64),
 ('out1', 'in0'): Array(0.40652457-0.05268222j, dtype=complex64)}

Next, we define a more simple component function to model our phase shifter. This component simply takes the phase value phi and adds it to the connection.

[16]:
def phase_shifter(phi: float = 0.0):
    phase_added = jnp.exp(1j * phi)
    s_dict = {("in", "out"): phase_added}
    return sax.reciprocal(s_dict)

Circuit#

Next, we need to combine these components together into a circuit. We do this through sax.circuit, which lets us define our β€œinstances” (these component functions defined earlier), the β€œconnections” between each of these instances, and then the β€œports” for the entire circuit.

We wish to create a (1->2) component, with one output connected to our phase shifter, and then combine everything in a (2->2) component. We define these components and connections below and then specify the ports for the entire S-matrix, which is a (1->2) system.

[17]:
import functools

circuit_fn, _ = sax.circuit(
    netlist={
        "instances": {
            "splitter": functools.partial(component, shape=(1,2)),
            "phase_shifter": phase_shifter,
            "combiner": functools.partial(component, shape=(2,2)),
        },
        "connections": {
            "splitter,out0": "phase_shifter,in",
            "phase_shifter,out": "combiner,in0",
            "splitter,out1": "combiner,in1",
        },
        "ports": {
            "in": "splitter,in0",
            "out0": "combiner,out0",
            "out1": "combiner,out1",
        },
    }
)
circuit_fn
[17]:
<function sax.circuit._flat_circuit.<locals>._circuit(*, splitter={'params': Array([[0.4359949 , 0.02592623, 0.5496625 , ..., 0.17671216, 0.59125733,
        0.48926616],
       [0.54790777, 0.69952065, 0.24581116, ..., 0.6424524 , 0.38690034,
        0.85511965],
       [0.3807926 , 0.17830983, 0.7816594 , ..., 0.4921191 , 0.9379131 ,
        0.13442676],
       ...,
       [0.35449517, 0.7365258 , 0.73508275, ..., 0.62516195, 0.26062906,
        0.5743313 ],
       [0.87019104, 0.9364767 , 0.56900996, ..., 0.47169012, 0.08907937,
        0.9284895 ],
       [0.25833175, 0.5660962 , 0.85214543, ..., 0.31971204, 0.79901004,
        0.170014  ]], dtype=float32), 'beta': Array(5., dtype=float32), 'shape': Array([1., 2.], dtype=float32)}, phase_shifter={'phi': Array(0., dtype=float32)}, combiner={'params': Array([[0.4359949 , 0.02592623, 0.5496625 , ..., 0.17671216, 0.59125733,
        0.48926616],
       [0.54790777, 0.69952065, 0.24581116, ..., 0.6424524 , 0.38690034,
        0.85511965],
       [0.3807926 , 0.17830983, 0.7816594 , ..., 0.4921191 , 0.9379131 ,
        0.13442676],
       ...,
       [0.35449517, 0.7365258 , 0.73508275, ..., 0.62516195, 0.26062906,
        0.5743313 ],
       [0.87019104, 0.9364767 , 0.56900996, ..., 0.47169012, 0.08907937,
        0.9284895 ],
       [0.25833175, 0.5660962 , 0.85214543, ..., 0.31971204, 0.79901004,
        0.170014  ]], dtype=float32), 'beta': Array(5., dtype=float32), 'shape': Array([2., 2.], dtype=float32)}) -> 'SType'>

Passing individual parameters#

The circuit_fn returned is a function that accepts parameters to each of our component functions. It is worth noting that we can pass different inputs to different functions by passing them as keyword arguments, as shown below. This is important to note as we will be optimizing each of the Tidy3D components individually with their own independent parameters.

Let’s call the circuit function and print the result, which is the S-matrix for the entire circuit given our passed parameters.

[18]:
# how to pass specific parmaeters to each of the sub-functions for the instances
s = circuit_fn(splitter={"params": params0}, combiner={"params": 0 * params0}, beta=3, phase_sifter=dict(phi=2.0))

[19]:
s
[19]:
{('out0', 'out0'): Array(0.+0.j, dtype=complex64),
 ('out0', 'out1'): Array(0.+0.j, dtype=complex64),
 ('out1', 'out0'): Array(0.+0.j, dtype=complex64),
 ('out1', 'out1'): Array(0.+0.j, dtype=complex64),
 ('in', 'in'): Array(0.+0.j, dtype=complex64),
 ('in', 'out0'): Array(0.09807562-0.12380885j, dtype=complex64),
 ('in', 'out1'): Array(0.06793377-0.14270785j, dtype=complex64),
 ('out0', 'in'): Array(0.09807562-0.12380885j, dtype=complex64),
 ('out1', 'in'): Array(0.06793377-0.14270784j, dtype=complex64)}

Objective Function#

With our circuit defined, we can now combine everything into a single objective function. We first write a penalty function that evaluates how well the structure respects the feature size constraints that we defined earlier.

[20]:
def penalty(params, beta, delta_eps=0.49):
    params = pre_process(params, beta=beta)
    dilate_fn = lambda x: filter_project(x, beta=100, eta=0.5-delta_eps)
    eroded_fn = lambda x: filter_project(x, beta=100, eta=0.5+delta_eps)

    params_dilate_erode = eroded_fn(dilate_fn(params))
    params_erode_dilate = dilate_fn(eroded_fn(params))
    diff = params_dilate_erode - params_erode_dilate
    return jnp.linalg.norm(diff) / jnp.linalg.norm(jnp.ones_like(diff))

We then write a combined objective function that accepts our parameters for each of the individual components (as one array params) and the projection strength beta applied to each design region.

The objective function uses these parameters to construct each of the individual components and simulates them to compute their scattering matrix. Then, it defines a circuit-level objective to look at the transmission of the entire circuit into the two output ports as a function of the phase shift phi. We seek to maximize transmission to the top port when phi=0 and the bottom port when phi=pi.

[21]:
def J(params, beta) -> float:
    """Circuit-level objective function."""

    params1, params2 = params

    circuit_function = functools.partial(circuit_fn, splitter={"params": params1}, combiner={"params": params2}, beta=beta)

    def top_minus_bot(phi: float) -> float:
        """Power in top port minus power in bottom port."""

        #evaluate the circuit at phi
        sdict = circuit_function(phase_shifter={"phi": phi})

        # S-parameters for the whole circuit
        s_00 = sdict["in", "out0"]
        s_01 = sdict["in", "out1"]

        # power at ports
        power_top = jnp.sum(jnp.abs(s_00)**2)
        power_bot = jnp.sum(jnp.abs(s_01)**2)

        # top power minus bottom power
        return power_top - power_bot

    # combine objectives together: at worst, it will be -1, at best + 1.
    objective = (top_minus_bot(0.0) - top_minus_bot(np.pi)) / 2.0

    # combined penalty for both devices
    penalty_weight = 0.5
    feature_penalty1 = penalty(params=params1, beta=beta)
    feature_penalty2 = penalty(params=params2, beta=beta)
    feature_penalty = penalty_weight * (feature_penalty1 + feature_penalty2) / 2.0

    return objective - feature_penalty

Next we use jax to compute a function that returns the value of this objective function and its gradient when passed some input parameters.

[22]:
dJ_fn = jax.value_and_grad(J)

Let’s try running this function with some example parameters and inspect the results.

[23]:
params0_combined = np.stack((params0, params0), axis=0)

val, grad = dJ_fn(params0_combined, beta=1)

[24]:
print(val, grad)
-0.50251895 [[[ 7.14473344e-06  9.11673851e-06  1.05827448e-05 ... -9.17585021e-06
   -7.92271931e-06 -6.19440652e-06]
  [ 8.68899588e-06  1.10152214e-05  1.26921068e-05 ... -1.13362294e-05
   -9.85804763e-06 -7.75598346e-06]
  [ 9.43710984e-06  1.18747666e-05  1.35484370e-05 ... -1.26476180e-05
   -1.10918045e-05 -8.78276478e-06]
  ...
  [-3.22923770e-05 -4.00895296e-05 -4.48482424e-05 ...  4.50681364e-05
    4.12528025e-05  3.38398604e-05]
  [-2.90573880e-05 -3.61831262e-05 -4.05935389e-05 ...  4.25772196e-05
    3.88274893e-05  3.17442318e-05]
  [-2.33855517e-05 -2.92259228e-05 -3.28775859e-05 ...  3.56646669e-05
    3.24215143e-05  2.64082391e-05]]

 [[ 3.00405318e-05  3.74189149e-05  4.21421937e-05 ... -3.59632759e-05
   -3.20999643e-05 -2.58842447e-05]
  [ 3.59139303e-05  4.45573241e-05  5.00202914e-05 ... -4.29933280e-05
   -3.85160092e-05 -3.11923541e-05]
  [ 3.78112854e-05  4.67239806e-05  5.22388145e-05 ... -4.54894471e-05
   -4.09546483e-05 -3.33193311e-05]
  ...
  [-3.32819945e-06 -3.46686102e-06 -3.34001015e-06 ...  9.98089945e-06
    9.28078589e-06  7.96504173e-06]
  [-2.19371759e-06 -2.00998147e-06 -1.55896669e-06 ...  9.13287022e-06
    8.60987348e-06  7.42999919e-06]
  [-1.26351188e-06 -9.15017154e-07 -3.27843736e-07 ...  7.39111147e-06
    7.05704315e-06  6.11908763e-06]]]
[25]:
print(grad.shape)
(2, 120, 120)

The resulting value and gradient are reasonable. Note the gradient is shaped (2, nx, ny), which represents the gradients with respect to each of the two (nx, nx) pixelated grids for the individual components.

Optimization Loop#

Next, as in the other examples, we use optax to run the optimization of this entire circuit using gradient descent using the Adam optimization method.

[26]:
import optax

# hyperparameters
num_steps = 45
learning_rate = 1.0

# initialize adam optimizer with starting parameters
params = params0_combined.copy()
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# store history
Js = []
params_history = [params]
beta_history = []

beta0 = 1.0
beta_final = 20

for i in range(num_steps):

    # compute gradient and current objective function value

    perc_done = i / num_steps
    beta = beta0 * (1 - perc_done) + beta_final * perc_done
    value, gradient = dJ_fn(params, beta=beta)

    # outputs
    print(f"step = {i + 1}")
    print(f"\tbeta = {beta:.4e}")
    print(f"\tJ = {value:.4e}")
    print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")

    # compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
    updates, opt_state = optimizer.update(-gradient, opt_state, params)
    params = optax.apply_updates(params, updates)

    # cap the parameters
    params = jnp.minimum(params, 1.0)
    params = jnp.maximum(params, 0.0)

    # save history
    Js.append(value)
    params_history.append(params)
    beta_history.append(beta)

power = J(params_history[-1], beta=beta)
Js.append(power)
step = 1
        beta = 1.0000e+00
        J = -5.0252e-01
        grad_norm = 2.0214e-02
step = 2
        beta = 1.4222e+00
        J = -2.4675e-01
        grad_norm = 1.3238e-02
step = 3
        beta = 1.8444e+00
        J = -2.1442e-01
        grad_norm = 9.7928e-03
step = 4
        beta = 2.2667e+00
        J = -1.6024e-01
        grad_norm = 8.3475e-03
step = 5
        beta = 2.6889e+00
        J = -9.5096e-02
        grad_norm = 8.0457e-03
step = 6
        beta = 3.1111e+00
        J = -1.4640e-02
        grad_norm = 8.4054e-03
step = 7
        beta = 3.5333e+00
        J = 1.3219e-02
        grad_norm = 3.5979e-02
step = 8
        beta = 3.9556e+00
        J = -5.9181e-02
        grad_norm = 3.9888e-02
step = 9
        beta = 4.3778e+00
        J = 1.6096e-01
        grad_norm = 1.1203e-02
step = 10
        beta = 4.8000e+00
        J = 1.4436e-01
        grad_norm = 3.3816e-02
step = 11
        beta = 5.2222e+00
        J = 2.6768e-01
        grad_norm = 9.0722e-03
step = 12
        beta = 5.6444e+00
        J = 2.7818e-01
        grad_norm = 2.2982e-02
step = 13
        beta = 6.0667e+00
        J = 3.4127e-01
        grad_norm = 1.0269e-02
step = 14
        beta = 6.4889e+00
        J = 3.7026e-01
        grad_norm = 1.5204e-02
step = 15
        beta = 6.9111e+00
        J = 4.0043e-01
        grad_norm = 2.1623e-02
step = 16
        beta = 7.3333e+00
        J = 4.4509e-01
        grad_norm = 1.0588e-02
step = 17
        beta = 7.7556e+00
        J = 4.7139e-01
        grad_norm = 7.8376e-03
step = 18
        beta = 8.1778e+00
        J = 4.9354e-01
        grad_norm = 9.4674e-03
step = 19
        beta = 8.6000e+00
        J = 5.1339e-01
        grad_norm = 7.4156e-03
step = 20
        beta = 9.0222e+00
        J = 5.3156e-01
        grad_norm = 5.8523e-03
step = 21
        beta = 9.4444e+00
        J = 5.4703e-01
        grad_norm = 5.7467e-03
step = 22
        beta = 9.8667e+00
        J = 5.6340e-01
        grad_norm = 9.4500e-03
step = 23
        beta = 1.0289e+01
        J = 5.5934e-01
        grad_norm = 2.3881e-02
step = 24
        beta = 1.0711e+01
        J = 5.2488e-01
        grad_norm = 4.0713e-02
step = 25
        beta = 1.1133e+01
        J = 4.8486e-01
        grad_norm = 4.8898e-02
step = 26
        beta = 1.1556e+01
        J = 5.7733e-01
        grad_norm = 2.2211e-02
step = 27
        beta = 1.1978e+01
        J = 5.9732e-01
        grad_norm = 2.4999e-02
step = 28
        beta = 1.2400e+01
        J = 6.0919e-01
        grad_norm = 1.4594e-02
step = 29
        beta = 1.2822e+01
        J = 6.1784e-01
        grad_norm = 1.1322e-02
step = 30
        beta = 1.3244e+01
        J = 6.2263e-01
        grad_norm = 9.7416e-03
step = 31
        beta = 1.3667e+01
        J = 6.3181e-01
        grad_norm = 9.6787e-03
step = 32
        beta = 1.4089e+01
        J = 6.3439e-01
        grad_norm = 8.7720e-03
step = 33
        beta = 1.4511e+01
        J = 6.3985e-01
        grad_norm = 5.5513e-03
step = 34
        beta = 1.4933e+01
        J = 6.4201e-01
        grad_norm = 6.7834e-03
step = 35
        beta = 1.5356e+01
        J = 6.4506e-01
        grad_norm = 1.1735e-02
step = 36
        beta = 1.5778e+01
        J = 6.4778e-01
        grad_norm = 6.5648e-03
step = 37
        beta = 1.6200e+01
        J = 6.5068e-01
        grad_norm = 6.2077e-03
step = 38
        beta = 1.6622e+01
        J = 6.5446e-01
        grad_norm = 5.3513e-03
step = 39
        beta = 1.7044e+01
        J = 6.5579e-01
        grad_norm = 9.0492e-03
step = 40
        beta = 1.7467e+01
        J = 6.5536e-01
        grad_norm = 1.5514e-02
step = 41
        beta = 1.7889e+01
        J = 6.1316e-01
        grad_norm = 3.9242e-02
step = 42
        beta = 1.8311e+01
        J = 4.6327e-01
        grad_norm = 7.3251e-02
step = 43
        beta = 1.8733e+01
        J = 4.7778e-01
        grad_norm = 6.6240e-02
step = 44
        beta = 1.9156e+01
        J = 6.3996e-01
        grad_norm = 1.1438e-02
step = 45
        beta = 1.9578e+01
        J = 6.1807e-01
        grad_norm = 3.0977e-02

Results#

Finally, we can inpect the results.

First we plot the objective function over iteration number and note that it steadily increases.

[27]:
plt.plot(Js)
plt.xlabel("iterations")
plt.ylabel("objective function")
plt.ylim(-1.5, 1)
plt.show()

../_images/notebooks_AdjointPlugin11CircuitMZI_51_0.png

We grab the final design parameters and beta value.

[28]:
params_final = params1_final, params2_final = params_history[-1]
beta_final = beta_history[-1]

And use these to construct the Tidy3D simulations corresponding to the final optimized state of each of the components.

[29]:
sim1_final = make_sim(params1_final, beta=beta_final, source_index=0, shape=(1,2))
sim2_final = make_sim(params2_final, beta=beta_final, source_index=0, shape=(2,2))
sim3_final = make_sim(params2_final, beta=beta_final, source_index=1, shape=(2,2))

Let’s plot these simulations. Note that the 3rd and 2nd are the same, except with different source, so we can visualize the fields sourced from each of the individual inputs.

[30]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, tight_layout=True, figsize=(10,6))

sim1_final.plot_eps(z=0, ax=ax1)
sim2_final.plot_eps(z=0, ax=ax2)
sim3_final.plot_eps(z=0, ax=ax3)

ax1.set_title('first component (splitter)')
ax2.set_title('second component (combiner)')
ax3.set_title('second component (combiner)')

plt.show()

../_images/notebooks_AdjointPlugin11CircuitMZI_57_0.png

To visualize the fields, let’s create and add a FieldMonitor to each of the simulations.

[31]:
field_mnt = td.FieldMonitor(
    size=(td.inf, td.inf, 0),
    freqs=[freq0],
    name="field_mnt",
    colocate=True,
)

sim1_final = sim1_final.copy(update=dict(monitors=(field_mnt,)))
sim2_final = sim2_final.copy(update=dict(monitors=(field_mnt,)))
sim3_final = sim3_final.copy(update=dict(monitors=(field_mnt,)))

Next, run the simulations

[37]:
sims_final = (sim1_final, sim2_final, sim3_final)

sim_data1_final, sim_data2_final, sim_data3_final = tda.web.run_async(sims_final, path_dir="data", verbose=False)

and plot the results.

[38]:
f, (axes_eps, axes_fld, axes_int) = plt.subplots(3, 3, figsize=(10, 8), tight_layout=True)
sim_datas = [sim_data1_final, sim_data2_final, sim_data3_final]
for sim_data_final, ax_eps, ax_fld, ax_int in zip(sim_datas, axes_eps, axes_fld, axes_int):
    sim_data_final.simulation.plot_eps(z=0.01, ax=ax_eps)
    sim_data_final.plot_field("field_mnt", "Ez", z=0, ax=ax_fld)
    sim_data_final.plot_field("field_mnt", "E", "abs^2", z=0, ax=ax_int)

../_images/notebooks_AdjointPlugin11CircuitMZI_63_0.png

While this gives an interesting picture, what we really want to visualize is how the fields look under our design conditions when phi=0 and phi=pi. For that, we write a function to compute the source parameters for the 2nd component under values of phi and run that simulation.

[39]:
def get_sim_data_right(phi):

    out_top_1 = sim_data1_final["top"].amps.sel(direction="+", f=freq0, mode_index=0)
    out_bot_1 = sim_data1_final["bot"].amps.sel(direction="+", f=freq0, mode_index=0)

    # apply phi phase shift to top arm
    phase_top = np.angle(out_top_1) + phi
    phase_bot = np.angle(out_bot_1)

    src_top = sim2_final.sources[0]
    src_bot = sim3_final.sources[0]

    src_time_top = src_top.source_time.updated_copy(amplitude=abs(out_top_1), phase=phase_top)
    src_time_bot = src_bot.source_time.updated_copy(amplitude=abs(out_bot_1), phase=phase_bot)

    src_top = src_top.updated_copy(source_time=src_time_top)
    src_bot = src_bot.updated_copy(source_time=src_time_bot)

    sim_right = sim2_final.updated_copy(sources=[src_top, src_bot])
    return tda.web.run(sim_right, task_name=f"phi={phi:.3f}")

We compute the field data for the output component for both phi=0 and phi=pi.

[40]:
sim_data_right_p0 = get_sim_data_right(phi=0)
sim_data_right_pi = get_sim_data_right(phi=np.pi)
[17:59:00] Created task 'phi=0.000' with task_id
           'fdve-303d2de1-75c2-438a-9087-f48c40f1abb2v1'.
[17:59:03] status = queued
[17:59:06] status = preprocess
[17:59:10] Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get
           the billed FlexCredit cost after a simulation run.
           starting up solver
[17:59:11] running solver
           To cancel the simulation, use 'web.abort(task_id)' or
           'web.delete(task_id)' or abort/delete the task in the web UI.
           Terminating the Python script will not stop the job running on the
           cloud.
[17:59:17] early shutoff detected, exiting.
           status = postprocess
[17:59:21] status = success
[17:59:23] loading SimulationData from simulation_data.hdf5
[17:59:24] Created task 'phi=3.142' with task_id
           'fdve-fe2d9123-5397-4a8c-9fb0-86c13495bb41v1'.
[17:59:26] status = queued
[17:59:30] status = preprocess
[17:59:37] Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get
           the billed FlexCredit cost after a simulation run.
           starting up solver
           running solver
           To cancel the simulation, use 'web.abort(task_id)' or
           'web.delete(task_id)' or abort/delete the task in the web UI.
           Terminating the Python script will not stop the job running on the
           cloud.
[17:59:43] early shutoff detected, exiting.
[17:59:44] status = postprocess
[17:59:48] status = success
[17:59:50] loading SimulationData from simulation_data.hdf5

And plot the results. Note that the device works exactly as intended! When phi=0, the light is transmitted into the top port and when phi=pi, the light is transmitted into the bottom port.

[41]:
alpha = 0.0
f, (axes_eps, axes_fld, axes_int) = plt.subplots(3, 3, figsize=(10, 8), tight_layout=True)
sim_datas = [sim_data1_final, sim_data_right_p0, sim_data_right_pi]
for sim_data_final, ax_eps, ax_fld, ax_int, phi in zip(sim_datas, axes_eps, axes_fld, axes_int, (None, "0", "Ο€")):
    sim_data_final.simulation.plot_eps(z=0.01, ax=ax_eps, source_alpha=alpha, monitor_alpha=0)
    sim_data_final.plot_field("field_mnt", "Ez", z=0, ax=ax_fld)
    sim_data_final.plot_field("field_mnt", "E", "abs^2", z=0, ax=ax_int)

    for ax in (ax_eps, ax_fld, ax_int):
        if phi is not None:
            ax.set_title(rf'output sim (phi={phi})')
        else:
            ax.set_title("input sim")

../_images/notebooks_AdjointPlugin11CircuitMZI_69_0.png

With some minor modifications to this MZI device (such as adding a 2nd input port and adding a 2nd phase shifter on the output), we can implement any unitary 2x2 matrix and build very complex components for performing arbitrary linear operations in optical circuits, such as optical neural networks.

With the adjoint plugin of Tidy3D and the differentiable circuit modeling of sax, we have a convenient tool for combining the power and flexibility of inverse design with the modularity of traditional component design and can perform co-optimization of individual components with minimal overhead.

[52]:
power_top_p0 = jnp.sum(jnp.abs(jnp.array(sim_data_right_p0.output_data[0].amps.values))**2)
power_bot_p0 = jnp.sum(jnp.abs(jnp.array(sim_data_right_p0.output_data[1].amps.values))**2)
power_top_pi = jnp.sum(jnp.abs(jnp.array(sim_data_right_pi.output_data[0].amps.values))**2)
power_bot_pi = jnp.sum(jnp.abs(jnp.array(sim_data_right_pi.output_data[1].amps.values))**2)
[58]:
print('phi = 0')
print(f'  Transmission_top = {100 * power_top_p0:.2f} %')
print(f'  Transmission_bot = {100 * power_bot_p0:.2f} %')

print('phi = pi')
print(f'  Transmission_top = {100 * power_top_pi:.2f} %')
print(f'  Transmission_bot = {100 * power_bot_pi:.2f} %')
phi = 0
  Transmission_top = 58.65 %
  Transmission_bot = 0.91 %
phi = pi
  Transmission_top = 0.39 %
  Transmission_bot = 79.51 %
[ ]: