# container for everything defining the inverse designfrom__future__importannotationsimportabcimporttypingimportautograd.numpyasanpimportnumpyasnpimportpydantic.v1aspdimporttidy3dastdfromtidy3d.components.autogradimportget_staticfromtidy3d.exceptionsimportValidationErrorfromtidy3d.plugins.expressions.metricsimportMetric,generate_validation_datafromtidy3d.plugins.expressions.typesimportExpressionTypefrom.baseimportInvdesBaseModelfrom.regionimportDesignRegionTypefrom.validatorsimportcheck_pixel_sizePostProcessFnType=typing.Callable[[td.SimulationData],float]classAbstractInverseDesign(InvdesBaseModel,abc.ABC):"""Container for an inverse design problem."""design_region:DesignRegionType=pd.Field(...,title="Design Region",description="Region within which we will optimize the simulation.",)task_name:str=pd.Field(...,title="Task Name",description="Task name to use in the objective function when running the ``JaxSimulation``.",)verbose:bool=pd.Field(False,title="Task Verbosity",description="If ``True``, will print the regular output from ``web`` functions.",)metric:typing.Optional[ExpressionType]=pd.Field(None,title="Objective Metric",description="Serializable expression defining the objective function.",)defmake_objective_fn(self,post_process_fn:typing.Optional[typing.Callable]=None,maximize:bool=True)->typing.Callable[[anp.ndarray],tuple[float,dict]]:"""Construct the objective function for this InverseDesign object."""if(post_process_fnisNone)and(self.metricisNone):raiseValueError("Either 'post_process_fn' or 'metric' must be provided.")if(post_process_fnisnotNone)and(self.metricisnotNone):raiseValueError("Provide only one of 'post_process_fn' or 'metric', not both.")direction_multiplier=1ifmaximizeelse-1defobjective_fn(params:anp.ndarray,aux_data:dict=None)->float:"""Full objective function."""data=self.to_simulation_data(params=params)ifself.metricisNone:post_process_val=post_process_fn(data)elifisinstance(data,td.SimulationData):post_process_val=self.metric.evaluate(data)elifgetattr(data,"type",None)=="BatchData":raiseNotImplementedError("Metrics currently do not support 'BatchData'")else:raiseValueError(f"Invalid data type: {type(data)}")penalty_value=self.design_region.penalty_value(params)objective_fn_val=direction_multiplier*post_process_val-penalty_value# Store auxiliary data if providedifaux_dataisnotNone:aux_data["penalty"]=get_static(penalty_value)aux_data["post_process_val"]=get_static(post_process_val)aux_data["objective_fn_val"]=get_static(objective_fn_val)*direction_multiplierifisinstance(data,td.SimulationData):aux_data["sim_data"]=data.to_static()else:aux_data["sim_data"]={k:v.to_static()fork,vindata.items()}aux_data["params"]=paramsreturnobjective_fn_valreturnobjective_fn@propertydefinitial_simulation(self)->td.Simulation:"""Return a simulation with the initial design region parameters."""initial_params=self.design_region.initial_parametersreturnself.to_simulation(initial_params)defrun(self,simulation,**kwargs)->td.SimulationData:"""Run a single tidy3d simulation."""fromtidy3d.webimportrunkwargs.setdefault("verbose",self.verbose)kwargs.setdefault("task_name",self.task_name)returnrun(simulation,**kwargs)defrun_async(self,simulations,**kwargs)->web.BatchData:# noqa: F821"""Run a batch of tidy3d simulations."""fromtidy3d.webimportrun_asynckwargs.setdefault("verbose",self.verbose)returnrun_async(simulations,**kwargs)
[docs]classInverseDesign(AbstractInverseDesign):"""Container for an inverse design problem."""simulation:td.Simulation=pd.Field(...,title="Base Simulation",description="Simulation without the design regions or monitors used in the objective fn.",)output_monitor_names:typing.Tuple[str,...]=pd.Field(None,title="Output Monitor Names",description="Optional names of monitors whose data the differentiable output depends on.""If this field is left ``None``, the plugin will try to add all compatible monitors to ""``JaxSimulation.output_monitors``. While this will work, there may be warnings if the ""monitors are not compatible with the ``adjoint`` plugin, for example if there are ""``FieldMonitor`` instances with ``.colocate != False``.",)_check_sim_pixel_size=check_pixel_size("simulation")@pd.root_validator(pre=False)def_validate_model(cls,values:dict)->dict:cls._validate_metric(values)returnvalues@staticmethoddef_validate_metric(values:dict)->dict:metric_expr=values.get("metric")ifnotmetric_expr:returnvaluessimulation=values.get("simulation")formetricinmetric_expr.filter(Metric):InverseDesign._validate_metric_monitor_name(metric,simulation)InverseDesign._validate_metric_mode_index(metric,simulation)InverseDesign._validate_metric_f(metric,simulation)InverseDesign._validate_metric_data(metric_expr,simulation)returnvalues@staticmethoddef_validate_metric_monitor_name(metric:Metric,simulation:td.Simulation)->None:"""Validate that the monitor name of the metric exists in the simulation."""monitor=next((mforminsimulation.monitorsifm.name==metric.monitor_name),None)ifmonitorisNone:raiseValidationError(f"Monitor named '{metric.monitor_name}' associated with the metric not found in the simulation monitors.")@staticmethoddef_validate_metric_mode_index(metric:Metric,simulation:td.Simulation)->None:"""Validate that the mode index of the metric is within the bounds of the monitor's ``ModeSpec.num_modes``."""monitor=next((mforminsimulation.monitorsifm.name==metric.monitor_name),None)ifmetric.mode_index>=monitor.mode_spec.num_modes:raiseValidationError(f"Mode index '{metric.mode_index}' for metric associated with monitor "f"'{metric.monitor_name}' is out of bounds. "f"Maximum allowed mode index is '{monitor.mode_spec.num_modes-1}'.")@staticmethoddef_validate_metric_f(metric:Metric,simulation:td.Simulation)->None:"""Validate that the frequencies of the metric are present in the monitor."""monitor=next((mforminsimulation.monitorsifm.name==metric.monitor_name),None)ifmetric.fisnotNone:metric_f_list=[metric.f]ifisinstance(metric.f,float)elsemetric.fiflen(metric_f_list)!=1:raiseValidationError("Only a single frequency is supported for the metric.")forfreqinmetric_f_list:ifnotany(np.isclose(freq,monitor.freqs,atol=1.0)):raiseValidationError(f"Frequency '{freq}' for metric associated with monitor "f"'{metric.monitor_name}' not found in monitor frequencies.")else:iflen(monitor.freqs)!=1:raiseValidationError(f"Monitor '{metric.monitor_name}' must contain only a single frequency when metric.f is None.")@staticmethoddef_validate_metric_data(expr:ExpressionType,simulation:td.Simulation)->None:"""Validate that expression can be evaluated and returns a real scalar."""data=generate_validation_data(expr)try:result=expr(data)exceptExceptionase:raiseValidationError(f"Failed to evaluate the metric expression: {str(e)}")fromeiflen(np.ravel(result))>1:raiseValidationError(f"The expression must return a scalar value or an array of length 1 (got {result}).")ifnotnp.all(np.isreal(result)):raiseValidationError(f"The expression must return a real (not complex) value (got {result}).")
[docs]defis_output_monitor(self,monitor:td.Monitor)->bool:"""Whether a monitor is added to the ``JaxSimulation`` as an ``output_monitor``."""output_mnt_types=td.components.simulation.OutputMonitorTypesifself.output_monitor_namesisNone:returnany(isinstance(monitor,mnt_type)formnt_typeinoutput_mnt_types)returnmonitor.nameinself.output_monitor_names
[docs]defseparate_output_monitors(self,monitors:typing.Tuple[td.Monitor])->dict:"""Separate monitors into output_monitors and regular monitors."""monitor_fields=dict(monitors=[],output_monitors=[])formonitorinmonitors:key="output_monitors"ifself.is_output_monitor(monitor)else"monitors"monitor_fields[key].append(monitor)returnmonitor_fields
[docs]defto_simulation(self,params:anp.ndarray)->td.Simulation:"""Convert the ``InverseDesign`` to a corresponding ``td.Simulation`` with traced fields."""# construct the design region to a regular structuredesign_region_structure=self.design_region.to_structure(params)# construct mesh override structures and a new grid spec, if applicablegrid_spec=self.simulation.grid_specmesh_override_structure=self.design_region.mesh_override_structureifmesh_override_structure:override_structures=list(self.simulation.grid_spec.override_structures)override_structures+=[mesh_override_structure]grid_spec=grid_spec.updated_copy(override_structures=override_structures)returnself.simulation.updated_copy(structures=list(self.simulation.structures)+[design_region_structure],grid_spec=grid_spec,)
[docs]defto_simulation_data(self,params:anp.ndarray,**kwargs)->td.SimulationData:"""Convert the ``InverseDesign`` to a ``td.Simulation`` and run it."""simulation=self.to_simulation(params=params)returnself.run(simulation,**kwargs)
[docs]classInverseDesignMulti(AbstractInverseDesign):"""``InverseDesign`` with multiple simulations and corresponding postprocess functions."""simulations:typing.Tuple[td.Simulation,...]=pd.Field(...,title="Base Simulations",description="Set of simulation without the design regions or monitors used in the objective fn.",)output_monitor_names:typing.Tuple[typing.Union[typing.Tuple[str,...],None],...]=pd.Field(None,title="Output Monitor Names",description="Optional names of monitors whose data the differentiable output depends on.""If this field is left ``None``, the plugin will try to add all compatible monitors to ""``JaxSimulation.output_monitors``. While this will work, there may be warnings if the ""monitors are not compatible with the ``adjoint`` plugin, for example if there are ""``FieldMonitor`` instances with ``.colocate != False``.",)_check_sim_pixel_size=check_pixel_size("simulations")@pd.root_validator()def_check_lengths(cls,values):"""Check the lengths of all of the multi fields."""keys=("simulations","post_process_fns","output_monitor_names","override_structure_dl")multi_dict={key:values.get(key)forkeyinkeys}sizes={key:len(val)forkey,valinmulti_dict.items()ifvalisnotNone}iflen(set(sizes.values()))!=1:raiseValueError(f"'MultiInverseDesign' requires that the fields {keys} must either ""have the same length or be left ``None``, if optional. Given fields with ""corresponding sizes of '{sizes}'.")returnvalues@propertydeftask_names(self)->list[str]:"""Task names associated with each of the simulations."""return[f"{self.task_name}_{i}"foriinrange(len(self.simulations))]@propertydefdesigns(self)->typing.List[InverseDesign]:"""List of individual ``InverseDesign`` objects corresponding to this instance."""designs_list=[]fori,(task_name,sim)inenumerate(zip(self.task_names,self.simulations)):des_i=InverseDesign(design_region=self.design_region,simulation=sim,verbose=self.verbose,task_name=task_name,)ifself.output_monitor_namesisnotNone:des_i=des_i.updated_copy(output_monitor_names=self.output_monitor_names[i])designs_list.append(des_i)returndesigns_list
[docs]defto_simulation(self,params:anp.ndarray)->dict[str,td.Simulation]:"""Convert the ``InverseDesign`` to a corresponding dict of ``td.Simulation``s."""simulation_list=[design.to_simulation(params)fordesigninself.designs]returndict(zip(self.task_names,simulation_list))
[docs]defto_simulation_data(self,params:anp.ndarray,**kwargs)->web.BatchData:# noqa: F821"""Convert the ``InverseDesignMulti`` to a set of ``td.Simulation``s and run async."""simulations=self.to_simulation(params)returnself.run_async(simulations,**kwargs)