"""Adjoint-specific webapi."""
from typing import Tuple, Dict, List
from functools import partial
import tempfile
import pydantic.v1 as pd
from jax import custom_vjp
from jax.tree_util import register_pytree_node_class
from ...components.simulation import Simulation
from ...components.data.sim_data import SimulationData
from tidy3d.web.api.webapi import run as web_run
from tidy3d.web.api.webapi import wait_for_connection
from tidy3d.web.core.s3utils import download_file, upload_file
from tidy3d.web.api.asynchronous import run_async as web_run_async
from ...web.api.container import BatchData, DEFAULT_DATA_DIR, Job, Batch
from ...components.types import Literal
from .components.base import JaxObject
from .components.simulation import JaxSimulation, JaxInfo
from .components.data.sim_data import JaxSimulationData
# file names and paths for server side adjoint
SIM_VJP_FILE = "output/jax_sim_vjp.hdf5"
JAX_INFO_FILE = "jax_info.json"
@register_pytree_node_class
class RunResidual(JaxObject):
"""Class to store extra data needed to pass between the forward and backward adjoint run."""
fwd_task_id: str = pd.Field(
..., title="Forward task_id", description="task_id of the forward simulation."
)
@register_pytree_node_class
class RunResidualBatch(JaxObject):
"""Class to store extra data needed to pass between the forward and backward adjoint run."""
fwd_task_ids: Tuple[str, ...] = pd.Field(
..., title="Forward task_ids", description="task_ids of the forward simulations."
)
@register_pytree_node_class
class RunResidualAsync(JaxObject):
"""Class to store extra data needed to pass between the forward and backward adjoint run."""
fwd_task_ids: Dict[str, str] = pd.Field(
..., title="Forward task_ids", description="task_ids of the forward simulation for async."
)
def _task_name_fwd(task_name: str) -> str:
"""task name for forward run as a function of the original task name."""
return str(task_name) + "_fwd"
def _task_name_adj(task_name: str) -> str:
"""task name for adjoint run as a function of the original task name."""
return str(task_name) + "_adj"
def tidy3d_run_fn(simulation: Simulation, task_name: str, **kwargs) -> SimulationData:
"""Run a regular :class:`.Simulation` after conversion from jax type."""
return web_run(simulation=simulation, task_name=task_name, **kwargs)
def tidy3d_run_async_fn(simulations: Dict[str, Simulation], **kwargs) -> BatchData:
"""Run a set of regular :class:`.Simulation` objects after conversion from jax type."""
return web_run_async(simulations=simulations, **kwargs)
""" Running a single simulation using web.run. """
[docs]
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
def run(
simulation: JaxSimulation,
task_name: str,
folder_name: str = "default",
path: str = "simulation_data.hdf5",
callback_url: str = None,
verbose: bool = True,
) -> JaxSimulationData:
"""Submits a :class:`.JaxSimulation` to server, starts running, monitors progress, downloads,
and loads results as a :class:`.JaxSimulationData` object.
Can be included within a function that will have ``jax.grad`` applied.
Parameters
----------
simulation : :class:`.JaxSimulation`
Simulation to upload to server.
task_name : str
Name of task.
path : str = "simulation_data.hdf5"
Path to download results file (.hdf5), including filename.
folder_name : str = "default"
Name of folder to store task on web UI.
callback_url : str = None
Http PUT url to receive simulation finish event. The body content is a json file with
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
Returns
-------
:class:`.JaxSimulationData`
Object containing solver results for the supplied :class:`.JaxSimulation`.
"""
sim, jax_info = simulation.to_simulation()
sim_data = tidy3d_run_fn(
simulation=sim,
task_name=str(task_name),
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
return JaxSimulationData.from_sim_data(sim_data, jax_info)
def run_fwd(
simulation: JaxSimulation,
task_name: str,
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
) -> Tuple[JaxSimulationData, Tuple[RunResidual]]:
"""Run forward pass and stash extra objects for the backwards pass."""
sim_fwd, jax_info_fwd, jax_info_orig = simulation.to_simulation_fwd()
sim_data_orig, task_id = webapi_run_adjoint_fwd(
simulation=sim_fwd,
jax_info=jax_info_fwd,
task_name=str(task_name),
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
res = RunResidual(fwd_task_id=task_id)
jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
return jax_sim_data_orig, (res,)
def run_bwd(
task_name: str,
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
res: tuple,
sim_data_vjp: JaxSimulationData,
) -> Tuple[JaxSimulation]:
"""Run backward pass and return simulation storing vjp of the objective w.r.t. the sim."""
fwd_task_id = res[0].fwd_task_id
fwidth_adj = sim_data_vjp.simulation._fwidth_adjoint
run_time_adj = sim_data_vjp.simulation._run_time_adjoint
jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)
sim_adj, jax_info_adj = jax_sim_adj.to_simulation()
sim_vjp = webapi_run_adjoint_bwd(
sim_adj=sim_adj,
jax_info_adj=jax_info_adj,
fwd_task_id=fwd_task_id,
task_name=_task_name_adj(task_name),
folder_name=folder_name,
callback_url=callback_url,
verbose=verbose,
)
sim_vjp = sim_data_vjp.simulation.updated_copy(input_structures=sim_vjp.input_structures)
return (sim_vjp,)
"""TO DO: IMPLEMENT this section IN WEBAPI """
@wait_for_connection
def upload_jax_info(jax_info: JaxInfo, task_id: str, verbose: bool) -> None:
"""Upload jax_info for a task with a given task_id."""
data_file = tempfile.NamedTemporaryFile(suffix=".json")
data_file.close()
jax_info.to_file(data_file.name)
upload_file(
task_id,
data_file.name,
JAX_INFO_FILE,
verbose=verbose,
)
@wait_for_connection
def download_sim_vjp(task_id: str, verbose: bool) -> JaxSimulation:
"""Download the vjp loaded simulation from the server to return to jax."""
data_file = tempfile.NamedTemporaryFile(suffix=".hdf5")
data_file.close()
download_file(task_id, SIM_VJP_FILE, to_file=data_file.name, verbose=verbose)
return JaxSimulation.from_file(data_file.name)
AdjointSimulationType = Literal["tidy3d", "adjoint_fwd", "adjoint_bwd"]
class AdjointJob(Job):
"""Job that uploads a jax_info object and also includes new fields for adjoint tasks."""
simulation_type: AdjointSimulationType = pd.Field(
None,
title="Simulation Type",
description="Type of simulation, used internally only.",
)
jax_info: JaxInfo = pd.Field(
None,
title="Jax Info",
description="Container of information needed to reconstruct jax simulation.",
)
def start(self) -> None:
"""Start running a :class:`AdjointJob`. after uploading jax info.
Note
----
To monitor progress of the :class:`Job`, call :meth:`Job.monitor` after started.
"""
upload_jax_info(task_id=self.task_id, jax_info=self.jax_info, verbose=self.verbose)
super().start()
class AdjointBatch(Batch):
"""Batch that uploads a jax_info object and also includes new fields for adjoint tasks."""
simulation_type: AdjointSimulationType = pd.Field(
"tidy3d",
title="Simulation Type",
description="Type of simulation, used internally only.",
)
jobs: Dict[str, AdjointJob] = pd.Field(
None,
title="Simulations",
description="Mapping of task names to individual AdjointJob object for each task "
"in the batch. Set by ``AdjointBatch.upload``, leave as None.",
)
jax_infos: Dict[str, JaxInfo] = pd.Field(
...,
title="Jax Info Dict",
description="Containers of information needed to reconstruct JaxSimulation for each item.",
)
@pd.root_validator()
def _add_jax_infos(cls, values) -> None:
"""Add jax_info fields to the uploaded jobs."""
jax_infos = values.get("jax_infos")
jobs = values.get("jobs")
if jobs is None:
return values
for task_name, job in jobs.items():
jax_info = jax_infos[task_name]
values["jobs"][task_name] = job.updated_copy(jax_info=jax_info)
return values
def start(self) -> None:
"""Start running all tasks in the :class:`Batch`.
Note
----
To monitor the running simulations, can call :meth:`Batch.monitor`.
"""
for _, job in self.jobs.items():
upload_jax_info(task_id=job.task_id, jax_info=job.jax_info, verbose=self.verbose)
job.start()
def webapi_run_adjoint_fwd(
simulation: Simulation,
jax_info: JaxInfo,
task_name: str,
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
) -> Dict[str, float]:
"""Runs the forward simulation on our servers, stores the gradient data for later."""
job = AdjointJob(
simulation=simulation,
task_name=task_name,
folder_name=folder_name,
callback_url=callback_url,
verbose=verbose,
simulation_type="adjoint_fwd",
jax_info=jax_info,
)
sim_data = job.run()
return sim_data, job.task_id
def webapi_run_adjoint_bwd(
sim_adj: Simulation,
jax_info_adj: JaxInfo,
fwd_task_id: str,
task_name: str,
folder_name: str,
callback_url: str,
verbose: bool,
) -> JaxSimulation:
"""Runs adjoint simulation on our servers, grabs the gradient data from fwd for processing."""
job = AdjointJob(
simulation=sim_adj,
task_name=task_name,
folder_name=folder_name,
callback_url=callback_url,
verbose=verbose,
simulation_type="adjoint_bwd",
parent_tasks=[fwd_task_id],
jax_info=jax_info_adj,
)
job.start()
job.monitor()
sim_vjp = download_sim_vjp(task_id=job.task_id, verbose=verbose)
return sim_vjp
""" END WEBAPI ADDITIONS """
# register the custom forward and backward functions
run.defvjp(run_fwd, run_bwd)
""" Running a batch of simulations using web.run_async. """
def _task_name_orig(index: int):
"""Task name as function of index into simulations. Note: for original must be int."""
return int(index)
[docs]
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
def run_async(
simulations: Tuple[JaxSimulation, ...],
folder_name: str = "default",
path_dir: str = DEFAULT_DATA_DIR,
callback_url: str = None,
verbose: bool = True,
num_workers: int = None,
) -> Tuple[JaxSimulationData, ...]:
"""Submits a set of :class:`.JaxSimulation` objects to server, starts running,
monitors progress, downloads, and loads results
as a tuple of :class:`.JaxSimulationData` objects.
Parameters
----------
simulations : Tuple[:class:`.JaxSimulation`, ...]
Collection of :class:`.JaxSimulations` to run asynchronously.
folder_name : str = "default"
Name of folder to store each task on web UI.
path_dir : str
Base directory where data will be downloaded, by default current working directory.
callback_url : str = None
Http PUT url to receive simulation finish event. The body content is a json file with
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
num_workers: int = None
Number of tasks to submit at once in a batch, if None, will run all at the same time.
Note
----
This is an experimental feature and may not work on all systems or configurations.
For more details, see ``https://realpython.com/async-io-python/``.
Returns
------
Tuple[:class:`.JaxSimulationData`, ...]
Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`.
"""
# get task names, the td.Simulation, and JaxInfo for all supplied simulations
task_names = [str(_task_name_orig(i)) for i in range(len(simulations))]
task_info = [jax_sim.to_simulation() for jax_sim in simulations]
# process this into dictionaries of task_name -> Simulation and task_name -> JaxInfo
sims, jax_infos = list(zip(*task_info))
sims_tidy3d = dict(zip(task_names, sims))
jax_infos = dict(zip(task_names, jax_infos))
# run using regular tidy3d simulation running fn
batch_data_tidy3d = tidy3d_run_async_fn(
simulations=sims_tidy3d,
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
)
# convert back to jax type and return
jax_batch_data = []
for i in range(len(simulations)):
task_name = str(_task_name_orig(i))
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)
return jax_batch_data
def run_async_fwd(
simulations: Tuple[JaxSimulation, ...],
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
num_workers: int,
) -> Tuple[Tuple[JaxSimulationData, ...], RunResidualBatch]:
"""Run forward pass and stash extra objects for the backwards pass."""
jax_infos_orig = []
sims_fwd = []
jax_infos_fwd = []
for simulation in simulations:
sim_fwd, jax_info_fwd, jax_info_orig = simulation.to_simulation_fwd()
jax_infos_orig.append(jax_info_orig)
sims_fwd.append(sim_fwd)
jax_infos_fwd.append(jax_info_fwd)
batch_data_orig, fwd_task_ids = webapi_run_async_adjoint_fwd(
simulations=sims_fwd,
jax_infos=jax_infos_fwd,
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
)
batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()]
jax_batch_data_orig = []
for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig):
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
jax_batch_data_orig.append(jax_sim_data)
residual = RunResidualBatch(fwd_task_ids=fwd_task_ids)
return jax_batch_data_orig, (residual,)
def run_async_bwd(
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
num_workers: int,
res: tuple,
batch_data_vjp: Tuple[JaxSimulationData, ...],
) -> Tuple[Dict[str, JaxSimulation]]:
"""Run backward pass and return simulation storing vjp of the objective w.r.t. the sim."""
fwd_task_ids = res[0].fwd_task_ids
sims_adj = []
jax_infos_adj = []
parent_tasks_adj = []
for sim_data_vjp, fwd_task_id in zip(batch_data_vjp, fwd_task_ids):
parent_tasks_adj.append([str(fwd_task_id)])
fwidth_adj = sim_data_vjp.simulation._fwidth_adjoint
run_time_adj = sim_data_vjp.simulation._run_time_adjoint
jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)
sim_adj, jax_info_adj = jax_sim_adj.to_simulation()
sims_adj.append(sim_adj)
jax_infos_adj.append(jax_info_adj)
sims_vjp = webapi_run_async_adjoint_bwd(
simulations=sims_adj,
jax_infos=jax_infos_adj,
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
parent_tasks=parent_tasks_adj,
)
# update the JaxSimulation.input_structures in the sim_data_vjps using the adjoint returned vals
sims_vjp_updated = []
for sim_vjp, sim_data_vjp in zip(sims_vjp, batch_data_vjp):
sim_vjp_orig = sim_data_vjp.simulation
sim_vjp_updated = sim_vjp_orig.updated_copy(input_structures=sim_vjp.input_structures)
sims_vjp_updated.append(sim_vjp_updated)
return (sims_vjp_updated,)
def webapi_run_async_adjoint_fwd(
simulations: Tuple[Simulation, ...],
jax_infos: Tuple[JaxInfo, ...],
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
) -> Tuple[BatchData, Dict[str, str]]:
"""Runs the forward simulations on our servers, stores the gradient data for later."""
task_names = [str(_task_name_orig(i)) for i in range(len(simulations))]
simulations = dict(zip(task_names, simulations))
jax_infos = dict(zip(task_names, jax_infos))
batch = AdjointBatch(
simulations=simulations,
jax_infos=jax_infos,
folder_name=folder_name,
callback_url=callback_url,
verbose=verbose,
simulation_type="adjoint_fwd",
)
batch_data_orig = batch.run(path_dir=path_dir)
return batch_data_orig, tuple(batch_data_orig.task_ids.values())
def webapi_run_async_adjoint_bwd(
simulations: Tuple[Simulation, ...],
jax_infos: Tuple[JaxInfo, ...],
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
parent_tasks: List[List[str]],
) -> List[JaxSimulation]:
"""Runs the forward simulations on our servers, stores the gradient data for later."""
task_names = [str(i) for i in range(len(simulations))]
simulations = dict(zip(task_names, simulations))
jax_infos = dict(zip(task_names, jax_infos))
parent_tasks = [tuple(task_ids) for task_ids in parent_tasks]
parent_tasks_dict = dict(zip(task_names, parent_tasks))
batch = AdjointBatch(
simulations=simulations,
jax_infos=jax_infos,
folder_name=folder_name,
callback_url=callback_url,
verbose=verbose,
simulation_type="adjoint_bwd",
parent_tasks=parent_tasks_dict,
)
batch.start()
batch.monitor()
sims_vjp = []
for _, job in batch.jobs.items():
task_id = job.task_id
sim_vjp = download_sim_vjp(task_id=task_id, verbose=verbose)
sims_vjp.append(sim_vjp)
return sims_vjp
# register the custom forward and backward functions
run_async.defvjp(run_async_fwd, run_async_bwd)
""" Options to do the previous but all client side (mainly for testing / debugging)."""
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
def run_local(
simulation: JaxSimulation,
task_name: str,
folder_name: str = "default",
path: str = "simulation_data.hdf5",
callback_url: str = None,
verbose: bool = True,
) -> JaxSimulationData:
"""Submits a :class:`.JaxSimulation` to server, starts running, monitors progress, downloads,
and loads results as a :class:`.JaxSimulationData` object.
Can be included within a function that will have ``jax.grad`` applied.
Parameters
----------
simulation : :class:`.JaxSimulation`
Simulation to upload to server.
task_name : str
Name of task.
path : str = "simulation_data.hdf5"
Path to download results file (.hdf5), including filename.
folder_name : str = "default"
Name of folder to store task on web UI.
callback_url : str = None
Http PUT url to receive simulation finish event. The body content is a json file with
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
Returns
-------
:class:`.JaxSimulationData`
Object containing solver results for the supplied :class:`.JaxSimulation`.
"""
# convert to regular tidy3d (and accounting info)
sim_tidy3d, jax_info = simulation.to_simulation()
# run using regular tidy3d simulation running fn
sim_data_tidy3d = tidy3d_run_fn(
simulation=sim_tidy3d,
task_name=str(task_name),
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
# convert back to jax type and return
return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
def run_local_fwd(
simulation: JaxSimulation,
task_name: str,
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
) -> Tuple[JaxSimulationData, tuple]:
"""Run forward pass and stash extra objects for the backwards pass."""
# add the gradient monitors and run the forward simulation
grad_mnts = simulation.get_grad_monitors(
input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint
)
sim_fwd = simulation.updated_copy(**grad_mnts)
sim_data_fwd = run(
simulation=sim_fwd,
task_name=_task_name_fwd(task_name),
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
# remove the gradient data from the returned version (not needed)
sim_data_orig = sim_data_fwd.copy(update=dict(grad_data=(), simulation=simulation))
return sim_data_orig, (sim_data_fwd,)
def run_local_bwd(
task_name: str,
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
res: tuple,
sim_data_vjp: JaxSimulationData,
) -> Tuple[JaxSimulation]:
"""Run backward pass and return simulation storing vjp of the objective w.r.t. the sim."""
# grab the forward simulation and its gradient monitor data
(sim_data_fwd,) = res
grad_data_fwd = sim_data_fwd.grad_data_symmetry
grad_eps_data_fwd = sim_data_fwd.grad_eps_data_symmetry
# make and run adjoint simulation
fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint
run_time_adj = sim_data_fwd.simulation._run_time_adjoint
sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)
sim_data_adj = run(
simulation=sim_adj,
task_name=_task_name_adj(task_name),
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
sim_data_adj = sim_data_adj.normalize_adjoint_fields()
grad_data_adj = sim_data_adj.grad_data_symmetry
# get gradient and insert into the resulting simulation structure medium
sim_vjp = sim_data_vjp.simulation.store_vjp(grad_data_fwd, grad_data_adj, grad_eps_data_fwd)
return (sim_vjp,)
# register the custom forward and backward functions
run_local.defvjp(run_local_fwd, run_local_bwd)
""" Running a batch of simulations using web.run_async. """
def _task_name_orig_local(index: int, task_name_suffix: str = None):
"""Task name as function of index into simulations. Note: for original must be int."""
if task_name_suffix is not None:
return f"{index}{task_name_suffix}"
return int(index)
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 7)))
def run_async_local(
simulations: Tuple[JaxSimulation, ...],
folder_name: str = "default",
path_dir: str = DEFAULT_DATA_DIR,
callback_url: str = None,
verbose: bool = True,
num_workers: int = None,
task_name_suffix: str = None,
) -> Tuple[JaxSimulationData, ...]:
"""Submits a set of :class:`.JaxSimulation` objects to server, starts running,
monitors progress, downloads, and loads results
as a tuple of :class:`.JaxSimulationData` objects.
Uses ``asyncio`` to perform these steps asynchronously.
Parameters
----------
simulations : Tuple[:class:`.JaxSimulation`, ...]
Collection of :class:`.JaxSimulations` to run asynchronously.
folder_name : str = "default"
Name of folder to store each task on web UI.
path_dir : str
Base directory where data will be downloaded, by default current working directory.
callback_url : str = None
Http PUT url to receive simulation finish event. The body content is a json file with
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
num_workers: int = None
Number of tasks to submit at once in a batch, if None, will run all at the same time.
Note
----
This does the adjoint processing on the client side. So more data will be required for download.
Returns
------
Tuple[:class:`.JaxSimulationData`, ...]
Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`.
"""
simulations = {
_task_name_orig_local(i, task_name_suffix): sim for i, sim in enumerate(simulations)
}
task_info = {task_name: jax_sim.to_simulation() for task_name, jax_sim in simulations.items()}
# TODO: anyone know a better syntax for this?
sims_tidy3d = {str(task_name): sim for task_name, (sim, _) in task_info.items()}
jax_infos = {str(task_name): jax_info for task_name, (_, jax_info) in task_info.items()}
# run using regular tidy3d simulation running fn
batch_data_tidy3d = tidy3d_run_async_fn(
simulations=sims_tidy3d,
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
)
# convert back to jax type and return
task_name_suffix = "" if task_name_suffix is None else task_name_suffix
jax_batch_data = []
for i in range(len(simulations)):
task_name = _task_name_orig_local(i, task_name_suffix)
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)
return jax_batch_data
def run_async_local_fwd(
simulations: Tuple[JaxSimulation, ...],
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
num_workers: int,
task_name_suffix: str,
) -> Tuple[Dict[str, JaxSimulationData], tuple]:
"""Run forward pass and stash extra objects for the backwards pass."""
task_name_suffix_fwd = _task_name_fwd("")
sims_fwd = []
for simulation in simulations:
grad_mnts = simulation.get_grad_monitors(
input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint
)
sim_fwd = simulation.updated_copy(**grad_mnts)
sims_fwd.append(sim_fwd)
batch_data_fwd = run_async_local(
simulations=sims_fwd,
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
task_name_suffix=task_name_suffix_fwd,
)
# remove the gradient data from the returned version (not needed)
batch_data_orig = []
for i, sim_data_fwd in enumerate(batch_data_fwd):
sim_orig = simulations[i]
sim_data_orig = sim_data_fwd.copy(update=dict(grad_data=(), simulation=sim_orig))
batch_data_orig.append(sim_data_orig)
return batch_data_orig, (batch_data_fwd,)
def run_async_local_bwd(
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
num_workers: int,
task_name_suffix: str,
res: tuple,
batch_data_vjp: Tuple[JaxSimulationData, ...],
) -> Tuple[Dict[str, JaxSimulation]]:
"""Run backward pass and return simulation storing vjp of the objective w.r.t. the sim."""
# grab the forward simulation and its gradient monitor data
(batch_data_fwd,) = res
task_name_suffix_adj = _task_name_adj("")
grad_data_fwd = {}
grad_eps_data_fwd = {}
for i, sim_data_fwd in enumerate(batch_data_fwd):
grad_data_fwd[i] = sim_data_fwd.grad_data_symmetry
grad_eps_data_fwd[i] = sim_data_fwd.grad_eps_data_symmetry
# make and run adjoint simulation
sims_adj = []
for i, sim_data_fwd in enumerate(batch_data_fwd):
fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint
run_time_adj = sim_data_fwd.simulation._run_time_adjoint
sim_data_vjp = batch_data_vjp[i]
sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)
sims_adj.append(sim_adj)
batch_data_adj = run_async_local(
simulations=sims_adj,
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
task_name_suffix=task_name_suffix_adj,
)
sims_vjp = []
for i, (sim_data_fwd, sim_data_adj) in enumerate(zip(batch_data_fwd, batch_data_adj)):
sim_data_adj = sim_data_adj.normalize_adjoint_fields()
grad_data_fwd = sim_data_fwd.grad_data_symmetry
grad_data_adj = sim_data_adj.grad_data_symmetry
grad_data_eps_fwd = sim_data_fwd.grad_eps_data_symmetry
sim_data_vjp = batch_data_vjp[i]
sim_vjp = sim_data_vjp.simulation.store_vjp(grad_data_fwd, grad_data_adj, grad_data_eps_fwd)
sims_vjp.append(sim_vjp)
return (sims_vjp,)
# register the custom forward and backward functions
run_async_local.defvjp(run_async_local_fwd, run_async_local_bwd)