Adjoint optimization of a wavelength division multiplexer#

In this notebook, we will use a multi-objective optimization to design a wavelength division multiplexer (WDM).

In short, this device takes in broadband light and directs light of different wavelengths to different output ports.

This demo combines the basic setup of our 3rd tutorial of a mode converter with the multi-frequency feature introduced in Tidy3D version 2.5.

If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.

[1]:
import matplotlib.pylab as plt

# first import tidy3d, its adjoint plugin, numpy, and jax.
import tidy3d as td
import tidy3d.plugins.adjoint as tda
import numpy as np
import jax.numpy as jnp
import jax

np.random.seed(2)

Setup#

First we set up our basic simulation.

We have an input waveguide connected to a square design region, which has two output waveguides.

The square design region is a custom medium with a pixellated permittivity grid that we wish to optimize such that input light of different wavelengths get directed to different output ports.

As this is a SOI device, we typically define the design region and waveguides as Silicon sitting on an SiO2 substrate. For this demo, we make a 2D simulation, but it can be easily made 3D by changing the Lz parameter, adding dimension to the structures, and adding a substrate.

[2]:
# material information
n_si = 3.49
n_sio2 = 1.45 # not used in 2D
n_air = 1

# design output wavelengths
wavelength_top = 1.300
wavelength_bot = 1.550

# and their corresponding frequencies and spectral information
freq_top = td.C_0 / wavelength_top
freq_bot = td.C_0 / wavelength_bot
freq0 = (freq_top + freq_bot) / 2.0
fwidth = abs(freq_bot - freq_top)
run_time = 100 / fwidth

# create dictionaries to reference these later by string key 'top' or 'bot'
freqs = dict(top=freq_top, bot=freq_bot)
wavelengths = dict(top=wavelength_top, bot=wavelength_bot)

# size of design region
lx = 2.8
ly = 2.8
lz = td.inf # in 2D, we say the size of components is inf but the size of simulation is 0.

# size of waveguides
wg_width = 0.3
wg_length = 1.5
wg_spacing = 0.8

# spacing between design region and PML in y
buffer = 1.5

# size of simulation
Lx = lx + wg_length * 2
Ly = ly + buffer * 2
Lz = 0.0

# fabrication constraint (feature size and projection strength)
radius = 0.200
beta = 30

# resolution information
min_steps_per_wvl = 25
[3]:
# define the waveguide ports

wg_in = td.Structure(
    geometry=td.Box(
        center=(-Lx/2, 0, 0),
        size=(wg_length * 2, wg_width, lz),
    ),
    medium=td.Medium(permittivity=n_si**2)
)

wg_top = td.Structure(
    geometry=td.Box(
        center=(+Lx/2, +wg_width/2+wg_spacing/2, 0),
        size=(wg_length * 2, wg_width, lz),
    ),
    medium=td.Medium(permittivity=n_si**2)
)

wg_bot = td.Structure(
    geometry=td.Box(
        center=(+Lx/2, -wg_width/2-wg_spacing/2, 0),
        size=(wg_length * 2, wg_width, lz),
    ),
    medium=td.Medium(permittivity=n_si**2)
)

# and a field monitor that measures fields on the z=0 plane
fld_mnt = td.FieldMonitor(
    center=(0,0,0),
    size=(td.inf, td.inf, 0),
    freqs=[freq_top, freq_bot],
    name="field",
)

Note we can ignore this warning as it will be resolved after 2.4.0

Define design region#

Here we define the design region as a pixellated grid of permittivity values.

We first define the overall geometry as a JaxBox and also the number of pixels in x and y.

[4]:
nx = 55
ny = 55

design_region_geo = tda.JaxBox(
    size=(lx, ly, lz),
    center=(0,0,0)
)

Next we write a function to give us the pixellated array as a function of our parameters through our filtering and projection methods, which are used to make the resulting structures easier to fabricate. For more details, refer to our 4th lecture in the inverse design 101 lecture series, which focuses on fabrication constraints.

We also wrap this function in another one that generates the entire JaxStructure corresponding to the design region, for convenience later.

[5]:
from tidy3d.plugins.adjoint.utils.filter import ConicFilter

conic_filter = ConicFilter(radius=radius, design_region_dl=lx/nx)

# note: params is an array of shape (nx, ny) that stores values between -inf (air) and +inf (silicon)

def tanh_projection(x, beta, eta=0.5):
    tanhbn = jnp.tanh(beta * eta)
    num = tanhbn + jnp.tanh(beta * (x - eta))
    den = tanhbn + jnp.tanh(beta * (1 - eta))
    return num / den

def filter_project(x, beta, eta=0.5):
    x = conic_filter.evaluate(x)
    return tanh_projection(x, beta=beta, eta=eta)

# number of times to filter -> project. Two times with a lower beta (~30) seems to give decent results.
num_projections = 2

def pre_process(params, beta):
    """Get the permittivity values (1, eps_wg) array as a funciton of the parameters (0,1)"""
    for _ in range(num_projections):
        params = filter_project(params, beta=beta)
    return params

def make_eps(params, beta):
    params = pre_process(params, beta=beta)
    eps_values = 1 + (n_si**2 - 1) * params
    return eps_values

def make_custom_medium(params, beta):
    """Make JaxCustomMedium as a function of provided parameters."""
    eps = make_eps(params, beta).reshape((nx, ny, 1, 1))
    eps = jnp.where(eps < 1, 1, eps)
    eps = jnp.where(eps > n_si**2, n_si**2, eps)

    xs = list(jnp.linspace(-lx/2, lx/2, nx))
    ys = list(jnp.linspace(-ly/2, ly/2, ny))
    zs = [0]
    freqs = [freq0]
    coords = dict(x=xs, y=ys, z=zs, f=freqs)

    eps_dataset = tda.JaxDataArray(values=eps, coords=coords)

    medium = tda.JaxCustomMedium(
        eps_dataset=tda.JaxPermittivityDataset(
            eps_xx=eps_dataset,
            eps_yy=eps_dataset,
            eps_zz=eps_dataset,
        )
    )

    struct = tda.JaxStructure(
        geometry=design_region_geo,
        medium=medium
    )

    return struct

Define base simulation#

With all of these functions and variables defined, we can write a single function to return our β€œbase” JaxSimulation as a function of our design parameters. This function first constructs the design region and then creates a JaxSimulation with all of the basic parameters.

Note, we don’t yet have a source or monitors for injecting and measuring our fields, but will add those next after running the mode solver.

[6]:
def make_sim_base(params, beta):

    input_struct = make_custom_medium(params, beta=beta)

    return tda.JaxSimulation(
        size=(Lx, Ly, Lz),
        grid_spec=td.GridSpec.auto(min_steps_per_wvl=min_steps_per_wvl, wavelength=wavelength_top),
        structures=[wg_in, wg_top, wg_bot],
        monitors=[fld_mnt],
        input_structures=[input_struct],
        boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True if Lz else False),
        run_time=run_time,
    )

Let’s test out our function. We’ll make an initially random array of parameters between 0 and 1 and generate the base simulation to plot and inspect.

[7]:
params0 = np.random.random((nx, ny))
sim_base = make_sim_base(params0, beta=1)
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[8]:
ax = sim_base.plot_eps(z=0, monitor_alpha=0.0)
../_images/notebooks_AdjointPlugin9WDM_13_0.png

It all looks good, so now we add the bits that define the optimization.

Adding Mode Sources and Monitors#

Solving modes#

First, we need to create our ModeSource and ModeMonitor objects that inject and measure the modes that we are interested in optimizing.

We’ll use tidy3d’s ModeSolver and use the remote run function that gets more accurate results by running on Flexcompute’s servers.

[9]:
from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.web import run as run_mode_solver

# we'll ask for 4 modes just to inspect
num_modes = 4

# let's define how large the mode planes are and how far they are from the PML relative to the design region
mode_size = (0, 1.8 * wg_spacing + wg_width, max([Lz, lz, 3]))
space_fraction = 0.2

# make a plane corresponding to where we wish to measure the input mode
plane_in = td.Box(
    center=(-Lx/2 + space_fraction * wg_length, 0, 0),
    size=mode_size,
)

# construct the mode solver using our base sim (converted from `JaxSimulation` to regular `Simulation`) + our plane
mode_solver = ModeSolver(
    simulation=sim_base.to_simulation()[0],
    plane=plane_in,
    freqs=[freq_top],
    mode_spec=td.ModeSpec(num_modes=num_modes)
)

Next we run the mode solver on the servers.

[10]:
mode_data = run_mode_solver(mode_solver)
11:04:55 -03 Mode solver created with
             task_id='fdve-7c6df90b-568c-4a83-be65-6158396ee5f9',
             solver_id='mo-9e52b651-3cf5-4147-bf98-411afc499e89'.
11:04:59 -03 Mode solver status: queued
11:05:11 -03 Mode solver status: running
11:05:20 -03 Mode solver status: success

And visualize the results.

[11]:
fig, axs = plt.subplots(num_modes, 3, figsize=(12, 12), tight_layout=True)
for mode_index in range(num_modes):
    vmax = 1.1 * max(abs(mode_data.field_components[n].sel(mode_index=mode_index)).max() for n in ("Ex", "Ey", "Ez"))
    for field_name, ax in zip(("Ex", "Ey", "Ez"), axs[mode_index]):
        field = mode_data.field_components[field_name].sel(mode_index=mode_index)
        field.real.plot(label="Real", ax=ax)
        field.imag.plot(ls="--", label="Imag", ax=ax)
        ax.set_title(f'index={mode_index}, {field_name}')
        ax.set_ylim(-vmax, vmax)

axs[0, 0].legend()

print("Effective index of computed modes: ", np.array(mode_data.n_eff))
Effective index of computed modes:  [[3.141611 2.806403 1.953868 1.062474]]
../_images/notebooks_AdjointPlugin9WDM_19_1.png

We identify mode_index=0 as the first order mode that is out of plane of the device. Let’s choose to optimize our device with respect to this as the mode of interest for both the input and output.

We re-set the ModeSpec to only compute the number of modes we need (1) and also update our ModeSolver accordingly.

[12]:
mode_index = 0
mode_spec = td.ModeSpec(num_modes=mode_index+1)
mode_solver = mode_solver.updated_copy(mode_spec=mode_spec)

Make input and output mode sources and monitors#

Next, we will generate the input ModeSource and output ModeMonitor objects using the convenience methods defined in the ModeSolver.

Because our plane was defined at the input port, we’ll modify the centers of the ModeMonitors to place them at the output ports to the right of the device.

[13]:
# make source
mode_src = mode_solver.to_source(
    source_time=td.GaussianPulse(
        freq0=freq0,
        fwidth=fwidth,
    ),
    direction="+",
    mode_index=mode_index,
)

# make a basic monitor
mode_mnt = mode_solver.to_monitor(
    freqs=[freq0],
    name="_"
)

# construct the proper centers for the monitors at the 'top' and 'bot' ports
mnt_center_top = list(plane_in.center)
mnt_center_bot = list(plane_in.center)
mnt_center_top[0] = -plane_in.center[0]
mnt_center_bot[0] = -plane_in.center[0]
mnt_center_top[1] = wg_top.geometry.center[1]
mnt_center_bot[1] = wg_bot.geometry.center[1]

# make a dictionary of names and frequencies to refer to later by key
mnt_names = dict(top="mode_top", bot="mode_bot")
mnt_freqs = dict(top=freq_top, bot=freq_bot)

# make two updated copies of the mode monitor with the proper frequencies, centers, and names
mode_mnt_top = mode_mnt.updated_copy(freqs=[mnt_freqs["top"]], center=mnt_center_top, name=mnt_names["top"])
mode_mnt_bot = mode_mnt.updated_copy(freqs=[mnt_freqs["bot"]], center=mnt_center_bot, name=mnt_names["bot"])

# make another dictionary mapping the keys to the monitors
mode_mnts = dict(top=mode_mnt_top, bot=mode_mnt_bot)

Add flux monitors#

For plotting later, we’ll add a couple of FluxMonitor objects at the output ports to measure the total flux over a large spectrum. With this data, we should be able to clearly see the difference in transmission for each of the ports at the design region and get an idea about the device bandwidth.

[14]:
Nf = 121
freqs_flux = np.linspace(freq_bot - fwidth/10, freq_top + fwidth/10, Nf)

flux_mnt_names = dict(top="flux_top", bot="flux_bot")

flux_mnt_top = td.FluxMonitor(
    center=mode_mnt_top.center,
    size=mode_mnt_top.size,
    name=flux_mnt_names["top"],
    freqs=list(freqs_flux),
)

flux_mnt_bot = td.FluxMonitor(
    center=mode_mnt_bot.center,
    size=mode_mnt_bot.size,
    name=flux_mnt_names["bot"],
    freqs=list(freqs_flux),
)

Add to simulation#

Finally, we will wrap our previous make_sim_base() function in a new one that adds our new objects to this base simulation.

[15]:
def make_sim(params, beta):

    output_monitors = [mode_mnts["top"], mode_mnts["bot"]]

    sim_base = make_sim_base(params, beta=beta)
    return sim_base.updated_copy(
        output_monitors=output_monitors,
        sources=[mode_src],
        monitors=tuple(list(sim_base.monitors) + [flux_mnt_top, flux_mnt_bot])
    )

Let’s make the final simulation and visualize it with the sources and monitors added.

[16]:
sim = make_sim(params0, beta=1)

Note: the FluxMonitor objects are overlaying the output ModeMonitor objects.

[17]:
ax = sim.plot_eps(z=0.01)
plt.show()
../_images/notebooks_AdjointPlugin9WDM_31_0.png

Defining objective function#

With our simulation fully defined as a function of our parameters, we are ready to define our objective function.

Computing power transmission#

In this case, it is quite simple, we simply measure the transmitted power in our output waveguide mode. We wish to maximize transmission to the top port at the β€œtop” wavelength (1330 nm) and maximize transmission to the bottom port at the β€œbot” wavelength (1550 nm).

[18]:
def measure_power(sim_data) -> float:
    """Extract power from simulation data."""

    def get_power(mnt_key: str, freq_key: str) -> float:
        """Get the power at monitor 'mnt_key' at frequency 'freq_key' (both either 'top' or 'bot')."""
        mnt_name = mnt_names[mnt_key]
        freq = freqs[freq_key]
        mnt_data = sim_data[mnt_name]
        amp = mnt_data.amps.sel(direction="+", mode_index=0, f=freq)
        return jnp.abs(amp) ** 2

    power_max = get_power("top", "top") + get_power("bot", "bot")
    return power_max / 2.0

Next we add a penalty to produce structures that are invariant under erosion and dilation, which is a useful approach to implementing minimum length scale features.

[19]:
def penalty(params, beta, delta_eps=0.49):
    params = pre_process(params, beta=beta)
    dilate_fn = lambda x: filter_project(x, beta=beta, eta=0.5-delta_eps)
    eroded_fn = lambda x: filter_project(x, beta=beta, eta=0.5+delta_eps)

    params_dilate_erode = eroded_fn(dilate_fn(params))
    params_erode_dilate = dilate_fn(eroded_fn(params))
    diff = params_dilate_erode - params_erode_dilate
    return jnp.linalg.norm(diff) / jnp.linalg.norm(jnp.ones_like(diff))

Writing objective function#

Then we write an objective function that constructs our simulation, runs it, measures the power, and subtracts our penalty.

Note, we return a JaxSimulationData as the second output. The reason for this is that we might wish to access our flux and field data later on. jax gives an option has_aux to use only the first output for differentiation while letting the user have access to the 2nd β€œauxiliary” output, which we will make use of.

[20]:
def objective(params, beta, verbose=False) -> float:
    sim = make_sim(params, beta=beta)
    sim_data = tda.web.run(sim, task_name="WDM_MULTIFREQ", verbose=verbose)
    power = measure_power(sim_data)
    J = power - penalty(params, beta=beta)
    return J, sim_data

Differentiating objective#

Finally, we can simply use jax to transform this objective function into a function that returns our objective function value, the auxiliary data, and our gradient, which we will feed to the optimizer.

[21]:
grad_fn = jax.value_and_grad(objective, has_aux=True)

Let’s try out our gradient function with verbosity on for just this run.

[22]:
(J, sim_data), grad = grad_fn(params0, beta=1, verbose=True)
11:05:27 -03 Created task 'WDM_MULTIFREQ' with task_id
             'fdve-f68cf898-b522-4b9c-96ff-f48bf8dc3193' and task_type 'FDTD'.
11:05:31 -03 status = queued
11:05:35 -03 status = preprocess
11:05:42 -03 Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get
             the billed FlexCredit cost after a simulation run.
             starting up solver
             running solver
             To cancel the simulation, use 'web.abort(task_id)' or
             'web.delete(task_id)' or abort/delete the task in the web UI.
             Terminating the Python script will not stop the job running on the
             cloud.
11:05:55 -03 early shutoff detected at 32%, exiting.
             status = postprocess
11:06:11 -03 status = success
11:06:17 -03 loading simulation from simulation_data.hdf5
11:06:18 -03 WARNING: 2 unique frequencies detected in the output monitors with 
             a minimum spacing of 3.720e+13 (Hz). Setting the 'fwidth' of the   
             adjoint sources to 0.1 times this value = 3.720e+12 (Hz) to avoid  
             spectral overlap. To account for this, the corresponding 'run_time'
             in the adjoint simulation is will be set to 2.688527e-11 compared  
             to 2.688527e-12 in the forward simulation. If the adjoint          
             'run_time' is large due to small frequency spacing, it could be    
             better to instead run one simulation per frequency, which can be   
             done in parallel using 'tidy3d.plugins.adjoint.web.run_async'.     
             Created task 'WDM_MULTIFREQ_adj' with task_id
             'fdve-3113d201-09a0-49ec-b9fa-57e3a9dc7dd6' and task_type 'FDTD'.
11:06:22 -03 status = queued
11:06:28 -03 status = preprocess
11:06:34 -03 Maximum FlexCredit cost: 0.032. Use 'web.real_cost(task_id)' to get
             the billed FlexCredit cost after a simulation run.
             starting up solver
11:06:35 -03 running solver
             To cancel the simulation, use 'web.abort(task_id)' or
             'web.delete(task_id)' or abort/delete the task in the web UI.
             Terminating the Python script will not stop the job running on the
             cloud.
11:06:47 -03 early shutoff detected at 4%, exiting.
             status = postprocess
11:06:54 -03 status = success

Run Opimization#

Finally, we are ready to optimize our device. We will make use the optax package to define an optimizer using the adam method, as we’ve done in the previous adjoint tutorials.

We record a history of objective function values, simulation data, and parameters for visualization later.

[23]:
import optax

# we know that the source fwidth will be set automatically due to multi-freq adjoint, so suppress warnings
td.config.logging_level = "ERROR"

# hyperparameters
num_steps = 50
learning_rate = 5e-2
beta_min = 1
beta_max = 30

# initialize adam optimizer with starting parameters
params = params0
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# store history
Js = []
params_history = [params0]
data_history = []
beta_history = []

for i in range(num_steps):

    perc_done = i / (num_steps - 1)
    beta_i = beta_min * (1 - perc_done) + beta_max * perc_done

    # compute gradient and current objective funciton value
    (value, data), gradient = grad_fn(params, beta=beta_i)

    # outputs
    print(f"step = {i + 1}")
    print(f"\tJ = {value:.4e}")
    print(f"\tbeta = {beta_i:.2f}")
    print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")

    # compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
    updates, opt_state = optimizer.update(-gradient, opt_state, params)
    params = optax.apply_updates(params, updates)

    # keep params between 0 and 1
    params = jnp.minimum(1.0, params)
    params = jnp.maximum(0.0, params)

    # save history
    Js.append(value)
    params_history.append(params)
    beta_history.append(beta_i)
    data_history.append(data)
step = 1
        J = 1.1304e-01
        beta = 1.00
        grad_norm = 1.3148e-01
step = 2
        J = 1.7998e-01
        beta = 1.59
        grad_norm = 2.2417e-01
step = 3
        J = 1.9274e-01
        beta = 2.18
        grad_norm = 2.6165e-01
step = 4
        J = 5.7359e-02
        beta = 2.78
        grad_norm = 3.2920e-01
step = 5
        J = -2.7653e-02
        beta = 3.37
        grad_norm = 5.1749e-01
step = 6
        J = -1.0726e-01
        beta = 3.96
        grad_norm = 5.9930e-01
step = 7
        J = 5.4190e-02
        beta = 4.55
        grad_norm = 2.8124e-01
step = 8
        J = 1.3695e-01
        beta = 5.14
        grad_norm = 3.0855e-01
step = 9
        J = 1.9039e-01
        beta = 5.73
        grad_norm = 3.4029e-01
step = 10
        J = 2.1604e-01
        beta = 6.33
        grad_norm = 5.2777e-01
step = 11
        J = 2.6796e-01
        beta = 6.92
        grad_norm = 5.9430e-01
step = 12
        J = 2.0278e-01
        beta = 7.51
        grad_norm = 1.1882e+00
step = 13
        J = 3.2679e-01
        beta = 8.10
        grad_norm = 6.3214e-01
step = 14
        J = 3.6340e-01
        beta = 8.69
        grad_norm = 2.8383e-01
step = 15
        J = 3.7389e-01
        beta = 9.29
        grad_norm = 2.3837e-01
step = 16
        J = 3.9022e-01
        beta = 9.88
        grad_norm = 1.6353e-01
step = 17
        J = 4.0384e-01
        beta = 10.47
        grad_norm = 1.3485e-01
step = 18
        J = 4.1686e-01
        beta = 11.06
        grad_norm = 1.2928e-01
step = 19
        J = 4.2954e-01
        beta = 11.65
        grad_norm = 1.1254e-01
step = 20
        J = 4.3786e-01
        beta = 12.24
        grad_norm = 1.2386e-01
step = 21
        J = 4.4402e-01
        beta = 12.84
        grad_norm = 1.3647e-01
step = 22
        J = 4.4862e-01
        beta = 13.43
        grad_norm = 3.4103e-01
step = 23
        J = 4.5884e-01
        beta = 14.02
        grad_norm = 1.4109e-01
step = 24
        J = 4.6795e-01
        beta = 14.61
        grad_norm = 8.9034e-02
step = 25
        J = 4.7021e-01
        beta = 15.20
        grad_norm = 1.6547e-01
step = 26
        J = 4.8514e-01
        beta = 15.80
        grad_norm = 1.6582e-01
step = 27
        J = 5.0607e-01
        beta = 16.39
        grad_norm = 1.3549e-01
step = 28
        J = 5.2328e-01
        beta = 16.98
        grad_norm = 1.7107e-01
step = 29
        J = 5.5337e-01
        beta = 17.57
        grad_norm = 1.8353e-01
step = 30
        J = 5.8120e-01
        beta = 18.16
        grad_norm = 2.6351e-01
step = 31
        J = 6.0239e-01
        beta = 18.76
        grad_norm = 2.1676e-01
step = 32
        J = 6.1259e-01
        beta = 19.35
        grad_norm = 4.3566e-01
step = 33
        J = 5.4708e-01
        beta = 19.94
        grad_norm = 9.3248e-01
step = 34
        J = 6.1207e-01
        beta = 20.53
        grad_norm = 6.7785e-01
step = 35
        J = 6.4012e-01
        beta = 21.12
        grad_norm = 2.7947e-01
step = 36
        J = 6.2275e-01
        beta = 21.71
        grad_norm = 6.0492e-01
step = 37
        J = 6.6233e-01
        beta = 22.31
        grad_norm = 1.6761e-01
step = 38
        J = 6.5566e-01
        beta = 22.90
        grad_norm = 6.2367e-01
step = 39
        J = 6.5845e-01
        beta = 23.49
        grad_norm = 6.8979e-01
step = 40
        J = 6.9423e-01
        beta = 24.08
        grad_norm = 3.0279e-01
step = 41
        J = 7.0369e-01
        beta = 24.67
        grad_norm = 2.5525e-01
step = 42
        J = 6.9838e-01
        beta = 25.27
        grad_norm = 5.4912e-01
step = 43
        J = 6.9214e-01
        beta = 25.86
        grad_norm = 7.5892e-01
step = 44
        J = 7.1860e-01
        beta = 26.45
        grad_norm = 4.9100e-01
step = 45
        J = 7.4005e-01
        beta = 27.04
        grad_norm = 2.2158e-01
step = 46
        J = 7.3577e-01
        beta = 27.63
        grad_norm = 5.3908e-01
step = 47
        J = 7.4667e-01
        beta = 28.22
        grad_norm = 4.1632e-01
step = 48
        J = 7.5887e-01
        beta = 28.82
        grad_norm = 2.1059e-01
step = 49
        J = 7.6499e-01
        beta = 29.41
        grad_norm = 3.3670e-01
step = 50
        J = 7.6311e-01
        beta = 30.00
        grad_norm = 4.7441e-01

Visualize Results#

Let’s visualize the results of our optimization.

Objective function vs Iteration#

First we inspect the objective function value as a function of optimization iteration number. We see that it steadily increases as expected.

The presence of fabrication constraints tends to create some minor bumps in the optimization, which can be a signal that one needs to reduce the step size, but these results are sufficient for our purposes.

[24]:
plt.plot(Js)
plt.xlabel('iteration number')
plt.ylabel('objective function')
plt.show()
../_images/notebooks_AdjointPlugin9WDM_45_0.png

Final Simulation#

Let’s take a look at the final simulation, which we grab from our history.

[25]:
sim_data_final = data_history[-1]
sim_final = sim_data_final.simulation.to_simulation()[0]

We notice that the structure has reasonably large feature sizes but is not well binarized. This could be improved by increasing the beta projection value slowly over iteration number, as was done in the grating coupler tutorial.

[26]:
ax = sim_final.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0)
../_images/notebooks_AdjointPlugin9WDM_49_0.png

Flux#

Let’s inspect the flux over each of the output ports as a function of wavelength.

We notice that the top and bottom ports have peaks in transmission at their corresponding design wavelengths, as expected!

[27]:
# plot flux
for key, color in zip(("top", "bot"), ('royalblue', 'firebrick')):
    freq = freqs[key]
    flux_data = sim_data_final[flux_mnt_names[key]]
    wvl_nm = 1000 * td.C_0 / freq
    wavelengths_nm = 1000 * td.C_0 / np.array(flux_data.flux.f)
    flux = np.array(flux_data.flux.values)
    flux_db = 10 * np.log(flux)
    label = f"{key} ({int(wvl_nm)} nm)"
    plt.plot(wavelengths_nm, flux_db, label=label, color=color)
    plt.scatter([wvl_nm], [0], 100, marker="*", color=color)
    plt.xlabel('wavelength (nm)')
    plt.ylabel('transmission (dB)')
    plt.legend()

plt.show()
../_images/notebooks_AdjointPlugin9WDM_51_0.png

Fields#

Let’s also plot the field intensity patterns at the two design wavelengths. We see from this plot the expected result that the power is directed to the design port at each frequency.

[28]:
# plot fields at the two design wavelengths

fig, axes = plt.subplots(1, 2, tight_layout=True, figsize=(7, 3))

for key, ax in zip(("top", "bot"), axes):
    freq = freqs[key]
    sim_data_final.plot_field("field", "E", "abs^2", f=freq, ax=ax, vmax=1200)
    wvl = 1000 * td.C_0 / freq
    ax.set_title(f"wavelength = {int(wvl)} nm")
../_images/notebooks_AdjointPlugin9WDM_53_0.png

Animation#

Finally, we animate this plot as a function of iteration number. The animation shows the device quickly accomplishing our design objective.

Note: can take a few minutes to complete

[29]:
import matplotlib.animation as animation
from IPython.display import HTML

fig, (ax1, ax2, ax3) = fig, axes = plt.subplots(1, 3, tight_layout=False, figsize=(9, 4))

def animate(i):

    sim_data_i = data_history[i]

    sim_i = sim_data_i.simulation.to_simulation()[0]
    sim_i.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0, ax=ax1)
    ax1.set_aspect('equal')

    for key, ax in zip(("top", "bot"), (ax2, ax3)):

        freq = freqs[key]
        wvl = 1000 * td.C_0 / freq

        int_i = sim_data_i.get_intensity("field").sel(f=freq)
        int_i.squeeze().plot.pcolormesh(x='x', y='y', ax=ax, add_colorbar=False, cmap="magma", vmax=1000)

        ax.set_aspect('equal')
        ax.set_title(f"wavelength = {int(wvl)} nm")

# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(data_history))
plt.close()

[30]:
# display the animation (press "play" to start)
HTML(ani.to_jshtml())

[30]:
<Figure size 640x480 with 0 Axes>

To save the animation as a file, uncomment the line below

Note: can take several more minutes to complete

[31]:
# ani.save('img/animation_wdm_adjoint.gif', fps=60)

Export to GDS#

The Simulation object has the .to_gds_file convenience function to export the final design to a GDS file. In addition to a file name, it is necessary to set a cross-sectional plane (z = 0 in this case) on which to evaluate the geometry, a frequency to evaluate the permittivity, and a permittivity_threshold to define the shape boundaries in custom mediums. See the GDS export notebook for a detailed example on using .to_gds_file and other GDS related functions.

[32]:
sim_final = data_history[-1].simulation.to_simulation()[0]
sim_final.to_gds_file(
    fname="./misc/inv_des_wdm.gds",
    z=0,
    permittivity_threshold=(n_si**2 + 1) / 2,
    frequency=freq0,
)