Adjoint Plugin: 3 Inverse Design Demo#

In this notebook, we will use inverse design and the Tidy3D adjoint plugin to create an integrated photonics component to convert a fundamental waveguide mode to a higher order mode.

[1]:
from typing import List
import numpy as np
import matplotlib.pylab as plt

# import jax to be able to use automatic differentiation
import jax.numpy as jnp
from jax import grad, value_and_grad

# import regular tidy3d
import tidy3d as td
import tidy3d.web as web
from tidy3d.plugins import ModeSolver

# import the components we need from the adjoint plugin
from tidy3d.plugins.adjoint import JaxSimulation, JaxBox, JaxCustomMedium, JaxStructure, JaxSimulationData, JaxDataArray, JaxPermittivityDataset
from tidy3d.plugins.adjoint.web import run
[16:33:01] WARNING  This version of Tidy3D was pip installed from the 'tidy3d-beta' repository on   __init__.py:103
                    PyPI. Future releases will be uploaded to the 'tidy3d' repository. From now on,                
                    please use 'pip install tidy3d' instead.                                                       
           INFO     Using client version: 1.9.0rc1                                                  __init__.py:121

Setup#

We wish to recreate a device like the diagram below:

0c59231ab8de48bfbab5e6a621de65c6

A mode source is injected into a waveguide on the left-hand side. The light propagates through a rectangular region filled with pixellated Box objects, each with a permittivity value independently tunable between 1 (vacuum) and some maximum permittivity. Finally, we measure the transmission of the light into a waveguide on the right-hand side.

The goal of the inverse design exercise is to find the permittivities (\(\epsilon_{ij}\)) of each Box in the coupling region to maximize the power conversion between the input mode and the output mode.

Parameters#

First we will define some parameters.

[2]:
# wavelength and frequency
wavelength = 1.0
freq0 = td.C_0 / wavelength
k0 = 2 * np.pi * freq0 / td.C_0

# resolution control
dl = 0.01

# space between boxes and PML
buffer = 0.5 * wavelength

# optimize region size
lz = td.inf
golden_ratio = 1.618
lx = 5.0
ly = lx / golden_ratio
wg_width = .7

# num cells
nx = 120
ny = int(nx / golden_ratio)
num_cells = nx * ny

# position of source and monitor (constant for all)
source_x = -lx/2 - buffer * 0.8
meas_x = lx/2 + buffer * 0.8

# total size
Lx = lx + 2 * buffer
Ly = ly + 2 * buffer
Lz = 0

# permittivity info
eps_wg = 2.75
eps_deviation_random = 0.5
eps_max = 5

# note, we choose the starting permittivities to be uniform with a small, random deviation
eps_boxes = eps_wg * np.ones((nx, ny))
eps_boxes += 2 * (np.random.random((nx, ny)) - 0.5) * eps_deviation_random
eps_boxes = eps_boxes.flatten()

# frequency width and run time
freqw = freq0 / 10
run_time = 10 / freqw

# mode in and out
mode_index_in = 0
mode_index_out = 2
num_modes = max(mode_index_in, mode_index_out) + 1
mode_spec = td.ModeSpec(num_modes=num_modes)

Static Components#

Next, we will set up the static parts of the geometry, the input source, and the output monitor using these parameters.

[3]:
waveguide = td.Structure(
    geometry=td.Box(size=(td.inf, wg_width, lz)),
    medium=td.Medium(permittivity=eps_wg)
)

mode_size = (0,4,lz)

# source seeding the simulation
forward_source = td.ModeSource(
    source_time=td.GaussianPulse(freq0=freq0, fwidth=freqw),
    center=[source_x, 0, 0],
    size=mode_size,
    mode_index=mode_index_in,
    mode_spec=mode_spec,
    direction="+"
)

# we'll refer to the measurement monitor by this name often
measurement_monitor_name = 'measurement'

# monitor where we compute the objective function from
measurement_monitor = td.ModeMonitor(
    center=[meas_x, 0, 0],
    size=mode_size,
    freqs=[freq0],
    mode_spec=mode_spec,
    name=measurement_monitor_name,
)

Input Structures#

Next, we write a function to return the pixellated array given our flattened tuple of permittivity values \(\epsilon_{ij}\) using JaxCustomMedium.

We will feed the result of this function to our JaxSimulation.input_structures and will take the gradient w.r.t. the inputs.

[4]:
def make_input_structures(eps_boxes) -> List[JaxStructure]:

    size_box_x = float(lx) / nx
    size_box_y = float(ly) / ny
    size_box = (size_box_x, size_box_y, lz)

    x0_min = -lx/2 + size_box_x/2
    y0_min = -ly/2 + size_box_y/2

    input_structures = []

    coords_x = [x0_min + index_x * size_box_x - 1e-5 for index_x in range(nx)]
    coords_y = [y0_min + index_y * size_box_y - 1e-5 for index_y in range(ny)]

    coords = dict(x=coords_x, y=coords_y, z=[0], f=[freq0])
    values = []

    index_box = 0
    for index_x in range(nx):
        x0 = coords_x[index_x]
        for index_y in range(ny):
            y0 = coords_y[index_y]
            values.append(eps_boxes[index_box])
            index_box += 1


    values = jnp.array(values).reshape((nx, ny, 1, 1))
    field_components = {f"eps_{dim}{dim}": JaxDataArray(values=values, coords=coords) for dim in "xyz"}
    eps_dataset = JaxPermittivityDataset(**field_components)
    custom_medium = JaxCustomMedium(eps_dataset=eps_dataset)
    box = JaxBox(center=(0,0,0), size=(lx, ly, lz))
    custom_structure = JaxStructure(geometry=box, medium=custom_medium)
    return [custom_structure]

Jax Simulation#

Next, we write a function to return the JaxSimulation as a function of our \(\epsilon_{ij}\) values.

We make sure to add the pixellated JaxStructure list to input_structures and the measurement_monitor to output_monitors.

[5]:
def make_sim(eps_boxes) -> JaxSimulation:

    input_structures = make_input_structures(eps_boxes)

    return JaxSimulation(
        size=[Lx, Ly, Lz],
        grid_spec=td.GridSpec.uniform(dl=dl),
        structures=[waveguide],
        input_structures=input_structures,
        sources=[forward_source],
        monitors=[],
        output_monitors=[measurement_monitor],
        run_time=run_time,
        subpixel=True,
        boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=False),
        shutoff=1e-8,
        courant=0.9,
    )

Visualize#

Let’s visualize the simulation to see how it looks

[6]:
sim_start = make_sim(eps_boxes)

f, axes = plt.subplots(1, 3, tight_layout=True, figsize=(15, 10))

for dim, ax in zip('xyz', axes):
    sim_start.plot_eps(**{dim:0}, ax=ax)

plt.show()
           INFO     Remote TPU is not linked into jax; skipping remote TPU.                       xla_bridge.py:170
           INFO     Unable to initialize backend 'tpu_driver': Could not initialize backend       xla_bridge.py:355
                    'tpu_driver'                                                                                   
           INFO     Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no     xla_bridge.py:355
                    attribute 'GpuAllocatorConfig'                                                                 
           INFO     Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no     xla_bridge.py:355
                    attribute 'GpuAllocatorConfig'                                                                 
           INFO     Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no      xla_bridge.py:355
                    attribute 'get_tpu_client'                                                                     
           INFO     Unable to initialize backend 'plugin': xla_extension has no attributes named  xla_bridge.py:355
                    get_plugin_device_client. Compile TensorFlow with                                              
                    //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults                    
                    to false) to enable this.                                                                      
           WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
../_images/notebooks_AdjointPlugin_3_InverseDesign_11_7.png

Select Input and Output Modes#

Next, let’s visualize the mode profiles so we can inspect which mode indices we want to inject and transmit.

[7]:
mode_solver = ModeSolver(simulation=sim_start, plane=forward_source, mode_spec=mode_spec, freqs=[freq0])
modes = mode_solver.solve()

print("Effective index of computed modes: ", np.array(modes.n_eff))

fig, axs = plt.subplots(num_modes, 3, figsize=(20, 15))
for mode_ind in range(num_modes):
    for field_ind, field_name in enumerate(('Ex', 'Ey', 'Ez')):
        field = modes.field_components[field_name].sel(mode_index=mode_ind)
        ax = axs[mode_ind, field_ind]
        field.real.plot(ax=ax)
/Users/twhughes/.pyenv/versions/3.10.9/lib/python3.10/site-packages/numpy/linalg/linalg.py:2139: RuntimeWarning: divide by zero encountered in det
  r = _umath_linalg.det(a, signature=signature)
/Users/twhughes/.pyenv/versions/3.10.9/lib/python3.10/site-packages/numpy/linalg/linalg.py:2139: RuntimeWarning: invalid value encountered in det
  r = _umath_linalg.det(a, signature=signature)
Effective index of computed modes:  [[1.5736595 1.5368265 1.3096673]]
../_images/notebooks_AdjointPlugin_3_InverseDesign_13_2.png

Aftert inspection, we decide to inject the fundamental, Ez-polarized input into the 1st order Ez-polarized input.

From the plots, we see that these modes correspond to the first and third rows, or mode_index=0 and mode_index=2, respectively.

So we make sure that the mode_index_in and mode_index_out variables are set appropriately.

Post Processing#

We will define one more function to tell us how we want to postprocess a JaxSimulationData object to give the conversion power that we are interested in maximizing.

[8]:
def measure_power(sim_data: JaxSimulationData) -> float:
    """Return the power in the output_data amplitude at the mode index of interest."""
    output_amps = sim_data.output_data[0].amps
    amp = output_amps.sel(direction="+", f=freq0, mode_index=mode_index_out)
    return jnp.sum(jnp.abs(amp)**2)

penalty_strength = 0.0
def binary_penalty(eps_boxes, penalty_strength=0.0):
    """Applies penalty of `penalty_strength` directly between 1 and eps_max and 0 at the boundaries."""

    delta_eps = eps_max - 1
    eps_average = jnp.mean(eps_boxes)
    above_vacuum = eps_average - 1
    below_epsmax = eps_max - eps_average
    return penalty_strength * above_vacuum * below_epsmax / delta_eps

Define Objective Function#

Finally, we need to define the objective function that we want to maximize as a function of our input parameters (permittivity of each box) that returns the conversion power. This is the function we will differentiate later.

[9]:
def J(eps_boxes, step_num:int=None) -> float:
    sim = make_sim(eps_boxes)
    task_name = "inv_des"
    if step_num:
        task_name += f"_step_{step_num}"
    sim_data = run(sim, task_name=task_name)
    power = measure_power(sim_data)
    penalty = binary_penalty(eps_boxes)
    return power - penalty

Inverse Design#

Now we are ready to perform the optimization.

We use the jax.value_and_grad function to get the gradient of J with respect to the permittivity of each Box, while also returning the converted power associated with the current iteration, so we can record this value for later.

Let’s try running this function once to make sure it works.

[10]:
dJ_fn = value_and_grad(J)
[11]:
val, grad = dJ_fn(eps_boxes)
print(grad.shape)
[16:33:07] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
[16:33:09] INFO     Created task 'inv_des_fwd' with task_id '878446e6-2704-434e-a490-79e2a16143d8'.   webapi.py:120
[16:33:13] INFO     status = queued                                                                   webapi.py:262
[16:33:16] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:33:18] INFO     status = preprocess                                                               webapi.py:274
[16:33:22] INFO     starting up solver                                                                webapi.py:278
[16:33:31] INFO     running solver                                                                    webapi.py:284
[16:33:32] INFO     status = postprocess                                                              webapi.py:307
[16:33:38] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:33:41] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 1.04e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:33:43] INFO     Created task 'inv_des_adj' with task_id '9d71232c-1760-492f-b072-1d2aa78e1b80'.   webapi.py:120
[16:33:47] INFO     status = queued                                                                   webapi.py:262
[16:33:50] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:33:53] INFO     status = preprocess                                                               webapi.py:274
[16:33:56] INFO     starting up solver                                                                webapi.py:278
[16:34:06] INFO     running solver                                                                    webapi.py:284
[16:34:07] INFO     status = postprocess                                                              webapi.py:307
[16:34:13] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:34:16] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 2.24e-05 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
(8880,)

Optimization#

We will use “Adam” optimization strategy to perform sequential updates of each of the permittivity values in the JaxCustomMedium.

For more information on what we use to implement this method, see this article.

We will run 10 steps and measure both the permittivities and powers at each iteration.

We capture this process in an optimize function, which accepts various parameters that we can tweak.

[12]:
permittivities = np.array(eps_boxes)

Js = []
perms = [permittivities]

def optimize(
    permittivities,
    step_size=.2,
    num_steps=8,
    eps_max=eps_max,
    beta1=0.9,
    beta2=0.999,
    epsilon=1e-8,
):

    mt = np.zeros_like(permittivities)
    vt = np.zeros_like(permittivities)

    for i in range(num_steps):

        t = i + 1
        print(f'step = {t}')

        power, gradient = dJ_fn(permittivities, step_num=t)
        gradient = np.array(gradient).copy()

        mt = beta1 * mt + (1-beta1) * gradient
        vt = beta2 * vt + (1-beta2) * gradient**2

        mt_hat = mt / (1 - beta1**t)
        vt_hat = vt / (1 - beta2**t)

        update = step_size * (mt_hat / np.sqrt(vt_hat) + epsilon)

        Js.append(power)
        print(f'\tJ = {power:.4e}')
        print(f'\tgrad_norm = {np.linalg.norm(gradient):.4e}')

        permittivities += update
        permittivities[permittivities > eps_max] = eps_max
        permittivities[permittivities < 1.0] = 1.0
        perms.append(permittivities.copy())
    return permittivities

Let’s run the optimize function.

[13]:
perms_after = optimize(permittivities)
step = 1
[16:34:25] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
[16:34:26] INFO     Created task 'inv_des_step_1_fwd' with task_id                                    webapi.py:120
                    'a0457c33-62a4-482e-b6ea-062e094b87c9'.                                                        
[16:34:30] INFO     status = queued                                                                   webapi.py:262
[16:34:33] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:34:34] INFO     status = preprocess                                                               webapi.py:274
[16:34:38] INFO     starting up solver                                                                webapi.py:278
[16:34:47] INFO     running solver                                                                    webapi.py:284
[16:34:48] INFO     status = postprocess                                                              webapi.py:307
[16:34:54] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:34:57] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 1.04e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:34:58] INFO     Created task 'inv_des_step_1_adj' with task_id                                    webapi.py:120
                    '26b19d35-e838-4378-9cc0-a6a9d9288f1b'.                                                        
[16:35:03] INFO     status = queued                                                                   webapi.py:262
[16:35:05] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:35:08] INFO     status = preprocess                                                               webapi.py:274
[16:35:14] INFO     starting up solver                                                                webapi.py:278
[16:35:24] INFO     running solver                                                                    webapi.py:284
           INFO     status = postprocess                                                              webapi.py:307
[16:35:30] INFO     status = success                                                                  webapi.py:307
[16:35:31] INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:35:34] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 2.24e-05 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 2.5973e-03
        grad_norm = 3.0960e-02
step = 2
[16:35:41] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
[16:35:42] INFO     Created task 'inv_des_step_2_fwd' with task_id                                    webapi.py:120
                    'd7cce203-e642-4400-bef7-37afa031a471'.                                                        
[16:35:47] INFO     status = queued                                                                   webapi.py:262
[16:35:49] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:35:52] INFO     status = preprocess                                                               webapi.py:274
[16:35:55] INFO     starting up solver                                                                webapi.py:278
[16:36:04] INFO     running solver                                                                    webapi.py:284
[16:36:05] INFO     status = postprocess                                                              webapi.py:307
[16:36:11] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:36:14] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 4.44e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:36:16] INFO     Created task 'inv_des_step_2_adj' with task_id                                    webapi.py:120
                    '9756272a-e34c-48d5-918e-ff98974c940b'.                                                        
[16:36:20] INFO     status = queued                                                                   webapi.py:262
[16:36:23] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:36:24] INFO     status = preprocess                                                               webapi.py:274
[16:36:28] INFO     starting up solver                                                                webapi.py:278
[16:36:37] INFO     running solver                                                                    webapi.py:284
[16:36:38] INFO     status = postprocess                                                              webapi.py:307
[16:36:44] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:36:47] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 4.52e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 5.4169e-01
        grad_norm = 3.2903e-01
step = 3
[16:36:55] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
           INFO     Created task 'inv_des_step_3_fwd' with task_id                                    webapi.py:120
                    'f3e05153-8cfb-461a-9dc5-398cabb9c645'.                                                        
[16:37:00] INFO     status = queued                                                                   webapi.py:262
[16:37:03] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:37:05] INFO     status = preprocess                                                               webapi.py:274
[16:37:09] INFO     starting up solver                                                                webapi.py:278
[16:37:18] INFO     running solver                                                                    webapi.py:284
[16:37:19] INFO     status = postprocess                                                              webapi.py:307
[16:37:25] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:37:28] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 8.64e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:37:29] INFO     Created task 'inv_des_step_3_adj' with task_id                                    webapi.py:120
                    '50350a48-8719-400f-91aa-712644bed75b'.                                                        
[16:37:34] INFO     status = queued                                                                   webapi.py:262
[16:37:37] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:37:40] INFO     status = preprocess                                                               webapi.py:274
[16:37:46] INFO     starting up solver                                                                webapi.py:278
[16:37:56] INFO     running solver                                                                    webapi.py:284
[16:37:57] INFO     status = postprocess                                                              webapi.py:307
[16:38:03] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:38:06] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 1.79e-05 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 7.9532e-01
        grad_norm = 9.4370e-01
step = 4
[16:38:14] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
           INFO     Created task 'inv_des_step_4_fwd' with task_id                                    webapi.py:120
                    '7c51d7eb-6016-4fe0-8e4e-114c00a50427'.                                                        
[16:38:19] INFO     status = queued                                                                   webapi.py:262
[16:38:21] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:38:23] INFO     status = preprocess                                                               webapi.py:274
[16:38:27] INFO     starting up solver                                                                webapi.py:278
[16:38:35] INFO     running solver                                                                    webapi.py:284
[16:38:36] INFO     status = postprocess                                                              webapi.py:307
[16:38:43] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:38:46] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 3.53e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:38:47] INFO     Created task 'inv_des_step_4_adj' with task_id                                    webapi.py:120
                    '08cae26c-0607-495e-a75e-bf33baa17bd1'.                                                        
[16:38:52] INFO     status = queued                                                                   webapi.py:262
[16:38:55] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:38:57] INFO     status = preprocess                                                               webapi.py:274
[16:39:00] INFO     starting up solver                                                                webapi.py:278
[16:39:09] INFO     running solver                                                                    webapi.py:284
[16:39:10] INFO     status = postprocess                                                              webapi.py:307
[16:39:16] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:39:19] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 1.01e-05 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 8.7387e-01
        grad_norm = 7.1291e-01
step = 5
[16:39:27] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
           INFO     Created task 'inv_des_step_5_fwd' with task_id                                    webapi.py:120
                    '4fa9e4d4-8800-4468-afe0-49a7b1a5568a'.                                                        
[16:39:32] INFO     status = queued                                                                   webapi.py:262
[16:39:35] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:39:37] INFO     status = preprocess                                                               webapi.py:274
[16:39:40] INFO     starting up solver                                                                webapi.py:278
[16:39:49] INFO     running solver                                                                    webapi.py:284
[16:39:50] INFO     status = postprocess                                                              webapi.py:307
[16:39:57] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:40:01] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 2.5e-06 is greater than the simulation      webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:40:02] INFO     Created task 'inv_des_step_5_adj' with task_id                                    webapi.py:120
                    '3fc7422a-4ea0-4d4b-a47c-65f6393a6f35'.                                                        
[16:40:06] INFO     status = queued                                                                   webapi.py:262
[16:40:09] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:40:17] INFO     status = preprocess                                                               webapi.py:274
[16:40:20] INFO     starting up solver                                                                webapi.py:278
[16:40:30] INFO     running solver                                                                    webapi.py:284
[16:40:31] INFO     status = postprocess                                                              webapi.py:307
[16:40:37] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:40:40] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 6.68e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 9.0226e-01
        grad_norm = 1.6261e-01
step = 6
[16:40:48] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
           INFO     Created task 'inv_des_step_6_fwd' with task_id                                    webapi.py:120
                    '4569c5ac-59fa-4f36-8514-7a7e1bc4a78b'.                                                        
[16:40:53] INFO     status = queued                                                                   webapi.py:262
[16:40:56] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:40:58] INFO     status = preprocess                                                               webapi.py:274
[16:41:02] INFO     starting up solver                                                                webapi.py:278
[16:41:11] INFO     running solver                                                                    webapi.py:284
[16:41:12] INFO     status = postprocess                                                              webapi.py:307
[16:41:18] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
[16:41:19] INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:41:21] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
[16:41:22] WARNING  Simulation final field decay value of 2.75e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:41:23] INFO     Created task 'inv_des_step_6_adj' with task_id                                    webapi.py:120
                    '4e3948f0-cde6-4e02-9766-0c450d9498fb'.                                                        
[16:41:28] INFO     status = queued                                                                   webapi.py:262
[16:41:30] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:41:32] INFO     status = preprocess                                                               webapi.py:274
[16:41:36] INFO     starting up solver                                                                webapi.py:278
[16:41:45] INFO     running solver                                                                    webapi.py:284
[16:41:46] INFO     status = postprocess                                                              webapi.py:307
[16:41:51] INFO     status = success                                                                  webapi.py:307
[16:41:52] INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:41:55] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 7.28e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 9.1641e-01
        grad_norm = 5.0582e-01
step = 7
[16:42:02] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
[16:42:03] INFO     Created task 'inv_des_step_7_fwd' with task_id                                    webapi.py:120
                    '3674fb66-9793-461d-aca6-7903cbedf9f9'.                                                        
[16:42:07] INFO     status = queued                                                                   webapi.py:262
[16:42:14] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:42:15] INFO     status = preprocess                                                               webapi.py:274
[16:42:19] INFO     starting up solver                                                                webapi.py:278
[16:42:27] INFO     running solver                                                                    webapi.py:284
[16:42:29] INFO     status = postprocess                                                              webapi.py:307
[16:42:35] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
[16:42:36] INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:42:38] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
[16:42:39] WARNING  Simulation final field decay value of 5.85e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:42:40] INFO     Created task 'inv_des_step_7_adj' with task_id                                    webapi.py:120
                    '8934bfd8-8d00-48a9-ae58-5887b31a4aca'.                                                        
[16:42:45] INFO     status = queued                                                                   webapi.py:262
[16:42:51] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:42:52] INFO     status = preprocess                                                               webapi.py:274
[16:42:56] INFO     starting up solver                                                                webapi.py:278
[16:43:06] INFO     running solver                                                                    webapi.py:284
[16:43:07] INFO     status = postprocess                                                              webapi.py:307
[16:43:13] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:43:16] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 1.3e-05 is greater than the simulation      webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 9.3694e-01
        grad_norm = 1.7838e-01
step = 8
[16:43:24] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
           INFO     Created task 'inv_des_step_8_fwd' with task_id                                    webapi.py:120
                    '93188538-a57c-4e2e-b279-f3866a504e2b'.                                                        
[16:43:28] INFO     status = queued                                                                   webapi.py:262
[16:43:37] INFO     status = preprocess                                                               webapi.py:274
[16:43:40] INFO     starting up solver                                                                webapi.py:278
[16:43:50] INFO     running solver                                                                    webapi.py:284
           INFO     status = postprocess                                                              webapi.py:307
[16:43:57] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:44:00] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 8.86e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
[16:44:02] INFO     Created task 'inv_des_step_8_adj' with task_id                                    webapi.py:120
                    '1281e4b2-24d5-4ab6-9158-5b9cf438a849'.                                                        
[16:44:06] INFO     status = queued                                                                   webapi.py:262
[16:44:09] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:44:11] INFO     status = preprocess                                                               webapi.py:274
[16:44:15] INFO     starting up solver                                                                webapi.py:278
[16:44:23] INFO     running solver                                                                    webapi.py:284
[16:44:24] INFO     status = postprocess                                                              webapi.py:307
[16:44:29] INFO     status = success                                                                  webapi.py:307
[16:44:30] INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:44:33] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 3.23e-05 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            
        J = 9.3959e-01
        grad_norm = 3.4754e-01

and then record the final power value (including the last iteration’s parameter updates).

[14]:
power = J(perms_after)
Js.append(power)
[16:44:36] WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
[16:44:37] INFO     Created task 'inv_des' with task_id '631ff587-5fae-497c-935d-965a9e75cf15'.       webapi.py:120
[16:44:41] INFO     status = queued                                                                   webapi.py:262
[16:44:44] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:44:46] INFO     status = preprocess                                                               webapi.py:274
[16:44:50] INFO     starting up solver                                                                webapi.py:278
[16:44:58] INFO     running solver                                                                    webapi.py:284
           INFO     status = postprocess                                                              webapi.py:307
[16:45:02] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:45:04] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 5.84e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            

Results#

First, we plot the objective function (power converted to 1st order mode) as a function of step and notice that it converges nicely!

The final device converts about 90% of the input power to the 1st mode, up from < 1% when we started, with room for improvement if we run with more steps.

[15]:
plt.plot(Js)
plt.xlabel('iterations')
plt.ylabel('objective function')
plt.show()
../_images/notebooks_AdjointPlugin_3_InverseDesign_29_0.png
[16]:
print(f'Initial power conversion = {Js[0]*100:.2f} %')
print(f'Final power conversion = {Js[-1]*100:.2f} %')
Initial power conversion = 0.26 %
Final power conversion = 94.87 %

We then will visualize the final structure, so we convert it to a regular Simulation using the final permittivity values and plot it.

[17]:
sim_final = make_sim(perms_after)
           WARNING  'BoundarySpec.z' uses default value, which is 'Periodic()' but will change to   boundary.py:607
                    'PML()' in Tidy3D version 2.0. We recommend explicitly setting all boundary                    
                    conditions ahead of this release to avoid unexpected results.                                  
[18]:
sim_final = sim_final.to_simulation()[0]
sim_final.plot_eps(z=0)
[18]:
<AxesSubplot: title={'center': 'cross section at z=0.00'}, xlabel='x', ylabel='y'>
../_images/notebooks_AdjointPlugin_3_InverseDesign_33_1.png

Finally, we want to inspect the fields, so we add a field monitor to the Simulation and perform one more run to record the field values for plotting.

[19]:
field_mnt = td.FieldMonitor(
    size=(td.inf, td.inf, 0),
    freqs=[freq0],
    name='field_mnt',
)

sim_final = sim_final.copy(update=dict(monitors=(field_mnt,)))
[20]:
sim_data_final = web.run(sim_final, task_name='inv_des_final')
[16:45:05] INFO     Created task 'inv_des_final' with task_id '93de5273-5f3f-44eb-a331-425a4bf08ee3'. webapi.py:120
[16:45:09] INFO     status = queued                                                                   webapi.py:262
[16:45:12] INFO     Maximum FlexUnit cost: 0.025                                                      webapi.py:253
[16:45:15] INFO     status = preprocess                                                               webapi.py:274
[16:45:18] INFO     starting up solver                                                                webapi.py:278
[16:45:26] INFO     running solver                                                                    webapi.py:284
[16:45:27] INFO     status = postprocess                                                              webapi.py:307
[16:45:34] INFO     status = success                                                                  webapi.py:307
           INFO     Billed FlexUnit cost: 0.025                                                       webapi.py:311
           INFO     downloading file "output/monitor_data.hdf5" to "simulation_data.hdf5"             webapi.py:593
[16:45:40] INFO     loading SimulationData from simulation_data.hdf5                                  webapi.py:415
           WARNING  Simulation final field decay value of 5.84e-06 is greater than the simulation     webapi.py:421
                    shutoff threshold of 1e-08. Consider simulation again with large run_time                      
                    duration for more accurate results.                                                            

We notice that the behavior is as expected and the device performs exactly how we intended!

[21]:
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 6))
ax1 = sim_data_final.plot_field('field_mnt', 'Ez', z=0, ax=ax1)
ax2 = sim_data_final.plot_field('field_mnt', 'int', z=0, ax=ax2)
../_images/notebooks_AdjointPlugin_3_InverseDesign_38_0.png
[ ]:

[ ]: