Inverse design optimization of a metalens#

In this notebook, we will use inverse design and the Tidy3D adjoint plugin to design a high numerical aperture (NA) metalens for optimal focusing to a point. This demo also introduces how to use the adjoint plugin for objective functions that depend on the FieldMonitor outputs.

We will follow the basic set up from Mansouree et al. “Large-Scale Parametrized Metasurface Design Using Adjoint Optimization”. The published paper can be found here and the arxiv preprint can be found here.

Setup#

We first perform basic imports of the packages needed.

[1]:
# standard python imports
import numpy as np
from numpy import random
import matplotlib.pyplot as plt

import tidy3d as td
from tidy3d import web

import tidy3d.plugins.adjoint as tda
from tidy3d.plugins.adjoint.web import run as run_adj
import jax
import jax.numpy as jnp

The metalens design consists of a rectangular array of Si rectangular prisms sitting on an SiO2 substrate.

Here we define all of the basic parameters of the setup, including the wavelength, NA, geometrical dimensions, and material properties.

[2]:
# 1 nanometer in units of microns (for conversion)
nm = 1e-3

# free space central wavelength
wavelength = 850 * nm

# desired numerical aperture
NA = 0.94

# shape parameters of metalens unit cell (um) (refer to image above and see paper for details)
H = 430 * nm
S = 320 * nm

# space between bottom PML and substrate (-z)
space_below_sub = 1 * wavelength

# thickness of substrate between source and Si unit cells
thickness_sub = 100 * nm

# side length of entire metalens (um)
side_length = 6

# Number of unit cells in each x and y direction (NxN grid)
N = int(side_length / S)

print(f"for diameter of {side_length:.1f} um, have {N} cells per side")
print(f"full metalens has area of {side_length**2:.1f} um^2 and {N*N} total cells")

# Define material properties at 600 nm
n_Si = 3.84
n_SiO2 = 1.46
air = td.Medium(permittivity=1.0)
SiO2 = td.Medium(permittivity=n_SiO2**2)
Si = td.Medium(permittivity=n_Si**2)

# define symmetry
symmetry = (-1, 1, 0)
for diameter of 6.0 um, have 18 cells per side
full metalens has area of 36.0 um^2 and 324 total cells

Next, we will compute some important quantities derived from these parameters.

[3]:
# using the wavelength in microns, one can use td.C_0 (um/s) to get frequency in Hz
# wavelength_meters = wavelength * meters
f0 = td.C_0 / wavelength

# Compute the domain size in x, y (note: round down from side_length)
length_xy = N * S

# focal length given diameter and numerical aperture
focal_length = length_xy / 2 / NA * np.sqrt(1 - NA**2)

# total domain size in z: (space -> substrate -> unit cell -> 1.7 focal lengths)
length_z = space_below_sub + thickness_sub + H + 1.7 * focal_length

# construct simulation size array
sim_size = (length_xy, length_xy, length_z)

Create Metalens Geometry#

Now we will define the structures in our simulation. We will first generate the substrate as a regular td.Box.

[4]:
# define substrate
substrate = td.Structure(
    geometry=td.Box.from_bounds(
        rmin=(-td.inf, -td.inf, -1000),
        rmax=(+td.inf, +td.inf, -length_z / 2 + space_below_sub + thickness_sub)
    ),
    medium=SiO2,
)

Next, we will write a function to make a list of JaxStructure objects corresponding to each unit cell.

Note that the adjoint plugin does not yet support GeometryGroup for JaxBox, so we will keep them as individual JaxStructure objects for now.

[5]:
# define coordinates of each unit cell
centers_x = S * np.arange(N) - length_xy / 2.0 + S / 2.0
centers_y = S * np.arange(N) - length_xy / 2.0 + S / 2.0
center_z = -length_z / 2 + space_below_sub + thickness_sub + H / 2.0

focal_z = center_z + H / 2 + focal_length


x_centers, y_centers = np.meshgrid(centers_x, centers_y, indexing="ij")
xs = x_centers.flatten()
ys = y_centers.flatten()

def get_sizes(params):
    """Returns the actual side lengths of the boxes as a function of design parameters from (-inf, +inf)."""
    return S * (jnp.tanh(params) + 1.0) / 2.0

# initially, start with parameters of 0 (all boxes have side length S/2)
params0 = 0 * np.ones(x_centers.shape)

def make_structures(params, apply_symmetry: bool = True):
    """Make the JaxStructure objects that will be used as .input_structures."""

    sizes = get_sizes(params)
    nx, ny = sizes.shape
    geometries = []

    for i in range(nx):
        i_quad = max(i, nx - 1 - i)
        for j in range(ny):
            j_quad = max(j, ny - 1 - j)
            size = sizes[i_quad, j_quad]
            x0 = x_centers[i, j]
            y0 = y_centers[i, j]

            if apply_symmetry and symmetry[0] != 0 and x0 < -S/2:
                continue

            if apply_symmetry and symmetry[1] != 0 and y0 < -S/2:
                continue

            geometry = tda.JaxBox(
                center=(x0, y0, center_z),
                size=(size, size, H)
            )

            geometries.append(geometry)
    medium = tda.JaxMedium(permittivity=n_Si**2)
    return [tda.JaxStructure(medium=medium, geometry=geo) for geo in geometries]

structures = make_structures(params0)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Define grid specification#

We define the grid based on the properties of the geometry. The metalens is quasi-periodic in x and y, in that we have clearly defined unit cells, but each is slightly modified from its neighbors. Such structures are best resolved with a grid that matches the periodicity, which is why we use a uniform grid in x and y. In z, we use the automatic nonuniform grid that will place a higher grid density around the metalens region, and a lower one in the air region away from the metalens. To speed up the auto meshing in the region with the pillars, we put an override box in the grid specification.

[6]:
# steps per unit cell along x and y
grids_per_unit_length = 10

# uniform mesh in x and y
grid_x = td.UniformGrid(dl=S / grids_per_unit_length)
grid_y = td.UniformGrid(dl=S / grids_per_unit_length)

# in z, use an automatic nonuniform mesh with the wavelength being the "unit length"
grid_z = td.AutoGrid(min_steps_per_wvl=grids_per_unit_length)

# we need to supply the wavelength because of the automatic mesh in z
grid_spec = td.GridSpec(
    wavelength=wavelength, grid_x=grid_x, grid_y=grid_y, grid_z=grid_z
)

# put an override box over the pillars to avoid parsing a large amount of structures in the mesher
grid_spec = grid_spec.copy(
    update=dict(
        override_structures=[
            td.Structure(
                geometry=td.Box.from_bounds(
                    rmin=(-td.inf, -td.inf, -length_z / 2 + space_below_sub),
                    rmax=(td.inf, td.inf, center_z + H / 2),
                ),
                medium=Si,
            )
        ]
    )
)

Define Source#

Now we define the incident fields. We simply use an x-polarized, normally incident plane wave with Gaussian time dependence centered at our central frequency. For more details, see the plane wave source documentation and the gaussian source documentation

[7]:
# Bandwidth in Hz
fwidth = f0 / 10.0

# time dependence of source
gaussian = td.GaussianPulse(freq0=f0, fwidth=fwidth, phase=0)

source = td.PlaneWave(
    source_time=gaussian,
    size=(td.inf, td.inf, 0),
    center=(0, 0, -length_z / 2 + space_below_sub / 10.0),
    direction="+",
    pol_angle=0,
)

run_time = 50 / fwidth

Define Monitors#

Now we define the monitor that measures field output from the FDTD simulation. For simplicity, we use measure the fields at the central frequency at the focal spot.

This will be the monitor that we use in our objective function, so it will go into JaxSimulation.output_monitors.

[8]:
# To decrease the amount of data stored, only store the E fields
fields = ["Ex", "Ey", "Ez"]

monitor_focal = td.FieldMonitor(
    center=[0.0, 0.0, focal_z],
    size=[0, 0, 0],
    freqs=[f0],
    name="focal_point",
    fields=fields,
    colocate=False,
)

Create Simulation#

Now we can put everything together and define a JaxSimulation object to be run.

We get a number of warnings about structures being too close to the PML. In FDTD simulations, this can result in instability, as PML are absorbing for propagating fields, but can be amplifying for evanescent fields. This particular simulation runs without any issues even with PML on the sides, but it is best to heed these warnings to avoid problems. There are two ways that we can fix the simulation: one is to just put some space between the last of the metalens boxes and the PML. The other is to use adiabatic absorbers on the sides, which are always stable. The only downside of the absorbers is that they are slightly thicker than the PML, making the overall simulation size slightly larger. This is why we only put them along x and y, while we leave the PML in z.

Note: we add symmetry of (-1, 1, 0) to speed up the simulation by approximately 4x taking into account the symmetry in our source and dielectric function.

[9]:
def make_sim(angles, apply_symmetry: bool=True):
    metalens = make_structures(angles, apply_symmetry=apply_symmetry)
    sim = tda.JaxSimulation(
        size=sim_size,
        grid_spec=grid_spec,
        structures=[substrate],
        input_structures=metalens,
        sources=[source],
        monitors=[],
        output_monitors=[monitor_focal],
        run_time=run_time,
        boundary_spec=td.BoundarySpec(
            x=td.Boundary.absorber(), y=td.Boundary.absorber(), z=td.Boundary.pml()
        ),
        symmetry=symmetry,
    )
    return sim

sim = make_sim(params0)
[20:38:51] WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           

The warnings are just letting us know that we are using uniform grid along x and y (per our spec) even though the override structures have some extent in these dimensions. We can ignore as this is intended.

Visualize Geometry#

Lets take a look and make sure everything is defined properly.

[10]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 6))

sim.plot(x=0.1, ax=ax1)
sim.plot(y=0.1, ax=ax2)
sim.plot(z=-length_z / 2 + space_below_sub + thickness_sub + H / 2, ax=ax3)
plt.show()

           WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
[20:38:52] WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           x-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
           WARNING: Override structures take no effect along    grid_spec.py:555
           y-axis. If intending to apply override structures to                 
           this axis, use 'AutoGrid'.                                           
../_images/notebooks_AdjointPlugin7Metalens_22_12.png

Objective Function#

Now that our simulation is set up, we can define our objective function over the JaxSimulationData results.

We first write a function to take a JaxSimulationData object and return the intensity at the focal point.

Next, we write a function to

  1. Set up our simulation given our design parameters.

  2. Run the simulation through the adjoint run function.

  3. Compute and return the intensity at the focal point.

[11]:
# turn off warnings as we understand they are just about AutoGrid and can be ignored in our case
td.config.logging_level = "ERROR"

def measure_focal_intensity(sim_data: tda.JaxSimulationData) -> float:
    """Measures electric intensity at focal point."""
    return jnp.sum(sim_data.get_intensity('focal_point').values)

def J(params) -> float:
    """Objective function, returns intensity at focal point as a function of params."""
    sim = make_sim(params)
    sim_data = run_adj(sim, task_name='metalens_adj', verbose=False)
    return measure_focal_intensity(sim_data)

We first run our function to test that it works and see the starting value.

[12]:
J(params0)
[12]:
Array(11.112829, dtype=float32)

Next, we use jax to get a function returning the objective value and its gradient, given some parameters.

[13]:
dJ = jax.value_and_grad(J)

And try it out.

[14]:
val, grad = dJ(params0)
[15]:
print(val)
print(grad)
11.112829
[[ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          8.052231    1.5407639  -2.633849
  -1.7117175   1.5994252   0.37416115 -1.1727542   0.6812252   0.05781552]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          1.170242   -0.38775107 -1.3131913
  -0.13359495  0.8831538  -0.1225284  -0.49174562  0.42683142 -0.06464756]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.         -1.4127594  -0.9652585  -0.50624955
   0.6102758   0.50805783 -0.4469478  -0.17206581  0.43278244 -0.22273514]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.         -0.7552655  -0.19202262  0.47696215
   0.4672531  -0.21933757 -0.3337388   0.23176189  0.20692322 -0.3084491 ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.6831305   0.33028147  0.37160286
  -0.21041775 -0.41866505  0.13355729  0.28937018 -0.13324808 -0.16658844]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.11585653 -0.07828572 -0.16567144
  -0.2921306   0.05472778  0.24623685 -0.02784408 -0.21342102  0.09299421]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.         -0.3484568  -0.17058687 -0.03697217
   0.09520242  0.26396352 -0.07381138 -0.20417789  0.02290697  0.15813121]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.1712473   0.10132807  0.15602311
   0.08441953 -0.0273253  -0.17447795  0.01955975  0.12130421 -0.02494766]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.08062559  0.01353269 -0.03312859
  -0.09126864 -0.07907826  0.02552689  0.10252865 -0.02282601 -0.07314368]]

Normalize Objective#

To normalize our objective function value to something more understandable, we first run a simulation with no boxes to compute the focal point intensity in this case. Then, we construct a new objective function value that normalizes the raw intensity by this value, giving us an “intensity enhancement” factor. In this normalization, if our objective is given by “x”, it means that the intensity at the focal point is “x” times stronger with our design than with no structures at all.

[16]:
params_empty = -1e5 * np.ones_like(params0)
J_empty = np.array(J(params_empty))

def J_normalized(params):
    return J(params) / J_empty

val_normalized = val / J_empty

dJ_normalized = jax.value_and_grad(J_normalized)

print(val_normalized)
0.8701368

Optimization#

With our objective function set up, we can now run the optimization.

As before, we will optax’s “adam” optimization with initial parameters of all zeros (corresponding to boxes of side length S/2).

[17]:
import optax

# hyperparameters
num_steps = 18
learning_rate = 0.02

# initialize adam optimizer with starting parameters
params = np.array(params0)
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# store history
J_history = [val_normalized]
params_history = [params0]

for i in range(num_steps):

    # compute gradient and current objective funciton value
    value, gradient = dJ_normalized(params)

    # outputs
    print(f"step = {i + 1}")
    print(f"\tJ = {value:.4e}")
    print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")

    # compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
    updates, opt_state = optimizer.update(-gradient, opt_state, params)
    params = optax.apply_updates(params, updates)

    # save history
    J_history.append(value)
    params_history.append(params)

step = 1
        J = 8.7014e-01
        grad_norm = 7.5746e-01
step = 2
        J = 1.9787e+00
        grad_norm = 1.0905e+00
step = 3
        J = 4.7533e+00
        grad_norm = 1.5273e+00
step = 4
        J = 7.5140e+00
        grad_norm = 1.8209e+00
step = 5
        J = 1.3054e+01
        grad_norm = 2.8733e+00
step = 6
        J = 2.0436e+01
        grad_norm = 3.2374e+00
step = 7
        J = 2.7075e+01
        grad_norm = 2.9542e+00
step = 8
        J = 3.1865e+01
        grad_norm = 2.3616e+00
step = 9
        J = 3.2881e+01
        grad_norm = 3.0067e+00
step = 10
        J = 3.6570e+01
        grad_norm = 2.5087e+00
step = 11
        J = 3.9969e+01
        grad_norm = 2.7586e+00
step = 12
        J = 4.2636e+01
        grad_norm = 2.6617e+00
step = 13
        J = 4.3821e+01
        grad_norm = 2.7791e+00
step = 14
        J = 4.5719e+01
        grad_norm = 2.6449e+00
step = 15
        J = 4.7921e+01
        grad_norm = 2.1662e+00
step = 16
        J = 4.8882e+01
        grad_norm = 2.0180e+00
step = 17
        J = 4.8722e+01
        grad_norm = 2.4410e+00
step = 18
        J = 4.9012e+01
        grad_norm = 2.9489e+00
[18]:
params_after = params_history[-1]
[19]:
plt.plot(J_history)
plt.xlabel('iterations')
plt.ylabel('objective function (focusing intensity enhancement)')
plt.show()
../_images/notebooks_AdjointPlugin7Metalens_37_0.png
[20]:
sim_before = make_sim(0 * params_after, apply_symmetry=False).to_simulation()[0]
sim_after = make_sim(params_after, apply_symmetry=False).to_simulation()[0]

[21]:
f, (ax1, ax2) = plt.subplots(1, 2)

sim_before.plot(z=center_z, ax=ax1)
sim_after.plot(z=center_z, ax=ax2)

plt.show()
../_images/notebooks_AdjointPlugin7Metalens_39_0.png
[22]:
sim_after_mnt = sim_after.updated_copy(monitors=list(sim_after.monitors) +
[
    td.FieldMonitor(
        size=(0, td.inf, td.inf),
        center=(0,0,0),
        freqs=[f0],
        name='fields_yz',
    ),
    td.FieldMonitor(
        size=(td.inf, td.inf, 0),
        center=(0,0,focal_z),
        freqs=[f0],
        name='far_field',
    ),
])
[23]:
sim_data_after_mnt = web.run(sim_after_mnt, task_name='meta_near_field_after')
[21:11:10] Created task 'meta_near_field_after' with task_id       webapi.py:188
           'fdve-f56dc12f-663c-43ca-85f5-32f01b3062e7v1'.                       
[21:11:11] status = queued                                         webapi.py:361
[21:11:20] status = preprocess                                     webapi.py:355
[21:11:24] Maximum FlexCredit cost: 0.025. Use                     webapi.py:341
           'web.real_cost(task_id)' to get the billed FlexCredit                
           cost after a simulation run.                                         
           starting up solver                                      webapi.py:377
           running solver                                          webapi.py:386
           To cancel the simulation, use 'web.abort(task_id)' or   webapi.py:387
           'web.delete(task_id)' or abort/delete the task in the                
           web UI. Terminating the Python script will not stop the              
           job running on the cloud.                                            
[21:11:34] early shutoff detected, exiting.                        webapi.py:404
           status = postprocess                                    webapi.py:419
[21:11:38] status = success                                        webapi.py:426
[21:11:44] loading SimulationData from simulation_data.hdf5        webapi.py:590
[24]:
fig, (ax1, ax2) = plt.subplots(1, 2, tight_layout=True, figsize=(10, 4))
sim_data_after_mnt.plot_field('far_field', 'int', vmax=105, ax=ax1)
sim_data_after_mnt.plot_field('fields_yz', 'int', vmax=180, ax=ax2)
plt.show()
../_images/notebooks_AdjointPlugin7Metalens_42_0.png

Conclusions#

We notice that our metalens does quite well at focusing at this high NA! For the purposes of demonstration, this is quite a small device, but the same the same principle can be applied to optimize a much larger metalens.

For more case studies using the adjoint plugin, see the

[ ]: