Source code for tidy3d.plugins.adjoint.web

"""Adjoint-specific webapi."""
from typing import Tuple, Dict, List
from functools import partial
import tempfile

import pydantic.v1 as pd
from jax import custom_vjp
from jax.tree_util import register_pytree_node_class

from ...components.simulation import Simulation
from ...components.data.sim_data import SimulationData
from tidy3d.web.api.webapi import run as web_run
from tidy3d.web.api.webapi import wait_for_connection
from tidy3d.web.core.s3utils import download_file, upload_file
from tidy3d.web.api.asynchronous import run_async as web_run_async
from ...web.api.container import BatchData, DEFAULT_DATA_DIR, Job, Batch
from ...components.types import Literal

from .components.base import JaxObject
from .components.simulation import JaxSimulation, JaxInfo
from .components.data.sim_data import JaxSimulationData


# file names and paths for server side adjoint
SIM_VJP_FILE = "output/jax_sim_vjp.hdf5"
JAX_INFO_FILE = "jax_info.json"


@register_pytree_node_class
class RunResidual(JaxObject):
    """Class to store extra data needed to pass between the forward and backward adjoint run."""

    fwd_task_id: str = pd.Field(
        ..., title="Forward task_id", description="task_id of the forward simulation."
    )


@register_pytree_node_class
class RunResidualBatch(JaxObject):
    """Class to store extra data needed to pass between the forward and backward adjoint run."""

    fwd_task_ids: Tuple[str, ...] = pd.Field(
        ..., title="Forward task_ids", description="task_ids of the forward simulations."
    )


@register_pytree_node_class
class RunResidualAsync(JaxObject):
    """Class to store extra data needed to pass between the forward and backward adjoint run."""

    fwd_task_ids: Dict[str, str] = pd.Field(
        ..., title="Forward task_ids", description="task_ids of the forward simulation for async."
    )


def _task_name_fwd(task_name: str) -> str:
    """task name for forward run as a function of the original task name."""
    return str(task_name) + "_fwd"


def _task_name_adj(task_name: str) -> str:
    """task name for adjoint run as a function of the original task name."""
    return str(task_name) + "_adj"


def tidy3d_run_fn(simulation: Simulation, task_name: str, **kwargs) -> SimulationData:
    """Run a regular :class:`.Simulation` after conversion from jax type."""
    return web_run(simulation=simulation, task_name=task_name, **kwargs)


def tidy3d_run_async_fn(simulations: Dict[str, Simulation], **kwargs) -> BatchData:
    """Run a set of regular :class:`.Simulation` objects after conversion from jax type."""
    return web_run_async(simulations=simulations, **kwargs)


""" Running a single simulation using web.run. """


[docs] @partial(custom_vjp, nondiff_argnums=tuple(range(1, 6))) def run( simulation: JaxSimulation, task_name: str, folder_name: str = "default", path: str = "simulation_data.hdf5", callback_url: str = None, verbose: bool = True, ) -> JaxSimulationData: """Submits a :class:`.JaxSimulation` to server, starts running, monitors progress, downloads, and loads results as a :class:`.JaxSimulationData` object. Can be included within a function that will have ``jax.grad`` applied. Parameters ---------- simulation : :class:`.JaxSimulation` Simulation to upload to server. task_name : str Name of task. path : str = "simulation_data.hdf5" Path to download results file (.hdf5), including filename. folder_name : str = "default" Name of folder to store task on web UI. callback_url : str = None Http PUT url to receive simulation finish event. The body content is a json file with fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. verbose : bool = True If `True`, will print progressbars and status, otherwise, will run silently. Returns ------- :class:`.JaxSimulationData` Object containing solver results for the supplied :class:`.JaxSimulation`. """ sim, jax_info = simulation.to_simulation() sim_data = tidy3d_run_fn( simulation=sim, task_name=str(task_name), folder_name=folder_name, path=path, callback_url=callback_url, verbose=verbose, ) return JaxSimulationData.from_sim_data(sim_data, jax_info)
def run_fwd( simulation: JaxSimulation, task_name: str, folder_name: str, path: str, callback_url: str, verbose: bool, ) -> Tuple[JaxSimulationData, Tuple[RunResidual]]: """Run forward pass and stash extra objects for the backwards pass.""" sim_fwd, jax_info_fwd, jax_info_orig = simulation.to_simulation_fwd() sim_data_orig, task_id = webapi_run_adjoint_fwd( simulation=sim_fwd, jax_info=jax_info_fwd, task_name=str(task_name), folder_name=folder_name, path=path, callback_url=callback_url, verbose=verbose, ) res = RunResidual(fwd_task_id=task_id) jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig) return jax_sim_data_orig, (res,) def run_bwd( task_name: str, folder_name: str, path: str, callback_url: str, verbose: bool, res: tuple, sim_data_vjp: JaxSimulationData, ) -> Tuple[JaxSimulation]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" fwd_task_id = res[0].fwd_task_id fwidth_adj = sim_data_vjp.simulation._fwidth_adjoint run_time_adj = sim_data_vjp.simulation._run_time_adjoint jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sim_adj, jax_info_adj = jax_sim_adj.to_simulation() sim_vjp = webapi_run_adjoint_bwd( sim_adj=sim_adj, jax_info_adj=jax_info_adj, fwd_task_id=fwd_task_id, task_name=_task_name_adj(task_name), folder_name=folder_name, callback_url=callback_url, verbose=verbose, ) sim_vjp = sim_data_vjp.simulation.updated_copy(input_structures=sim_vjp.input_structures) return (sim_vjp,) """TO DO: IMPLEMENT this section IN WEBAPI """ @wait_for_connection def upload_jax_info(jax_info: JaxInfo, task_id: str, verbose: bool) -> None: """Upload jax_info for a task with a given task_id.""" data_file = tempfile.NamedTemporaryFile(suffix=".json") data_file.close() jax_info.to_file(data_file.name) upload_file( task_id, data_file.name, JAX_INFO_FILE, verbose=verbose, ) @wait_for_connection def download_sim_vjp(task_id: str, verbose: bool) -> JaxSimulation: """Download the vjp loaded simulation from the server to return to jax.""" data_file = tempfile.NamedTemporaryFile(suffix=".hdf5") data_file.close() download_file(task_id, SIM_VJP_FILE, to_file=data_file.name, verbose=verbose) return JaxSimulation.from_file(data_file.name) AdjointSimulationType = Literal["tidy3d", "adjoint_fwd", "adjoint_bwd"] class AdjointJob(Job): """Job that uploads a jax_info object and also includes new fields for adjoint tasks.""" simulation_type: AdjointSimulationType = pd.Field( None, title="Simulation Type", description="Type of simulation, used internally only.", ) jax_info: JaxInfo = pd.Field( None, title="Jax Info", description="Container of information needed to reconstruct jax simulation.", ) def start(self) -> None: """Start running a :class:`AdjointJob`. after uploading jax info. Note ---- To monitor progress of the :class:`Job`, call :meth:`Job.monitor` after started. """ upload_jax_info(task_id=self.task_id, jax_info=self.jax_info, verbose=self.verbose) super().start() class AdjointBatch(Batch): """Batch that uploads a jax_info object and also includes new fields for adjoint tasks.""" simulation_type: AdjointSimulationType = pd.Field( "tidy3d", title="Simulation Type", description="Type of simulation, used internally only.", ) jobs: Dict[str, AdjointJob] = pd.Field( None, title="Simulations", description="Mapping of task names to individual AdjointJob object for each task " "in the batch. Set by ``AdjointBatch.upload``, leave as None.", ) jax_infos: Dict[str, JaxInfo] = pd.Field( ..., title="Jax Info Dict", description="Containers of information needed to reconstruct JaxSimulation for each item.", ) @pd.root_validator() def _add_jax_infos(cls, values) -> None: """Add jax_info fields to the uploaded jobs.""" jax_infos = values.get("jax_infos") jobs = values.get("jobs") if jobs is None: return values for task_name, job in jobs.items(): jax_info = jax_infos[task_name] values["jobs"][task_name] = job.updated_copy(jax_info=jax_info) return values def start(self) -> None: """Start running all tasks in the :class:`Batch`. Note ---- To monitor the running simulations, can call :meth:`Batch.monitor`. """ for _, job in self.jobs.items(): upload_jax_info(task_id=job.task_id, jax_info=job.jax_info, verbose=self.verbose) job.start() def webapi_run_adjoint_fwd( simulation: Simulation, jax_info: JaxInfo, task_name: str, folder_name: str, path: str, callback_url: str, verbose: bool, ) -> Dict[str, float]: """Runs the forward simulation on our servers, stores the gradient data for later.""" job = AdjointJob( simulation=simulation, task_name=task_name, folder_name=folder_name, callback_url=callback_url, verbose=verbose, simulation_type="adjoint_fwd", jax_info=jax_info, ) sim_data = job.run() return sim_data, job.task_id def webapi_run_adjoint_bwd( sim_adj: Simulation, jax_info_adj: JaxInfo, fwd_task_id: str, task_name: str, folder_name: str, callback_url: str, verbose: bool, ) -> JaxSimulation: """Runs adjoint simulation on our servers, grabs the gradient data from fwd for processing.""" job = AdjointJob( simulation=sim_adj, task_name=task_name, folder_name=folder_name, callback_url=callback_url, verbose=verbose, simulation_type="adjoint_bwd", parent_tasks=[fwd_task_id], jax_info=jax_info_adj, ) job.start() job.monitor() sim_vjp = download_sim_vjp(task_id=job.task_id, verbose=verbose) return sim_vjp """ END WEBAPI ADDITIONS """ # register the custom forward and backward functions run.defvjp(run_fwd, run_bwd) """ Running a batch of simulations using web.run_async. """ def _task_name_orig(index: int): """Task name as function of index into simulations. Note: for original must be int.""" return int(index)
[docs] @partial(custom_vjp, nondiff_argnums=tuple(range(1, 6))) def run_async( simulations: Tuple[JaxSimulation, ...], folder_name: str = "default", path_dir: str = DEFAULT_DATA_DIR, callback_url: str = None, verbose: bool = True, num_workers: int = None, ) -> Tuple[JaxSimulationData, ...]: """Submits a set of :class:`.JaxSimulation` objects to server, starts running, monitors progress, downloads, and loads results as a tuple of :class:`.JaxSimulationData` objects. Parameters ---------- simulations : Tuple[:class:`.JaxSimulation`, ...] Collection of :class:`.JaxSimulations` to run asynchronously. folder_name : str = "default" Name of folder to store each task on web UI. path_dir : str Base directory where data will be downloaded, by default current working directory. callback_url : str = None Http PUT url to receive simulation finish event. The body content is a json file with fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. verbose : bool = True If `True`, will print progressbars and status, otherwise, will run silently. num_workers: int = None Number of tasks to submit at once in a batch, if None, will run all at the same time. Note ---- This is an experimental feature and may not work on all systems or configurations. For more details, see ``https://realpython.com/async-io-python/``. Returns ------ Tuple[:class:`.JaxSimulationData`, ...] Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`. """ # get task names, the td.Simulation, and JaxInfo for all supplied simulations task_names = [str(_task_name_orig(i)) for i in range(len(simulations))] task_info = [jax_sim.to_simulation() for jax_sim in simulations] # process this into dictionaries of task_name -> Simulation and task_name -> JaxInfo sims, jax_infos = list(zip(*task_info)) sims_tidy3d = dict(zip(task_names, sims)) jax_infos = dict(zip(task_names, jax_infos)) # run using regular tidy3d simulation running fn batch_data_tidy3d = tidy3d_run_async_fn( simulations=sims_tidy3d, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, verbose=verbose, num_workers=num_workers, ) # convert back to jax type and return jax_batch_data = [] for i in range(len(simulations)): task_name = str(_task_name_orig(i)) sim_data_tidy3d = batch_data_tidy3d[task_name] jax_info = jax_infos[str(task_name)] jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info) jax_batch_data.append(jax_sim_data) return jax_batch_data
def run_async_fwd( simulations: Tuple[JaxSimulation, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, num_workers: int, ) -> Tuple[Tuple[JaxSimulationData, ...], RunResidualBatch]: """Run forward pass and stash extra objects for the backwards pass.""" jax_infos_orig = [] sims_fwd = [] jax_infos_fwd = [] for simulation in simulations: sim_fwd, jax_info_fwd, jax_info_orig = simulation.to_simulation_fwd() jax_infos_orig.append(jax_info_orig) sims_fwd.append(sim_fwd) jax_infos_fwd.append(jax_info_fwd) batch_data_orig, fwd_task_ids = webapi_run_async_adjoint_fwd( simulations=sims_fwd, jax_infos=jax_infos_fwd, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, verbose=verbose, ) batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()] jax_batch_data_orig = [] for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig): jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig) jax_batch_data_orig.append(jax_sim_data) residual = RunResidualBatch(fwd_task_ids=fwd_task_ids) return jax_batch_data_orig, (residual,) def run_async_bwd( folder_name: str, path_dir: str, callback_url: str, verbose: bool, num_workers: int, res: tuple, batch_data_vjp: Tuple[JaxSimulationData, ...], ) -> Tuple[Dict[str, JaxSimulation]]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" fwd_task_ids = res[0].fwd_task_ids sims_adj = [] jax_infos_adj = [] parent_tasks_adj = [] for sim_data_vjp, fwd_task_id in zip(batch_data_vjp, fwd_task_ids): parent_tasks_adj.append([str(fwd_task_id)]) fwidth_adj = sim_data_vjp.simulation._fwidth_adjoint run_time_adj = sim_data_vjp.simulation._run_time_adjoint jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sim_adj, jax_info_adj = jax_sim_adj.to_simulation() sims_adj.append(sim_adj) jax_infos_adj.append(jax_info_adj) sims_vjp = webapi_run_async_adjoint_bwd( simulations=sims_adj, jax_infos=jax_infos_adj, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, verbose=verbose, parent_tasks=parent_tasks_adj, ) # update the JaxSimulation.input_structures in the sim_data_vjps using the adjoint returned vals sims_vjp_updated = [] for sim_vjp, sim_data_vjp in zip(sims_vjp, batch_data_vjp): sim_vjp_orig = sim_data_vjp.simulation sim_vjp_updated = sim_vjp_orig.updated_copy(input_structures=sim_vjp.input_structures) sims_vjp_updated.append(sim_vjp_updated) return (sims_vjp_updated,) def webapi_run_async_adjoint_fwd( simulations: Tuple[Simulation, ...], jax_infos: Tuple[JaxInfo, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, ) -> Tuple[BatchData, Dict[str, str]]: """Runs the forward simulations on our servers, stores the gradient data for later.""" task_names = [str(_task_name_orig(i)) for i in range(len(simulations))] simulations = dict(zip(task_names, simulations)) jax_infos = dict(zip(task_names, jax_infos)) batch = AdjointBatch( simulations=simulations, jax_infos=jax_infos, folder_name=folder_name, callback_url=callback_url, verbose=verbose, simulation_type="adjoint_fwd", ) batch_data_orig = batch.run(path_dir=path_dir) return batch_data_orig, tuple(batch_data_orig.task_ids.values()) def webapi_run_async_adjoint_bwd( simulations: Tuple[Simulation, ...], jax_infos: Tuple[JaxInfo, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, parent_tasks: List[List[str]], ) -> List[JaxSimulation]: """Runs the forward simulations on our servers, stores the gradient data for later.""" task_names = [str(i) for i in range(len(simulations))] simulations = dict(zip(task_names, simulations)) jax_infos = dict(zip(task_names, jax_infos)) parent_tasks = [tuple(task_ids) for task_ids in parent_tasks] parent_tasks_dict = dict(zip(task_names, parent_tasks)) batch = AdjointBatch( simulations=simulations, jax_infos=jax_infos, folder_name=folder_name, callback_url=callback_url, verbose=verbose, simulation_type="adjoint_bwd", parent_tasks=parent_tasks_dict, ) batch.start() batch.monitor() sims_vjp = [] for _, job in batch.jobs.items(): task_id = job.task_id sim_vjp = download_sim_vjp(task_id=task_id, verbose=verbose) sims_vjp.append(sim_vjp) return sims_vjp # register the custom forward and backward functions run_async.defvjp(run_async_fwd, run_async_bwd) """ Options to do the previous but all client side (mainly for testing / debugging).""" @partial(custom_vjp, nondiff_argnums=tuple(range(1, 6))) def run_local( simulation: JaxSimulation, task_name: str, folder_name: str = "default", path: str = "simulation_data.hdf5", callback_url: str = None, verbose: bool = True, ) -> JaxSimulationData: """Submits a :class:`.JaxSimulation` to server, starts running, monitors progress, downloads, and loads results as a :class:`.JaxSimulationData` object. Can be included within a function that will have ``jax.grad`` applied. Parameters ---------- simulation : :class:`.JaxSimulation` Simulation to upload to server. task_name : str Name of task. path : str = "simulation_data.hdf5" Path to download results file (.hdf5), including filename. folder_name : str = "default" Name of folder to store task on web UI. callback_url : str = None Http PUT url to receive simulation finish event. The body content is a json file with fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. verbose : bool = True If `True`, will print progressbars and status, otherwise, will run silently. Returns ------- :class:`.JaxSimulationData` Object containing solver results for the supplied :class:`.JaxSimulation`. """ # convert to regular tidy3d (and accounting info) sim_tidy3d, jax_info = simulation.to_simulation() # run using regular tidy3d simulation running fn sim_data_tidy3d = tidy3d_run_fn( simulation=sim_tidy3d, task_name=str(task_name), folder_name=folder_name, path=path, callback_url=callback_url, verbose=verbose, ) # convert back to jax type and return return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info) def run_local_fwd( simulation: JaxSimulation, task_name: str, folder_name: str, path: str, callback_url: str, verbose: bool, ) -> Tuple[JaxSimulationData, tuple]: """Run forward pass and stash extra objects for the backwards pass.""" # add the gradient monitors and run the forward simulation grad_mnts = simulation.get_grad_monitors( input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint ) sim_fwd = simulation.updated_copy(**grad_mnts) sim_data_fwd = run( simulation=sim_fwd, task_name=_task_name_fwd(task_name), folder_name=folder_name, path=path, callback_url=callback_url, verbose=verbose, ) # remove the gradient data from the returned version (not needed) sim_data_orig = sim_data_fwd.copy(update=dict(grad_data=(), simulation=simulation)) return sim_data_orig, (sim_data_fwd,) def run_local_bwd( task_name: str, folder_name: str, path: str, callback_url: str, verbose: bool, res: tuple, sim_data_vjp: JaxSimulationData, ) -> Tuple[JaxSimulation]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" # grab the forward simulation and its gradient monitor data (sim_data_fwd,) = res grad_data_fwd = sim_data_fwd.grad_data_symmetry grad_eps_data_fwd = sim_data_fwd.grad_eps_data_symmetry # make and run adjoint simulation fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint run_time_adj = sim_data_fwd.simulation._run_time_adjoint sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sim_data_adj = run( simulation=sim_adj, task_name=_task_name_adj(task_name), folder_name=folder_name, path=path, callback_url=callback_url, verbose=verbose, ) sim_data_adj = sim_data_adj.normalize_adjoint_fields() grad_data_adj = sim_data_adj.grad_data_symmetry # get gradient and insert into the resulting simulation structure medium sim_vjp = sim_data_vjp.simulation.store_vjp(grad_data_fwd, grad_data_adj, grad_eps_data_fwd) return (sim_vjp,) # register the custom forward and backward functions run_local.defvjp(run_local_fwd, run_local_bwd) """ Running a batch of simulations using web.run_async. """ def _task_name_orig_local(index: int, task_name_suffix: str = None): """Task name as function of index into simulations. Note: for original must be int.""" if task_name_suffix is not None: return f"{index}{task_name_suffix}" return int(index) @partial(custom_vjp, nondiff_argnums=tuple(range(1, 7))) def run_async_local( simulations: Tuple[JaxSimulation, ...], folder_name: str = "default", path_dir: str = DEFAULT_DATA_DIR, callback_url: str = None, verbose: bool = True, num_workers: int = None, task_name_suffix: str = None, ) -> Tuple[JaxSimulationData, ...]: """Submits a set of :class:`.JaxSimulation` objects to server, starts running, monitors progress, downloads, and loads results as a tuple of :class:`.JaxSimulationData` objects. Uses ``asyncio`` to perform these steps asynchronously. Parameters ---------- simulations : Tuple[:class:`.JaxSimulation`, ...] Collection of :class:`.JaxSimulations` to run asynchronously. folder_name : str = "default" Name of folder to store each task on web UI. path_dir : str Base directory where data will be downloaded, by default current working directory. callback_url : str = None Http PUT url to receive simulation finish event. The body content is a json file with fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. verbose : bool = True If `True`, will print progressbars and status, otherwise, will run silently. num_workers: int = None Number of tasks to submit at once in a batch, if None, will run all at the same time. Note ---- This does the adjoint processing on the client side. So more data will be required for download. Returns ------ Tuple[:class:`.JaxSimulationData`, ...] Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`. """ simulations = { _task_name_orig_local(i, task_name_suffix): sim for i, sim in enumerate(simulations) } task_info = {task_name: jax_sim.to_simulation() for task_name, jax_sim in simulations.items()} # TODO: anyone know a better syntax for this? sims_tidy3d = {str(task_name): sim for task_name, (sim, _) in task_info.items()} jax_infos = {str(task_name): jax_info for task_name, (_, jax_info) in task_info.items()} # run using regular tidy3d simulation running fn batch_data_tidy3d = tidy3d_run_async_fn( simulations=sims_tidy3d, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, verbose=verbose, num_workers=num_workers, ) # convert back to jax type and return task_name_suffix = "" if task_name_suffix is None else task_name_suffix jax_batch_data = [] for i in range(len(simulations)): task_name = _task_name_orig_local(i, task_name_suffix) sim_data_tidy3d = batch_data_tidy3d[task_name] jax_info = jax_infos[str(task_name)] jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info) jax_batch_data.append(jax_sim_data) return jax_batch_data def run_async_local_fwd( simulations: Tuple[JaxSimulation, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, num_workers: int, task_name_suffix: str, ) -> Tuple[Dict[str, JaxSimulationData], tuple]: """Run forward pass and stash extra objects for the backwards pass.""" task_name_suffix_fwd = _task_name_fwd("") sims_fwd = [] for simulation in simulations: grad_mnts = simulation.get_grad_monitors( input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint ) sim_fwd = simulation.updated_copy(**grad_mnts) sims_fwd.append(sim_fwd) batch_data_fwd = run_async_local( simulations=sims_fwd, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, verbose=verbose, num_workers=num_workers, task_name_suffix=task_name_suffix_fwd, ) # remove the gradient data from the returned version (not needed) batch_data_orig = [] for i, sim_data_fwd in enumerate(batch_data_fwd): sim_orig = simulations[i] sim_data_orig = sim_data_fwd.copy(update=dict(grad_data=(), simulation=sim_orig)) batch_data_orig.append(sim_data_orig) return batch_data_orig, (batch_data_fwd,) def run_async_local_bwd( folder_name: str, path_dir: str, callback_url: str, verbose: bool, num_workers: int, task_name_suffix: str, res: tuple, batch_data_vjp: Tuple[JaxSimulationData, ...], ) -> Tuple[Dict[str, JaxSimulation]]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" # grab the forward simulation and its gradient monitor data (batch_data_fwd,) = res task_name_suffix_adj = _task_name_adj("") grad_data_fwd = {} grad_eps_data_fwd = {} for i, sim_data_fwd in enumerate(batch_data_fwd): grad_data_fwd[i] = sim_data_fwd.grad_data_symmetry grad_eps_data_fwd[i] = sim_data_fwd.grad_eps_data_symmetry # make and run adjoint simulation sims_adj = [] for i, sim_data_fwd in enumerate(batch_data_fwd): fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint run_time_adj = sim_data_fwd.simulation._run_time_adjoint sim_data_vjp = batch_data_vjp[i] sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sims_adj.append(sim_adj) batch_data_adj = run_async_local( simulations=sims_adj, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, verbose=verbose, num_workers=num_workers, task_name_suffix=task_name_suffix_adj, ) sims_vjp = [] for i, (sim_data_fwd, sim_data_adj) in enumerate(zip(batch_data_fwd, batch_data_adj)): sim_data_adj = sim_data_adj.normalize_adjoint_fields() grad_data_fwd = sim_data_fwd.grad_data_symmetry grad_data_adj = sim_data_adj.grad_data_symmetry grad_data_eps_fwd = sim_data_fwd.grad_eps_data_symmetry sim_data_vjp = batch_data_vjp[i] sim_vjp = sim_data_vjp.simulation.store_vjp(grad_data_fwd, grad_data_adj, grad_data_eps_fwd) sims_vjp.append(sim_vjp) return (sims_vjp,) # register the custom forward and backward functions run_async_local.defvjp(run_async_local_fwd, run_async_local_bwd)