Inverse design optimization of a metalens#
In this notebook, we will use inverse design and the Tidy3D adjoint
plugin to design a high numerical aperture (NA) metalens for optimal focusing to a point. This demo also introduces how to use the adjoint
plugin for objective functions that depend on the FieldMonitor
outputs.
We will follow the basic set up from Mansouree et al. “Large-Scale Parametrized Metasurface Design Using Adjoint Optimization”. The published paper can be found here and the arxiv preprint can be found here.
Setup#
We first perform basic imports of the packages needed.
[1]:
# standard python imports
import numpy as np
from numpy import random
import matplotlib.pyplot as plt
import tidy3d as td
from tidy3d import web
import tidy3d.plugins.adjoint as tda
from tidy3d.plugins.adjoint.web import run as run_adj
import jax
import jax.numpy as jnp
The metalens design consists of a rectangular array of Si rectangular prisms sitting on an SiO2 substrate.
Here we define all of the basic parameters of the setup, including the wavelength, NA, geometrical dimensions, and material properties.
[2]:
# 1 nanometer in units of microns (for conversion)
nm = 1e-3
# free space central wavelength
wavelength = 850 * nm
# desired numerical aperture
NA = 0.94
# shape parameters of metalens unit cell (um) (refer to image above and see paper for details)
H = 430 * nm
S = 320 * nm
# space between bottom PML and substrate (-z)
space_below_sub = 1 * wavelength
# thickness of substrate between source and Si unit cells
thickness_sub = 100 * nm
# side length of entire metalens (um)
side_length = 6
# Number of unit cells in each x and y direction (NxN grid)
N = int(side_length / S)
print(f"for diameter of {side_length:.1f} um, have {N} cells per side")
print(f"full metalens has area of {side_length**2:.1f} um^2 and {N*N} total cells")
# Define material properties at 600 nm
n_Si = 3.84
n_SiO2 = 1.46
air = td.Medium(permittivity=1.0)
SiO2 = td.Medium(permittivity=n_SiO2**2)
Si = td.Medium(permittivity=n_Si**2)
# define symmetry
symmetry = (-1, 1, 0)
for diameter of 6.0 um, have 18 cells per side
full metalens has area of 36.0 um^2 and 324 total cells
Next, we will compute some important quantities derived from these parameters.
[3]:
# using the wavelength in microns, one can use td.C_0 (um/s) to get frequency in Hz
# wavelength_meters = wavelength * meters
f0 = td.C_0 / wavelength
# Compute the domain size in x, y (note: round down from side_length)
length_xy = N * S
# focal length given diameter and numerical aperture
focal_length = length_xy / 2 / NA * np.sqrt(1 - NA**2)
# total domain size in z: (space -> substrate -> unit cell -> 1.7 focal lengths)
length_z = space_below_sub + thickness_sub + H + 1.7 * focal_length
# construct simulation size array
sim_size = (length_xy, length_xy, length_z)
Create Metalens Geometry#
Now we will define the structures in our simulation. We will first generate the substrate as a regular td.Box
.
[4]:
# define substrate
substrate = td.Structure(
geometry=td.Box.from_bounds(
rmin=(-td.inf, -td.inf, -1000),
rmax=(+td.inf, +td.inf, -length_z / 2 + space_below_sub + thickness_sub)
),
medium=SiO2,
)
Next, we will write a function to make a list of JaxStructure
objects corresponding to each unit cell.
Note that the adjoint plugin does not yet support
GeometryGroup
forJaxBox
, so we will keep them as individualJaxStructure
objects for now.
[5]:
# define coordinates of each unit cell
centers_x = S * np.arange(N) - length_xy / 2.0 + S / 2.0
centers_y = S * np.arange(N) - length_xy / 2.0 + S / 2.0
center_z = -length_z / 2 + space_below_sub + thickness_sub + H / 2.0
focal_z = center_z + H / 2 + focal_length
x_centers, y_centers = np.meshgrid(centers_x, centers_y, indexing="ij")
xs = x_centers.flatten()
ys = y_centers.flatten()
def get_sizes(params):
"""Returns the actual side lengths of the boxes as a function of design parameters from (-inf, +inf)."""
return S * (jnp.tanh(params) + 1.0) / 2.0
# initially, start with parameters of 0 (all boxes have side length S/2)
params0 = 0 * np.ones(x_centers.shape)
def make_structures(params, apply_symmetry: bool = True):
"""Make the JaxStructure objects that will be used as .input_structures."""
sizes = get_sizes(params)
nx, ny = sizes.shape
geometries = []
for i in range(nx):
i_quad = max(i, nx - 1 - i)
for j in range(ny):
j_quad = max(j, ny - 1 - j)
size = sizes[i_quad, j_quad]
x0 = x_centers[i, j]
y0 = y_centers[i, j]
if apply_symmetry and symmetry[0] != 0 and x0 < -S/2:
continue
if apply_symmetry and symmetry[1] != 0 and y0 < -S/2:
continue
geometry = tda.JaxBox(
center=(x0, y0, center_z),
size=(size, size, H)
)
geometries.append(geometry)
medium = tda.JaxMedium(permittivity=n_Si**2)
return [tda.JaxStructure(medium=medium, geometry=geo) for geo in geometries]
structures = make_structures(params0)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Define grid specification#
We define the grid based on the properties of the geometry. The metalens is quasi-periodic in x and y, in that we have clearly defined unit cells, but each is slightly modified from its neighbors. Such structures are best resolved with a grid that matches the periodicity, which is why we use a uniform grid in x and y. In z, we use the automatic nonuniform grid that will place a higher grid density around the metalens region, and a lower one in the air region away from the metalens. To speed up the auto meshing in the region with the pillars, we put an override box in the grid specification.
[6]:
# steps per unit cell along x and y
grids_per_unit_length = 10
# uniform mesh in x and y
grid_x = td.UniformGrid(dl=S / grids_per_unit_length)
grid_y = td.UniformGrid(dl=S / grids_per_unit_length)
# in z, use an automatic nonuniform mesh with the wavelength being the "unit length"
grid_z = td.AutoGrid(min_steps_per_wvl=grids_per_unit_length)
# we need to supply the wavelength because of the automatic mesh in z
grid_spec = td.GridSpec(
wavelength=wavelength, grid_x=grid_x, grid_y=grid_y, grid_z=grid_z
)
# put an override box over the pillars to avoid parsing a large amount of structures in the mesher
grid_spec = grid_spec.copy(
update=dict(
override_structures=[
td.Structure(
geometry=td.Box.from_bounds(
rmin=(-td.inf, -td.inf, -length_z / 2 + space_below_sub),
rmax=(td.inf, td.inf, center_z + H / 2),
),
medium=Si,
)
]
)
)
Define Source#
Now we define the incident fields. We simply use an x-polarized, normally incident plane wave with Gaussian time dependence centered at our central frequency. For more details, see the plane wave source documentation and the gaussian source documentation
[7]:
# Bandwidth in Hz
fwidth = f0 / 10.0
# time dependence of source
gaussian = td.GaussianPulse(freq0=f0, fwidth=fwidth, phase=0)
source = td.PlaneWave(
source_time=gaussian,
size=(td.inf, td.inf, 0),
center=(0, 0, -length_z / 2 + space_below_sub / 10.0),
direction="+",
pol_angle=0,
)
run_time = 50 / fwidth
Define Monitors#
Now we define the monitor that measures field output from the FDTD simulation. For simplicity, we use measure the fields at the central frequency at the focal spot.
This will be the monitor that we use in our objective function, so it will go into JaxSimulation.output_monitors
.
[8]:
# To decrease the amount of data stored, only store the E fields
fields = ["Ex", "Ey", "Ez"]
monitor_focal = td.FieldMonitor(
center=[0.0, 0.0, focal_z],
size=[0, 0, 0],
freqs=[f0],
name="focal_point",
fields=fields,
colocate=False,
)
Create Simulation#
Now we can put everything together and define a JaxSimulation object to be run.
We get a number of warnings about structures being too close to the PML. In FDTD simulations, this can result in instability, as PML are absorbing for propagating fields, but can be amplifying for evanescent fields. This particular simulation runs without any issues even with PML on the sides, but it is best to heed these warnings to avoid problems. There are two ways that we can fix the simulation: one is to just put some space between the last of the metalens boxes and the PML. The other is to use adiabatic absorbers on the sides, which are always stable. The only downside of the absorbers is that they are slightly thicker than the PML, making the overall simulation size slightly larger. This is why we only put them along x and y, while we leave the PML in z.
Note: we add symmetry of (-1, 1, 0) to speed up the simulation by approximately 4x taking into account the symmetry in our source and dielectric function.
[9]:
def make_sim(angles, apply_symmetry: bool=True):
metalens = make_structures(angles, apply_symmetry=apply_symmetry)
sim = tda.JaxSimulation(
size=sim_size,
grid_spec=grid_spec,
structures=[substrate],
input_structures=metalens,
sources=[source],
monitors=[],
output_monitors=[monitor_focal],
run_time=run_time,
boundary_spec=td.BoundarySpec(
x=td.Boundary.absorber(), y=td.Boundary.absorber(), z=td.Boundary.pml()
),
symmetry=symmetry,
)
return sim
sim = make_sim(params0)
[20:38:51] WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
The warnings are just letting us know that we are using uniform grid along x and y (per our spec) even though the override structures have some extent in these dimensions. We can ignore as this is intended.
Visualize Geometry#
Lets take a look and make sure everything is defined properly.
[10]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 6))
sim.plot(x=0.1, ax=ax1)
sim.plot(y=0.1, ax=ax2)
sim.plot(z=-length_z / 2 + space_below_sub + thickness_sub + H / 2, ax=ax3)
plt.show()
WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
[20:38:52] WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 x-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
WARNING: Override structures take no effect along grid_spec.py:555 y-axis. If intending to apply override structures to this axis, use 'AutoGrid'.
Objective Function#
Now that our simulation is set up, we can define our objective function over the JaxSimulationData
results.
We first write a function to take a JaxSimulationData
object and return the intensity at the focal point.
Next, we write a function to
Set up our simulation given our design parameters.
Run the simulation through the adjoint
run
function.Compute and return the intensity at the focal point.
[11]:
# turn off warnings as we understand they are just about AutoGrid and can be ignored in our case
td.config.logging_level = "ERROR"
def measure_focal_intensity(sim_data: tda.JaxSimulationData) -> float:
"""Measures electric intensity at focal point."""
return jnp.sum(sim_data.get_intensity('focal_point').values)
def J(params) -> float:
"""Objective function, returns intensity at focal point as a function of params."""
sim = make_sim(params)
sim_data = run_adj(sim, task_name='metalens_adj', verbose=False)
return measure_focal_intensity(sim_data)
We first run our function to test that it works and see the starting value.
[12]:
J(params0)
[12]:
Array(11.112829, dtype=float32)
Next, we use jax
to get a function returning the objective value and its gradient, given some parameters.
[13]:
dJ = jax.value_and_grad(J)
And try it out.
[14]:
val, grad = dJ(params0)
[15]:
print(val)
print(grad)
11.112829
[[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 8.052231 1.5407639 -2.633849
-1.7117175 1.5994252 0.37416115 -1.1727542 0.6812252 0.05781552]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 1.170242 -0.38775107 -1.3131913
-0.13359495 0.8831538 -0.1225284 -0.49174562 0.42683142 -0.06464756]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. -1.4127594 -0.9652585 -0.50624955
0.6102758 0.50805783 -0.4469478 -0.17206581 0.43278244 -0.22273514]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. -0.7552655 -0.19202262 0.47696215
0.4672531 -0.21933757 -0.3337388 0.23176189 0.20692322 -0.3084491 ]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0.6831305 0.33028147 0.37160286
-0.21041775 -0.41866505 0.13355729 0.28937018 -0.13324808 -0.16658844]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0.11585653 -0.07828572 -0.16567144
-0.2921306 0.05472778 0.24623685 -0.02784408 -0.21342102 0.09299421]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. -0.3484568 -0.17058687 -0.03697217
0.09520242 0.26396352 -0.07381138 -0.20417789 0.02290697 0.15813121]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0.1712473 0.10132807 0.15602311
0.08441953 -0.0273253 -0.17447795 0.01955975 0.12130421 -0.02494766]
[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0.08062559 0.01353269 -0.03312859
-0.09126864 -0.07907826 0.02552689 0.10252865 -0.02282601 -0.07314368]]
Normalize Objective#
To normalize our objective function value to something more understandable, we first run a simulation with no boxes to compute the focal point intensity in this case. Then, we construct a new objective function value that normalizes the raw intensity by this value, giving us an “intensity enhancement” factor. In this normalization, if our objective is given by “x”, it means that the intensity at the focal point is “x” times stronger with our design than with no structures at all.
[16]:
params_empty = -1e5 * np.ones_like(params0)
J_empty = np.array(J(params_empty))
def J_normalized(params):
return J(params) / J_empty
val_normalized = val / J_empty
dJ_normalized = jax.value_and_grad(J_normalized)
print(val_normalized)
0.8701368
Optimization#
With our objective function set up, we can now run the optimization.
As before, we will optax
’s “adam” optimization with initial parameters of all zeros (corresponding to boxes of side length S/2
).
[17]:
import optax
# hyperparameters
num_steps = 18
learning_rate = 0.02
# initialize adam optimizer with starting parameters
params = np.array(params0)
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
# store history
J_history = [val_normalized]
params_history = [params0]
for i in range(num_steps):
# compute gradient and current objective funciton value
value, gradient = dJ_normalized(params)
# outputs
print(f"step = {i + 1}")
print(f"\tJ = {value:.4e}")
print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")
# compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
updates, opt_state = optimizer.update(-gradient, opt_state, params)
params = optax.apply_updates(params, updates)
# save history
J_history.append(value)
params_history.append(params)
step = 1
J = 8.7014e-01
grad_norm = 7.5746e-01
step = 2
J = 1.9787e+00
grad_norm = 1.0905e+00
step = 3
J = 4.7533e+00
grad_norm = 1.5273e+00
step = 4
J = 7.5140e+00
grad_norm = 1.8209e+00
step = 5
J = 1.3054e+01
grad_norm = 2.8733e+00
step = 6
J = 2.0436e+01
grad_norm = 3.2374e+00
step = 7
J = 2.7075e+01
grad_norm = 2.9542e+00
step = 8
J = 3.1865e+01
grad_norm = 2.3616e+00
step = 9
J = 3.2881e+01
grad_norm = 3.0067e+00
step = 10
J = 3.6570e+01
grad_norm = 2.5087e+00
step = 11
J = 3.9969e+01
grad_norm = 2.7586e+00
step = 12
J = 4.2636e+01
grad_norm = 2.6617e+00
step = 13
J = 4.3821e+01
grad_norm = 2.7791e+00
step = 14
J = 4.5719e+01
grad_norm = 2.6449e+00
step = 15
J = 4.7921e+01
grad_norm = 2.1662e+00
step = 16
J = 4.8882e+01
grad_norm = 2.0180e+00
step = 17
J = 4.8722e+01
grad_norm = 2.4410e+00
step = 18
J = 4.9012e+01
grad_norm = 2.9489e+00
[18]:
params_after = params_history[-1]
[19]:
plt.plot(J_history)
plt.xlabel('iterations')
plt.ylabel('objective function (focusing intensity enhancement)')
plt.show()
[20]:
sim_before = make_sim(0 * params_after, apply_symmetry=False).to_simulation()[0]
sim_after = make_sim(params_after, apply_symmetry=False).to_simulation()[0]
[21]:
f, (ax1, ax2) = plt.subplots(1, 2)
sim_before.plot(z=center_z, ax=ax1)
sim_after.plot(z=center_z, ax=ax2)
plt.show()
[22]:
sim_after_mnt = sim_after.updated_copy(monitors=list(sim_after.monitors) +
[
td.FieldMonitor(
size=(0, td.inf, td.inf),
center=(0,0,0),
freqs=[f0],
name='fields_yz',
),
td.FieldMonitor(
size=(td.inf, td.inf, 0),
center=(0,0,focal_z),
freqs=[f0],
name='far_field',
),
])
[23]:
sim_data_after_mnt = web.run(sim_after_mnt, task_name='meta_near_field_after')
View task using web UI at 'https://tidy3d.simulation.cl webapi.py:190 oud/workbench?taskId=fdve-f56dc12f-663c-43ca-85f5-32f01 b3062e7v1'.
[24]:
fig, (ax1, ax2) = plt.subplots(1, 2, tight_layout=True, figsize=(10, 4))
sim_data_after_mnt.plot_field('far_field', 'int', vmax=105, ax=ax1)
sim_data_after_mnt.plot_field('fields_yz', 'int', vmax=180, ax=ax2)
plt.show()
Conclusions#
We notice that our metalens does quite well at focusing at this high NA! For the purposes of demonstration, this is quite a small device, but the same the same principle can be applied to optimize a much larger metalens.
For more case studies using the adjoint
plugin, see the
[ ]: