Source code for photonforge.circuit_time_stepper

import threading
import time
import warnings
from collections.abc import Sequence
from queue import Queue
from typing import Any

from . import typing as pft
from .cache import _mode_overlap_cache
from .circuit_base import _gather_status, _process_component_netlist
from .extension import (
    Component,
    FiberPort,
    GaussianPort,
    Port,
    TimeStepper,
    config,
    register_time_stepper_class,
)
from .tidy3d_model import _align_and_overlap


def _stepper(work_queue):
    while True:
        fn, args, outputs = work_queue.get()
        if fn is None:
            return
        try:
            result = fn(*args)
        except Exception as err:
            result = str(err)
        outputs.append(result)
        work_queue.task_done()


[docs] class CircuitTimeStepper(TimeStepper): """Circuit-level time stepper. Constructs time steppers for individual circuit elements and handles connections between them. Each time stepper initialization is preceded by an update to the componoent's technology, the component itself, and its active model by calling :attr:`Reference.update`. They are reset to their original state afterwards. If a reference includes repetitions, it is flattened so that each instance is called separatelly. Args: mesh_refinement: Minimum number of mesh elements per wavelength used for mode solving. max_iterations: Maximum number of iterations for self-consistent signal propagation through the circuit. A larger value may be needed for larger circuits or high-Q feedback loops. abs_tolerance: The absolute tolerance for the convergence check. rel_tolerance: The relative tolerance for the convergence check. max_threads: Maximum number of threads used for stepping individual subcomponents. verbose: Flag setting the verbosity of mode solver runs. """ def __init__( self, mesh_refinement: pft.PositiveFloat | None = None, max_iterations: pft.PositiveInt = 100, abs_tolerance: pft.PositiveFloat = 1e-8, rel_tolerance: pft.PositiveFloat = 1e-5, max_threads: pft.PositiveInt = 8, verbose: bool = True, ): super().__init__( mesh_refinement=mesh_refinement, max_iterations=max_iterations, abs_tolerance=abs_tolerance, rel_tolerance=rel_tolerance, max_threads=max_threads, verbose=verbose, ) self._status = None self._work_queue = None
[docs] def setup_state( self, *, component: Component, time_step: float, carrier_frequency: float, monitors: dict[str, Port | FiberPort | GaussianPort] = {}, updates: dict[Sequence[str | int | None], dict[str, dict[str, Any]]] = {}, chain_technology_updates: bool = True, verbose: bool | None = None, **kwargs, ): """Initialize internal circuit variables. Args: component: Component for the time stepper. time_step: The interval between time steps (in seconds). carrier_frequency: The carrier frequency used to construct the time stepper. The carrier should be omitted from the input signals, as it is handled automatically by the time stepper. monitors: Additional ports to include in the outputs as a dictionary with monitor names as keys. Subcomponent ports can be obtained with :func:`Reference.get_ports` for the 1st level of references or with :func:`Component.query` for more deeply-nested ports. updates: Dictionary of parameter updates to be applied to components, technologies, and models for references within the main component. See :func:`CircuitModel.start` for examples. chain_technology_updates: if set, a technology update will trigger an update for all components using that technology. verbose: If set, overrides the model's ``verbose`` attribute and is passed to reference models. kwargs: Keyword arguments passed to reference models. Returns: Object with a status dictionary. """ time_stepper_kwargs = { "time_step": time_step, "carrier_frequency": carrier_frequency, } if verbose is None: verbose = self.parametric_kwargs["verbose"] else: time_stepper_kwargs["verbose"] = verbose frequencies = [carrier_frequency if carrier_frequency > 0 else (1 / time_step)] ( runners, self.time_steppers, _, port_connections, self.connections, instance_port_data, self.monitors, ) = _process_component_netlist( component, frequencies, self.parametric_kwargs["mesh_refinement"], monitors, updates, chain_technology_updates, verbose, False, kwargs, None, time_stepper_kwargs, ) self.port_connections = [ (index, f"{ref_port_name}@{n}", f"{port_name}@{n}") for (index, ref_port_name, num_modes), port_name in port_connections.items() for n in range(num_modes) ] self.component_name = component.name self.lock = threading.Lock() self._status = {"message": "running", "progress": 0} self.setup_thread = threading.Thread( daemon=True, target=self._setup_and_monitor, args=(runners, instance_port_data) ) self.setup_thread.start() return self
def _setup_and_monitor(self, runners, instance_port_data): runners = {k: v for k, v in runners.items() if v is not None} joint_status = _gather_status(*runners.values()) with self.lock: self._status = dict(joint_status) self._status["progress"] *= 0.95 while joint_status["message"] == "running": time.sleep(0.3) joint_status = _gather_status(*runners.values()) with self.lock: self._status = dict(joint_status) self._status["progress"] *= 0.95 if joint_status["message"] == "error": with self.lock: self._status = joint_status return with self.lock: self._status = joint_status self._status["message"] = "running" self._status["progress"] *= 0.95 self.mode_factors = [{} for _ in range(len(instance_port_data))] for index, (instance_ports, instance_keys) in enumerate(instance_port_data): # Check if reference is needed if instance_ports is None: continue # Fix port phases if a rotation is applied if instance_keys is not None: for port_name, port in instance_ports: key = instance_keys.get(port_name) if key is None: continue # Port mode overlap = _mode_overlap_cache[key] if overlap is None: overlap = _align_and_overlap( runners[(index, port_name, 0)].data, runners[(index, port_name, 1)].data )[0] _mode_overlap_cache[key] = overlap self.mode_factors[index].update( {f"{port_name}@{mode}": overlap[mode] for mode in range(port.num_modes)} ) with self.lock: self._status["progress"] = 95 + 5 * (index + 1) / len(instance_port_data) self.port_state = [{} for _ in instance_port_data] for connection in self.connections: for index, port_name, num_modes in connection: self.port_state[index].update( {f"{port_name}@{mode}": 0 for mode in range(num_modes)} ) self.emitted_convergence_warning = False with self.lock: self._status["progress"] = 100 self._status["message"] = "success"
[docs] def reset(self): """Reset the state of the circuit variables.""" self.port_state = [{} for _ in self.port_state] for connection in self.connections: for index, port_name, num_modes in connection: self.port_state[index].update( {f"{port_name}@{mode}": 0 for mode in range(num_modes)} ) self.emitted_convergence_warning = False for time_stepper in self.time_steppers.values(): time_stepper.reset()
@property def status(self): if not self.setup_thread.is_alive() and self._status["message"] == "running": self._status["message"] = "error" with self.lock: return self._status
[docs] def step_single( self, inputs: dict[str, complex], time_index: int, update_state: bool, shutdown: bool ) -> dict[str, complex]: """Take a single time step on the given inputs. Args: inputs: Dictionary containing inputs at the current time step, mapping port names to complex values. time_index: Time series index for the current input. update_state: Whether to update the internal stepper state. shutdown: Whether this is the last call to the single stepping function for the provided :class:`TimeSeries`. Returns: Outputs at the current time step. """ num_instances = len(self.port_state) abs_tolerance = self.parametric_kwargs["abs_tolerance"] rel_tolerance = self.parametric_kwargs["rel_tolerance"] if self._work_queue is None: self._work_queue = Queue() for _ in range(max(1, min(num_instances, self.parametric_kwargs["max_threads"]))): threading.Thread(daemon=True, target=_stepper, args=(self._work_queue,)).start() input_state = [dict(d) for d in self.port_state] # apply input to input_state for index, ref_name, port_name in self.port_connections: if port_name in inputs: input_state[index][ref_name] = inputs[port_name] elif ref_name not in input_state[index]: input_state[index][ref_name] = 0 max_iterations = max(self.parametric_kwargs["max_iterations"], 1) is_last_iteration = max_iterations == 1 # self.port_state stores inputs at time t. Compute outputs at time t, and use them as inputs # at time t, iterate. Once converged, use the outputs as inputs at time t+1 for current_iteration in range(1, max_iterations + 1): if current_iteration == max_iterations: is_last_iteration = True update = update_state and is_last_iteration shdwn = shutdown and is_last_iteration work_items = [] for index, mode_factors in enumerate(self.mode_factors): instance_inputs = { k: v * mode_factors.get(k, 1.0) for k, v in input_state[index].items() } work_item = ( self.time_steppers[index].step_single, (instance_inputs, time_index, update, shdwn), [], ) work_items.append(work_item) self._work_queue.put(work_item) # block for runners to finish self._work_queue.join() output_state = [{} for _ in range(num_instances)] for index, (_, _, results) in enumerate(work_items): if isinstance(results[0], str): raise RuntimeError(results[0]) mf = self.mode_factors[index] output_state[index] = {k: v / mf.get(k, 1.0) for k, v in results[0].items()} # apply connections next_input_state = [{} for _ in range(num_instances)] for (idx1, port_name1, num_modes), (idx2, port_name2, _) in self.connections: for mode in range(num_modes): key1 = f"{port_name1}@{mode}" key2 = f"{port_name2}@{mode}" val1 = output_state[idx1].get(key1) if val1 is not None: next_input_state[idx2][key2] = val1 val2 = output_state[idx2].get(key2) if val2 is not None: next_input_state[idx1][key1] = val2 # apply input to input_state for index, ref_name, port_name in self.port_connections: if port_name in inputs: next_input_state[index][ref_name] = inputs[port_name] elif ref_name not in next_input_state[index]: next_input_state[index][ref_name] = 0 if is_last_iteration: break # check convergence: math.isclose is symmetric and numpy.isclose is not converged = True for prv, nxt in zip(input_state, next_input_state, strict=True): for k, v1 in nxt.items(): v0 = prv.get(k, 0) if abs(v1 - v0) > max(abs_tolerance, rel_tolerance * max(abs(v1), abs(v0))): converged = False break if not converged: break if converged: is_last_iteration = True input_state = next_input_state if update_state: self.port_state = next_input_state if shutdown: self._work_queue.put((None, None, None)) self._work_queue = None if max_iterations > 1 and not self.emitted_convergence_warning and not converged: warnings.warn( f"Time stepper for component '{self.component_name}' failed to converge. " f"Consider increasing 'max_iterations'.", stacklevel=2, ) self.emitted_convergence_warning = True # store outputs outputs = { port_name: output_state[index].get(ref_name, 0) for index, ref_name, port_name in self.port_connections } # include monitors for index, port_name, num_modes, monitor_name in self.monitors: outputs.update( { f"{monitor_name}@{mode}-": input_state[index].get(f"{port_name}@{mode}", 0) for mode in range(num_modes) } ) outputs.update( { f"{monitor_name}@{mode}+": output_state[index].get(f"{port_name}@{mode}", 0) for mode in range(num_modes) } ) return outputs
register_time_stepper_class(CircuitTimeStepper) config.default_time_steppers["CircuitModel"] = CircuitTimeStepper()