Adjoint optimization of a wavelength division multiplexer#
In this notebook, we will use a multi-objective optimization to design a wavelength division multiplexer (WDM).
In short, this device takes in broadband light and directs light of different wavelengths to different output ports.
This demo combines the basic setup of our 3rd tutorial of a mode converter with the multi-frequency feature introduced in Tidy3D version 2.5.
If you are unfamiliar with inverse design, we also recommend our intro to inverse design tutorials and our primer on automatic differentiation with tidy3d.
[1]:
# first import tidy3d, its adjoint plugin, numpy, and jax.
import tidy3d as td
import tidy3d.plugins.adjoint as tda
import numpy as np
import jax.numpy as jnp
import jax
np.random.seed(2)
Setup#
First we set up our basic simulation.
We have an input waveguide connected to a square design region, which has two output waveguides.
The square design region is a custom medium with a pixellated permittivity grid that we wish to optimize such that input light of different wavelengths get directed to different output ports.
As this is a SOI device, we typically define the design region and waveguides as Silicon sitting on an SiO2 substrate. For this demo, we make a 2D simulation, but it can be easily made 3D by changing the Lz
parameter, adding dimension to the structures, and adding a substrate.
[2]:
# material information
n_si = 3.49
n_sio2 = 1.45 # not used in 2D
n_air = 1
# design output wavelengths
wavelength_top = 1.300
wavelength_bot = 1.550
# and their corresponding frequencies and spectral information
freq_top = td.C_0 / wavelength_top
freq_bot = td.C_0 / wavelength_bot
freq0 = (freq_top + freq_bot) / 2.0
fwidth = abs(freq_bot - freq_top)
run_time = 100 / fwidth
# create dictionaries to reference these later by string key 'top' or 'bot'
freqs = dict(top=freq_top, bot=freq_bot)
wavelengths = dict(top=wavelength_top, bot=wavelength_bot)
# size of design region
lx = 2.8
ly = 2.8
lz = td.inf # in 2D, we say the size of components is inf but the size of simulation is 0.
# size of waveguides
wg_width = 0.3
wg_length = 1.5
wg_spacing = 0.8
# spacing between design region and PML in y
buffer = 1.5
# size of simulation
Lx = lx + wg_length * 2
Ly = ly + buffer * 2
Lz = 0.0
# fabrication constraint (feature size and projection strength)
radius = 0.200
beta = 30
# resolution information
min_steps_per_wvl = 25
[3]:
# define the waveguide ports
wg_in = td.Structure(
geometry=td.Box(
center=(-Lx/2, 0, 0),
size=(wg_length * 2, wg_width, lz),
),
medium=td.Medium(permittivity=n_si**2)
)
wg_top = td.Structure(
geometry=td.Box(
center=(+Lx/2, +wg_width/2+wg_spacing/2, 0),
size=(wg_length * 2, wg_width, lz),
),
medium=td.Medium(permittivity=n_si**2)
)
wg_bot = td.Structure(
geometry=td.Box(
center=(+Lx/2, -wg_width/2-wg_spacing/2, 0),
size=(wg_length * 2, wg_width, lz),
),
medium=td.Medium(permittivity=n_si**2)
)
# and a field monitor that measures fields on the z=0 plane
fld_mnt = td.FieldMonitor(
center=(0,0,0),
size=(td.inf, td.inf, 0),
freqs=[freq_top, freq_bot],
name="field",
)
13:12:34 EDT WARNING: Default value for the field monitor 'colocate' setting has changed to 'True' in Tidy3D 2.4.0. All field components will be colocated to the grid boundaries. Set to 'False' to get the raw fields on the Yee grid instead.
Note we can ignore this warning as it will be resolved after 2.4.0
Define design region#
Here we define the design region as a pixellated grid of permittivity values.
We first define the overall geometry as a JaxBox
and also the number of pixels in x and y.
[4]:
nx = 55
ny = 55
design_region_geo = tda.JaxBox(
size=(lx, ly, lz),
center=(0,0,0)
)
Next we write a function to give us the pixellated array as a function of our parameters through our filtering and projection methods, which are used to make the resulting structures easier to fabricate. For more details, refer to our 4th lecture in the inverse design 101 lecture series, which focuses on fabrication constraints.
We also wrap this function in another one that generates the entire JaxStructure
corresponding to the design region, for convenience later.
[5]:
from tidy3d.plugins.adjoint.utils.filter import ConicFilter
conic_filter = ConicFilter(radius=radius, design_region_dl=lx/nx)
# note: params is an array of shape (nx, ny) that stores values between -inf (air) and +inf (silicon)
def tanh_projection(x, beta, eta=0.5):
tanhbn = jnp.tanh(beta * eta)
num = tanhbn + jnp.tanh(beta * (x - eta))
den = tanhbn + jnp.tanh(beta * (1 - eta))
return num / den
def filter_project(x, beta, eta=0.5):
x = conic_filter.evaluate(x)
return tanh_projection(x, beta=beta, eta=eta)
# number of times to filter -> project. Two times with a lower beta (~30) seems to give decent results.
num_projections = 2
def pre_process(params, beta):
"""Get the permittivity values (1, eps_wg) array as a funciton of the parameters (0,1)"""
for _ in range(num_projections):
params = filter_project(params, beta=beta)
return params
def make_eps(params, beta):
params = pre_process(params, beta=beta)
eps_values = 1 + (n_si**2 - 1) * params
return eps_values
def make_custom_medium(params, beta):
"""Make JaxCustomMedium as a function of provided parameters."""
eps = make_eps(params, beta).reshape((nx, ny, 1, 1))
xs = list(jnp.linspace(-lx/2, lx/2, nx))
ys = list(jnp.linspace(-ly/2, ly/2, ny))
zs = [0]
freqs = [freq0]
coords = dict(x=xs, y=ys, z=zs, f=freqs)
eps_dataset = tda.JaxDataArray(values=eps, coords=coords)
medium = tda.JaxCustomMedium(
eps_dataset=tda.JaxPermittivityDataset(
eps_xx=eps_dataset,
eps_yy=eps_dataset,
eps_zz=eps_dataset,
)
)
struct = tda.JaxStructure(
geometry=design_region_geo,
medium=medium
)
return struct
Define base simulation#
With all of these functions and variables defined, we can write a single function to return our “base” JaxSimulation
as a function of our design parameters. This function first constructs the design region and then creates a JaxSimulation
with all of the basic parameters.
Note, we don’t yet have a source or monitors for injecting and measuring our fields, but will add those next after running the mode solver.
[6]:
def make_sim_base(params, beta):
input_struct = make_custom_medium(params, beta=beta)
return tda.JaxSimulation(
size=(Lx, Ly, Lz),
grid_spec=td.GridSpec.auto(min_steps_per_wvl=min_steps_per_wvl, wavelength=wavelength_top),
structures=[wg_in, wg_top, wg_bot],
monitors=[fld_mnt],
input_structures=[input_struct],
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True if Lz else False),
run_time=run_time,
)
Let’s test out our function. We’ll make an initially random array of parameters between 0 and 1 and generate the base simulation to plot and inspect.
[7]:
params0 = np.random.random((nx, ny))
sim_base = make_sim_base(params0, beta=1)
[8]:
ax = sim_base.plot_eps(z=0, monitor_alpha=0.0)
It all looks good, so now we add the bits that define the optimization.
Adding Mode Sources and Monitors#
Solving modes#
First, we need to create our ModeSource
and ModeMonitor
objects that inject and measure the modes that we are interested in optimizing.
We’ll use tidy3d
’s ModeSolver
and use the remote run
function that gets more accurate results by running on Flexcompute’s servers.
[9]:
from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.web import run as run_mode_solver
# we'll ask for 4 modes just to inspect
num_modes = 4
# let's define how large the mode planes are and how far they are from the PML relative to the design region
mode_size = (0, 1.8 * wg_spacing + wg_width, max([Lz, lz, 3]))
space_fraction = 0.2
# make a plane corresponding to where we wish to measure the input mode
plane_in = td.Box(
center=(-Lx/2 + space_fraction * wg_length, 0, 0),
size=mode_size,
)
# construct the mode solver using our base sim (converted from `JaxSimulation` to regular `Simulation`) + our plane
mode_solver = ModeSolver(
simulation=sim_base.to_simulation()[0],
plane=plane_in,
freqs=[freq_top],
mode_spec=td.ModeSpec(num_modes=num_modes)
)
Next we run the mode solver on the servers.
[10]:
mode_data = run_mode_solver(mode_solver)
13:12:44 EDT Mode solver created with task_id='fdve-f75482a3-807f-46c2-952f-ab578a2ab189v1', solver_id='mo-1e66a52c-b2db-4f09-81a0-969eaa609cbd'.
13:12:48 EDT Mode solver status: queued
13:12:49 EDT Mode solver status: running
13:13:02 EDT Mode solver status: success
And visualize the results.
[11]:
import matplotlib.pyplot as plt
print("Effective index of computed modes: ", jnp.array(mode_data.n_eff))
fig, axs = plt.subplots(num_modes, 3, figsize=(4*num_modes, 10), tight_layout=True)
for mode_ind in range(num_modes):
for field_ind, field_name in enumerate(("Ex", "Ey", "Ez")):
field = mode_data.field_components[field_name].sel(mode_index=mode_ind)
ax = axs[mode_ind, field_ind]
field.real.plot(ax=ax)
ax.set_title(f"{field_name}, index={mode_ind}")
Effective index of computed modes: [[3.141611 2.806403 1.953868 1.062474]]
We identify mode_index=0
as the first order mode that is out of plane of the device. Let’s choose to optimize our device with respect to this as the mode of interest for both the input and output.
We re-set the ModeSpec
to only compute the number of modes we need (1) and also update our ModeSolver
accordingly.
[12]:
mode_index = 0
mode_spec = td.ModeSpec(num_modes=mode_index+1)
mode_solver = mode_solver.updated_copy(mode_spec=mode_spec)
Make input and output mode sources and monitors#
Next, we will generate the input ModeSource
and output ModeMonitor
objects using the convenience methods defined in the ModeSolver
.
Because our plane
was defined at the input port, we’ll modify the centers of the ModeMonitor
s to place them at the output ports to the right of the device.
[13]:
# make source
mode_src = mode_solver.to_source(
source_time=td.GaussianPulse(
freq0=freq0,
fwidth=fwidth,
),
direction="+",
mode_index=mode_index,
)
# make a basic monitor
mode_mnt = mode_solver.to_monitor(
freqs=[freq0],
name="_"
)
# construct the proper centers for the monitors at the 'top' and 'bot' ports
mnt_center_top = list(plane_in.center)
mnt_center_bot = list(plane_in.center)
mnt_center_top[0] = -plane_in.center[0]
mnt_center_bot[0] = -plane_in.center[0]
mnt_center_top[1] = wg_top.geometry.center[1]
mnt_center_bot[1] = wg_bot.geometry.center[1]
# make a dictionary of names and frequencies to refer to later by key
mnt_names = dict(top="mode_top", bot="mode_bot")
mnt_freqs = dict(top=freq_top, bot=freq_bot)
# make two updated copies of the mode monitor with the proper frequencies, centers, and names
mode_mnt_top = mode_mnt.updated_copy(freqs=[mnt_freqs["top"]], center=mnt_center_top, name=mnt_names["top"])
mode_mnt_bot = mode_mnt.updated_copy(freqs=[mnt_freqs["bot"]], center=mnt_center_bot, name=mnt_names["bot"])
# make another dictionary mapping the keys to the monitors
mode_mnts = dict(top=mode_mnt_top, bot=mode_mnt_bot)
Add flux monitors#
For plotting later, we’ll add a couple of FluxMonitor
objects at the output ports to measure the total flux over a large spectrum. With this data, we should be able to clearly see the difference in transmission for each of the ports at the design region and get an idea about the device bandwidth.
[14]:
Nf = 121
freqs_flux = np.linspace(freq_bot - fwidth/10, freq_top + fwidth/10, Nf)
flux_mnt_names = dict(top="flux_top", bot="flux_bot")
flux_mnt_top = td.FluxMonitor(
center=mode_mnt_top.center,
size=mode_mnt_top.size,
name=flux_mnt_names["top"],
freqs=list(freqs_flux),
)
flux_mnt_bot = td.FluxMonitor(
center=mode_mnt_bot.center,
size=mode_mnt_bot.size,
name=flux_mnt_names["bot"],
freqs=list(freqs_flux),
)
Add to simulation#
Finally, we will wrap our previous make_sim_base()
function in a new one that adds our new objects to this base simulation.
[15]:
def make_sim(params, beta):
output_monitors = [mode_mnts["top"], mode_mnts["bot"]]
sim_base = make_sim_base(params, beta=beta)
return sim_base.updated_copy(
output_monitors=output_monitors,
sources=[mode_src],
monitors=tuple(list(sim_base.monitors) + [flux_mnt_top, flux_mnt_bot])
)
Let’s make the final simulation and visualize it with the sources and monitors added.
[16]:
sim = make_sim(params0, beta=1)
Note: the FluxMonitor
objects are overlaying the output ModeMonitor
objects.
[17]:
ax = sim.plot_eps(z=0.01)
plt.show()
Defining objective function#
With our simulation fully defined as a function of our parameters, we are ready to define our objective function.
Computing power transmission#
In this case, it is quite simple, we simply measure the transmitted power in our output waveguide mode. We wish to maximize transmission to the top port at the “top” wavelength (1330 nm) and maximize transmission to the bottom port at the “bot” wavelength (1550 nm).
[18]:
def measure_power(sim_data) -> float:
"""Extract power from simulation data."""
def get_power(mnt_key: str, freq_key: str) -> float:
"""Get the power at monitor 'mnt_key' at frequency 'freq_key' (both either 'top' or 'bot')."""
mnt_name = mnt_names[mnt_key]
freq = freqs[freq_key]
mnt_data = sim_data[mnt_name]
amp = mnt_data.amps.sel(direction="+", mode_index=0, f=freq)
return jnp.abs(amp) ** 2
power_max = get_power("top", "top") + get_power("bot", "bot")
return power_max / 2.0
Next we add a penalty to produce structures that are invariant under erosion and dilation, which is a useful approach to implementing minimum length scale features.
[19]:
def penalty(params, beta, delta_eps=0.49):
params = pre_process(params, beta=beta)
dilate_fn = lambda x: filter_project(x, beta=beta, eta=0.5-delta_eps)
eroded_fn = lambda x: filter_project(x, beta=beta, eta=0.5+delta_eps)
params_dilate_erode = eroded_fn(dilate_fn(params))
params_erode_dilate = dilate_fn(eroded_fn(params))
diff = params_dilate_erode - params_erode_dilate
return jnp.linalg.norm(diff) / jnp.linalg.norm(jnp.ones_like(diff))
Writing objective function#
Then we write an objective
function that constructs our simulation, runs it, measures the power, and subtracts our penalty.
Note, we return a
JaxSimulationData
as the second output. The reason for this is that we might wish to access our flux and field data later on.jax
gives an optionhas_aux
to use only the first output for differentiation while letting the user have access to the 2nd “auxiliary” output, which we will make use of.
[20]:
def objective(params, beta, verbose=False) -> float:
sim = make_sim(params, beta=beta)
sim_data = tda.web.run(sim, task_name="WDM_MULTIFREQ", verbose=verbose)
power = measure_power(sim_data)
J = power - penalty(params, beta=beta)
return J, sim_data
Differentiating objective#
Finally, we can simply use jax
to transform this objective function into a function that returns our objective function value, the auxiliary data, and our gradient, which we will feed to the optimizer.
[21]:
grad_fn = jax.value_and_grad(objective, has_aux=True)
Let’s try out our gradient function with verbosity on for just this run.
[22]:
(J, sim_data), grad = grad_fn(params0, beta=1, verbose=True)
13:13:13 EDT Created task 'WDM_MULTIFREQ' with task_id 'fdve-d055ad08-1512-4931-90e8-8977ba9a382bv1' and task_type 'FDTD'.
View task using web UI at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-d055ad08-151 2-4931-90e8-8977ba9a382bv1'.
13:13:15 EDT status = queued
13:13:25 EDT status = preprocess
13:13:33 EDT Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get the billed FlexCredit cost after a simulation run.
starting up solver
13:13:34 EDT running solver
To cancel the simulation, use 'web.abort(task_id)' or 'web.delete(task_id)' or abort/delete the task in the web UI. Terminating the Python script will not stop the job running on the cloud.
13:13:45 EDT early shutoff detected, exiting.
status = postprocess
13:14:08 EDT status = success
13:14:09 EDT View simulation result at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-d055ad08-151 2-4931-90e8-8977ba9a382bv1'.
13:14:11 EDT loading simulation from simulation_data.hdf5
WARNING: 2 unique frequencies detected in the output monitors with a minimum spacing of 3.720e+13 (Hz). Setting the 'fwidth' of the adjoint sources to 0.1 times this value = 3.720e+12 (Hz) to avoid spectral overlap. To account for this, the corresponding 'run_time' in the adjoint simulation is will be set to 2.688527e-11 compared to 2.688527e-12 in the forward simulation. If the adjoint 'run_time' is large due to small frequency spacing, it could be better to instead run one simulation per frequency, which can be done in parallel using 'tidy3d.plugins.adjoint.web.run_async'.
Created task 'WDM_MULTIFREQ_adj' with task_id 'fdve-ac06b779-409c-4aab-9672-99508fa35ff5v1' and task_type 'FDTD'.
View task using web UI at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-ac06b779-409 c-4aab-9672-99508fa35ff5v1'.
13:14:14 EDT status = queued
13:14:23 EDT status = preprocess
13:14:28 EDT Maximum FlexCredit cost: 0.025. Use 'web.real_cost(task_id)' to get the billed FlexCredit cost after a simulation run.
starting up solver
running solver
To cancel the simulation, use 'web.abort(task_id)' or 'web.delete(task_id)' or abort/delete the task in the web UI. Terminating the Python script will not stop the job running on the cloud.
13:14:39 EDT early shutoff detected, exiting.
status = postprocess
13:14:46 EDT status = success
View simulation result at 'https://tidy3d.simulation.cloud/workbench?taskId=fdve-ac06b779-409 c-4aab-9672-99508fa35ff5v1'.
Run Opimization#
Finally, we are ready to optimize our device. We will make use the optax
package to define an optimizer using the adam
method, as we’ve done in the previous adjoint tutorials.
We record a history of objective function values, simulation data, and parameters for visualization later.
[23]:
import optax
# we know that the source fwidth will be set automatically due to multi-freq adjoint, so suppress warnings
td.config.logging_level = "ERROR"
# hyperparameters
num_steps = 50
learning_rate = 5e-2
beta_min = 1
beta_max = 30
# initialize adam optimizer with starting parameters
params = params0
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
# store history
Js = []
params_history = [params0]
data_history = []
beta_history = []
for i in range(num_steps):
perc_done = i / (num_steps - 1)
beta_i = beta_min * (1 - perc_done) + beta_max * perc_done
# compute gradient and current objective funciton value
(value, data), gradient = grad_fn(params, beta=beta_i)
# outputs
print(f"step = {i + 1}")
print(f"\tJ = {value:.4e}")
print(f"\tbeta = {beta_i:.2f}")
print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")
# compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
updates, opt_state = optimizer.update(-gradient, opt_state, params)
params = optax.apply_updates(params, updates)
# keep params between 0 and 1
params = jnp.minimum(1.0, params)
params = jnp.maximum(0.0, params)
# save history
Js.append(value)
params_history.append(params)
beta_history.append(beta_i)
data_history.append(data)
step = 1
J = 1.1304e-01
beta = 1.00
grad_norm = 1.3146e-01
step = 2
J = 1.7990e-01
beta = 1.59
grad_norm = 2.2436e-01
step = 3
J = 1.9295e-01
beta = 2.18
grad_norm = 2.6208e-01
step = 4
J = 5.7722e-02
beta = 2.78
grad_norm = 3.2809e-01
step = 5
J = -2.6573e-02
beta = 3.37
grad_norm = 5.0954e-01
step = 6
J = -1.0810e-01
beta = 3.96
grad_norm = 6.0937e-01
step = 7
J = 5.5483e-02
beta = 4.55
grad_norm = 2.9692e-01
step = 8
J = 1.3207e-01
beta = 5.14
grad_norm = 3.5711e-01
step = 9
J = 1.9169e-01
beta = 5.73
grad_norm = 3.6781e-01
step = 10
J = 2.2510e-01
beta = 6.33
grad_norm = 4.3558e-01
step = 11
J = 2.2839e-01
beta = 6.92
grad_norm = 4.7313e-01
step = 12
J = 2.8006e-01
beta = 7.51
grad_norm = 6.4329e-01
step = 13
J = 3.4886e-01
beta = 8.10
grad_norm = 5.0031e-01
step = 14
J = 3.8850e-01
beta = 8.69
grad_norm = 3.2984e-01
step = 15
J = 4.0177e-01
beta = 9.29
grad_norm = 3.4545e-01
step = 16
J = 4.1679e-01
beta = 9.88
grad_norm = 3.0066e-01
step = 17
J = 4.5674e-01
beta = 10.47
grad_norm = 3.2095e-01
step = 18
J = 4.5385e-01
beta = 11.06
grad_norm = 2.6492e-01
step = 19
J = 4.6074e-01
beta = 11.65
grad_norm = 2.2217e-01
step = 20
J = 4.8522e-01
beta = 12.24
grad_norm = 2.5342e-01
step = 21
J = 5.1445e-01
beta = 12.84
grad_norm = 2.4852e-01
step = 22
J = 5.3739e-01
beta = 13.43
grad_norm = 2.2100e-01
step = 23
J = 5.5660e-01
beta = 14.02
grad_norm = 2.6047e-01
step = 24
J = 5.8012e-01
beta = 14.61
grad_norm = 1.8017e-01
step = 25
J = 5.9612e-01
beta = 15.20
grad_norm = 2.1880e-01
step = 26
J = 6.1565e-01
beta = 15.80
grad_norm = 1.9455e-01
step = 27
J = 6.3848e-01
beta = 16.39
grad_norm = 2.6951e-01
step = 28
J = 6.6192e-01
beta = 16.98
grad_norm = 4.8207e-01
step = 29
J = 6.9834e-01
beta = 17.57
grad_norm = 3.6163e-01
step = 30
J = 6.7018e-01
beta = 18.16
grad_norm = 9.7225e-01
step = 31
J = 5.8992e-01
beta = 18.76
grad_norm = 1.4539e+00
step = 32
J = 6.6032e-01
beta = 19.35
grad_norm = 5.5199e-01
step = 33
J = 6.3486e-01
beta = 19.94
grad_norm = 7.7368e-01
step = 34
J = 6.7879e-01
beta = 20.53
grad_norm = 3.1729e-01
step = 35
J = 6.9260e-01
beta = 21.12
grad_norm = 3.9760e-01
step = 36
J = 6.8910e-01
beta = 21.71
grad_norm = 4.3376e-01
step = 37
J = 7.0953e-01
beta = 22.31
grad_norm = 2.1688e-01
step = 38
J = 7.2231e-01
beta = 22.90
grad_norm = 1.7996e-01
step = 39
J = 7.3707e-01
beta = 23.49
grad_norm = 1.4756e-01
step = 40
J = 7.4738e-01
beta = 24.08
grad_norm = 1.8689e-01
step = 41
J = 7.4766e-01
beta = 24.67
grad_norm = 2.5960e-01
step = 42
J = 7.5494e-01
beta = 25.27
grad_norm = 1.4756e-01
step = 43
J = 7.5888e-01
beta = 25.86
grad_norm = 2.1101e-01
step = 44
J = 7.6419e-01
beta = 26.45
grad_norm = 1.7915e-01
step = 45
J = 7.6097e-01
beta = 27.04
grad_norm = 4.2852e-01
step = 46
J = 7.4819e-01
beta = 27.63
grad_norm = 8.0884e-01
step = 47
J = 7.3060e-01
beta = 28.22
grad_norm = 1.1788e+00
step = 48
J = 7.5213e-01
beta = 28.82
grad_norm = 6.2305e-01
step = 49
J = 7.6127e-01
beta = 29.41
grad_norm = 4.9933e-01
step = 50
J = 7.7012e-01
beta = 30.00
grad_norm = 4.5721e-01
Visualize Results#
Let’s visualize the results of our optimization.
Objective function vs Iteration#
First we inspect the objective function value as a function of optimization iteration number. We see that it steadily increases as expected.
The presence of fabrication constraints tends to create some minor bumps in the optimization, which can be a signal that one needs to reduce the step size, but these results are sufficient for our purposes.
[24]:
plt.plot(Js)
plt.xlabel('iteration number')
plt.ylabel('objective function')
plt.show()
Final Simulation#
Let’s take a look at the final simulation, which we grab from our history.
[25]:
sim_data_final = data_history[-1]
sim_final = sim_data_final.simulation.to_simulation()[0]
We notice that the structure has reasonably large feature sizes but is not well binarized. This could be improved by increasing the beta
projection value slowly over iteration number, as was done in the grating coupler tutorial.
[26]:
ax = sim_final.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0)
Flux#
Let’s inspect the flux over each of the output ports as a function of wavelength.
We notice that the top and bottom ports have peaks in transmission at their corresponding design wavelengths, as expected!
[27]:
# plot flux
for key, color in zip(("top", "bot"), ('royalblue', 'firebrick')):
freq = freqs[key]
flux_data = sim_data_final[flux_mnt_names[key]]
wvl_nm = 1000 * td.C_0 / freq
wavelengths_nm = 1000 * td.C_0 / np.array(flux_data.flux.f)
flux = np.array(flux_data.flux.values)
flux_db = 10 * np.log(flux)
label = f"{key} ({int(wvl_nm)} nm)"
plt.plot(wavelengths_nm, flux_db, label=label, color=color)
plt.scatter([wvl_nm], [0], 100, marker="*", color=color)
plt.xlabel('wavelength (nm)')
plt.ylabel('transmission (dB)')
plt.legend()
plt.show()
Fields#
Let’s also plot the field intensity patterns at the two design wavelengths. We see from this plot the expected result that the power is directed to the design port at each frequency.
[28]:
# plot fields at the two design wavelengths
fig, axes = plt.subplots(1, 2, tight_layout=True, figsize=(7, 3))
for key, ax in zip(("top", "bot"), axes):
freq = freqs[key]
sim_data_final.plot_field("field", "E", "abs^2", f=freq, ax=ax, vmax=1200)
wvl = 1000 * td.C_0 / freq
ax.set_title(f"wavelength = {int(wvl)} nm")
Animation#
Finally, we animate this plot as a function of iteration number. The animation shows the device quickly accomplishing our design objective.
Note: can take a few minutes to complete
[29]:
import matplotlib.animation as animation
from IPython.display import HTML
fig, (ax1, ax2, ax3) = fig, axes = plt.subplots(1, 3, tight_layout=False, figsize=(9, 4))
def animate(i):
sim_data_i = data_history[i]
sim_i = sim_data_i.simulation.to_simulation()[0]
sim_i.plot_eps(z=0.01, monitor_alpha=0, source_alpha=0, ax=ax1)
ax1.set_aspect('equal')
for key, ax in zip(("top", "bot"), (ax2, ax3)):
freq = freqs[key]
wvl = 1000 * td.C_0 / freq
int_i = sim_data_i.get_intensity("field").sel(f=freq)
int_i.squeeze().plot.pcolormesh(x='x', y='y', ax=ax, add_colorbar=False, cmap="magma", vmax=1000)
ax.set_aspect('equal')
ax.set_title(f"wavelength = {int(wvl)} nm")
# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(data_history))
plt.close()
[30]:
# display the animation (press "play" to start)
HTML(ani.to_jshtml())
[30]:
<Figure size 640x480 with 0 Axes>
To save the animation as a file, uncomment the line below
Note: can take several more minutes to complete
[31]:
# ani.save('img/animation_wdm_adjoint.gif', fps=60)