"""Fit PoleResidue Dispersion models to optical NK data"""from__future__importannotationsimportcodecsimportcsvfromtypingimportList,Optional,Tupleimportnumpyasnpimportrequestsimportscipy.optimizeasoptfrompydantic.v1importField,validatorfromrich.progressimportProgressfromtidy3d.web.core.environmentimportEnvfrom...components.baseimportTidy3dBaseModel,cached_property,skip_if_fields_missingfrom...components.mediumimportAbstractMedium,PoleResiduefrom...components.typesimportArrayFloat1D,Axfrom...components.vizimportadd_ax_if_nonefrom...constantsimportC_0,HBAR,MICROMETERfrom...exceptionsimportSetupError,ValidationError,WebErrorfrom...logimportget_logging_console,log
[docs]classDispersionFitter(Tidy3dBaseModel):"""Tool for fitting refractive index data to get a dispersive medium described by :class:`.PoleResidue` model."""wvl_um:ArrayFloat1D=Field(...,title="Wavelength data",description="Wavelength data in micrometers.",units=MICROMETER,)n_data:ArrayFloat1D=Field(...,title="Index of refraction data",description="Real part of the complex index of refraction.",)k_data:ArrayFloat1D=Field(None,title="Extinction coefficient data",description="Imaginary part of the complex index of refraction.",)wvl_range:Tuple[Optional[float],Optional[float]]=Field((None,None),title="Wavelength range [wvl_min,wvl_max] for fitting",description="Truncate the wavelength, n and k data to the wavelength range '[wvl_min, ""wvl_max]' for fitting.",units=MICROMETER,)@validator("wvl_um",always=True)def_setup_wvl(cls,val):"""Convert wvl_um to a numpy array."""ifval.size==0:raiseValidationError("Wavelength data cannot be empty.")returnval@validator("n_data",always=True)@skip_if_fields_missing(["wvl_um"])def_ndata_length_match_wvl(cls,val,values):"""Validate n_data"""ifval.shape!=values["wvl_um"].shape:raiseValidationError("The length of 'n_data' doesn't match 'wvl_um'.")returnval@validator("k_data",always=True)@skip_if_fields_missing(["wvl_um"])def_kdata_setup_and_length_match(cls,val,values):"""Validate the length of k_data, or setup k if it's None."""ifvalisNone:returnnp.zeros_like(values["wvl_um"])ifval.shape!=values["wvl_um"].shape:raiseValidationError("The length of 'k_data' doesn't match 'wvl_um'.")returnval@cached_propertydefdata_in_range(self)->Tuple[ArrayFloat1D,ArrayFloat1D,ArrayFloat1D]:"""Filter the wavelength-nk data to wavelength range for fitting. Returns ------- Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] Filtered wvl_um, n_data, k_data """ind_select=np.ones(self.wvl_um.shape,dtype=bool)ifself.wvl_range[0]isnotNone:ind_select=np.logical_and(self.wvl_um>=self.wvl_range[0],ind_select)ifself.wvl_range[1]isnotNone:ind_select=np.logical_and(self.wvl_um<=self.wvl_range[1],ind_select)ifnotnp.any(ind_select):raiseSetupError("No data within 'wvl_range'")returnself.wvl_um[ind_select],self.n_data[ind_select],self.k_data[ind_select]@cached_propertydeflossy(self)->bool:"""Find out if the medium is lossy or lossless based on the filtered input data. Returns ------- bool True for lossy medium; False for lossless medium """_,_,k_data=self.data_in_rangereturnk_dataisnotNoneandnp.any(k_data)@propertydefeps_data(self)->complex:"""Convert filtered input n(k) data into complex permittivity. Returns ------- complex Complex-valued relative permittivty. """_,n_data,k_data=self.data_in_rangereturnAbstractMedium.nk_to_eps_complex(n=n_data,k=k_data)@propertydeffreqs(self)->Tuple[float,...]:"""Convert filtered input wavelength data to frequency. Returns ------- Tuple[float, ...] Frequency array converted from filtered input wavelength data """wvl_um,_,_=self.data_in_rangereturnC_0/wvl_um@propertydeffrequency_range(self)->Tuple[float,float]:"""Frequency range of filtered input data Returns ------- Tuple[float, float] The minimal frequency and the maximal frequency """returnself.freqs.min(),self.freqs.max()@staticmethoddef_unpack_coeffs(coeffs):"""Unpack coefficient vector into complex pole parameters. Parameters ---------- coeffs : np.ndarray[real] Array of real coefficients for the pole residue fit. Returns ------- Tuple[np.ndarray[complex], np.ndarray[complex]] "a" and "c" poles for the PoleResidue model. """iflen(coeffs)%4!=0:raiseValueError(f"len(coeffs) must be multiple of 4, got {len(coeffs)=}.")a_real=coeffs[0::4]a_imag=coeffs[1::4]c_real=coeffs[2::4]c_imag=coeffs[3::4]poles_a=a_real+1j*a_imagpoles_c=c_real+1j*c_imagreturnpoles_a,poles_c@staticmethoddef_pack_coeffs(pole_a,pole_c):"""Pack complex a and c pole parameters into coefficient array. Parameters ---------- pole_a : np.ndarray[complex] Array of complex "a" poles for the PoleResidue dispersive model. pole_c : np.ndarray[complex] Array of complex "c" poles for the PoleResidue dispersive model. Returns ------- np.ndarray[float] Array of real coefficients for the pole residue fit. """stacked_coeffs=np.stack((pole_a.real,pole_a.imag,pole_c.real,pole_c.imag),axis=1)returnstacked_coeffs.flatten()@staticmethoddef_coeffs_to_poles(coeffs):"""Convert model coefficients to poles. Parameters ---------- coeffs : np.ndarray[float] Array of real coefficients for the pole residue fit. Returns ------- List[Tuple[complex, complex]] List of complex poles (a, c) """coeffs_scaled=coeffs/HBARpoles_a,poles_c=DispersionFitter._unpack_coeffs(coeffs_scaled)returnlist(zip(poles_a,poles_c))@staticmethoddef_poles_to_coeffs(poles):"""Convert poles to model coefficients. Parameters ---------- poles : List[Tuple[complex, complex]] List of complex poles (a, c) Returns ------- np.ndarray[float] Array of real coefficients for the pole residue fit. """poles=np.array(poles,dtype=complex)coeffs=DispersionFitter._pack_coeffs(poles[:,0],poles[:,1])returncoeffs*HBAR@staticmethoddef_eV_to_Hz(f_eV:float):"""Convert frequency in unit of eV to Hz. Parameters ---------- f_eV : float Frequency in unit of eV """returnf_eV/(HBAR*2*np.pi)@staticmethoddef_Hz_to_eV(f_Hz:float):"""Convert frequency in unit of Hz to eV. Parameters ---------- f_Hz : float Frequency in unit of Hz """returnf_Hz*HBAR*2*np.pi
[docs]deffit(self,num_poles:int=1,num_tries:int=50,tolerance_rms:float=1e-2,guess:PoleResidue=None,)->Tuple[PoleResidue,float]:"""Fit data a number of times and returns best results. Parameters ---------- num_poles : int, optional Number of poles in the model. num_tries : int, optional Number of optimizations to run with different initial guesses. tolerance_rms : float, optional RMS error below which the fit is successful and the result is returned. guess : :class:`.PoleResidue` = None A :class:`.PoleResidue` medium to use as the initial guess in the first optimization run. Returns ------- Tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """# Run it a number of times.best_medium=Nonebest_rms=np.infwithProgress(console=get_logging_console())asprogress:task=progress.add_task(f"Fitting with {num_poles} to RMS of {tolerance_rms}...",total=num_tries)whilenotprogress.finished:# if guess is provided use it in the first optimization runifguessisnotNoneandprogress.tasks[0].completed==0:medium,rms_error=self._fit_single(num_poles=num_poles,guess=guess)else:medium,rms_error=self._fit_single(num_poles=num_poles)# if improvement, set the best RMS and coeffsifrms_error<best_rms:best_rms=rms_errorbest_medium=mediumprogress.update(task,advance=1,description=f"Best RMS error so far: {best_rms:.3g}",refresh=True,)# if below tolerance, returnifbest_rms<tolerance_rms:progress.update(task,completed=num_tries,description=f"Best RMS error: {best_rms:.3g}",refresh=True,)log.info("Found optimal fit with RMS error %.3g",best_rms)returnbest_medium,best_rms# if exited loop, did not reach tolerance (warn)log.warning("Unable to fit with RMS error under 'tolerance_rms' of %.3g",tolerance_rms)log.info("Returning best fit with RMS error %.3g",best_rms)returnbest_medium,best_rms
def_make_medium(self,coeffs):"""Return medium from coeffs from optimizer. Parameters ---------- coeffs : np.ndarray[float] Array of real coefficients for the pole residue fit. Returns ------- :class:`.PoleResidue` Dispersive medium corresponding to this set of ``coeffs``. """poles_complex=DispersionFitter._coeffs_to_poles(coeffs)returnPoleResidue(poles=poles_complex,frequency_range=self.frequency_range)def_fit_single(self,num_poles:int=3,guess:PoleResidue=None,)->Tuple[PoleResidue,float]:"""Perform a single fit to the data and return optimization result. Parameters ---------- num_poles : int = 3 Number of poles in the model. guess : :class:`.PoleResidue` = None A PoleResidue object to use a guess instead of a random one. Returns ------- Tuple[:class:`.PoleResidue`, float] Results of single fit: (dispersive medium, RMS error). """# NOTE: Not useddefconstraint(coeffs,_grad=None):"""Evaluate the nonlinear stability criterion of Hongjin Choi, Jae-Woo Baek, and Kyung-Young Jung, "Comprehensive Study on Numerical Aspects of Modified Lorentz Model Based Dispersive FDTD Formulations," IEEE TAP 2019. Parameters ---------- coeffs : np.ndarray[float] Array of real coefficients for the pole residue fit. _grad : np.ndarray[float] Gradient of ``constraint`` w.r.t coeffs, not used. Returns ------- float Value of constraint. """poles_a,poles_c=DispersionFitter._unpack_coeffs(coeffs)a_real=poles_a.reala_imag=poles_a.imagc_real=poles_c.realc_imag=poles_c.imagprstar=a_real*c_real+a_imag*c_imagres=2*prstar*a_real-c_real*(a_real*a_real+a_imag*a_imag)res[res>=0]=0returnnp.sum(res)defobjective(coeffs,_grad=None):"""Objective function for fit Parameters ---------- coeffs : np.ndarray[float] Array of real coefficients for the pole residue fit. _grad : np.ndarray[float] Gradient of ``objective`` w.r.t coeffs, not used. Returns ------- float RMS error corresponding to current coeffs. """medium=self._make_medium(coeffs)eps_model=medium.eps_model(self.freqs)residual=self.eps_data-eps_model# cons = constraint(coeffs, _grad)returnnp.sqrt(np.sum(np.square(np.abs(residual)))/len(self.eps_data))# set initial guessnum_coeffs=num_poles*4ifguessisnotNone:iflen(guess.poles)!=num_poles:raiseValueError(f"The number of poles ({len(guess.poles)}) in provided guess 'PoleResidue' "f"medium does not match argument 'num_poles' = {num_poles})")coeffs0=self._poles_to_coeffs(guess.poles)else:coeffs0=2*(np.random.random(num_coeffs)-0.5)# set boundsbounds_upper=np.zeros(num_coeffs,dtype=float)bounds_lower=np.zeros(num_coeffs,dtype=float)ifself.lossy:# if lossy, the real parts can take on valuesbounds_lower[0::4]=-np.infbounds_upper[2::4]=np.infcoeffs0[0::4]=-np.abs(coeffs0[0::4])coeffs0[2::4]=+np.abs(coeffs0[2::4])else:# otherwise, they need to be 0coeffs0[0::2]=0bounds_lower[1::2]=-np.infbounds_upper[1::2]=np.infbounds=list(zip(bounds_lower,bounds_upper))# TODO: set up constraint properlyscipy_constraint=opt.NonlinearConstraint(constraint,lb=0,ub=np.inf)# TODO: set options properlyres=opt.minimize(objective,coeffs0,args=(),method="SLSQP",bounds=bounds,constraints=(scipy_constraint,),tol=1e-7,callback=None,options=dict(maxiter=10000),)coeffs=res.xrms_error=objective(coeffs)# set the latest fitmedium=self._make_medium(coeffs)returnmedium,rms_error
[docs]@add_ax_if_nonedefplot(self,medium:PoleResidue=None,wvl_um:ArrayFloat1D=None,ax:Ax=None,)->Ax:"""Make plot of model vs data, at a set of wavelengths (if supplied). Parameters ---------- medium : :class:`.PoleResidue` = None medium containing model to plot against data wvl_um : ArrayFloat1D = None Wavelengths to evaluate model at for plot in micrometers. ax : matplotlib.axes._subplots.Axes = None Axes to plot the data on, if None, a new one is created. Returns ------- matplotlib.axis.Axes Matplotlib axis corresponding to plot. """ax.plot(self.wvl_um,self.n_data,"x",label="n (data)")ifself.lossy:ax.plot(self.wvl_um,self.k_data,"+",label="k (data)")ifmedium:ifwvl_umisNone:wvl_um=C_0/self.freqseps_model=medium.eps_model(C_0/wvl_um)n_model,k_model=AbstractMedium.eps_complex_to_nk(eps_model)ax.plot(wvl_um,n_model,label="n (model)")ifself.lossy:ax.plot(wvl_um,k_model,label="k (model)")ax.set_ylabel("n, k")ax.set_xlabel("Wavelength ($\\mu m$)")ax.legend()returnax
@staticmethoddef_validate_url_load(data_load:List):"""Validate if the loaded data from URL is valid The data list should be in this format: [["wl", "n"], [float, float], . . . . . . (if lossy) ["wl", "k"], [float, float], . . . . . .]] Parameters ---------- data_load : List Loaded data from URL Raises ------ ValidationError Or other exceptions """has_k=0ifdata_load[0][0]!="wl"ordata_load[0][1]!="n":raiseValidationError("Invalid URL. The file should begin with ['wl','n']. ""Or make sure that you have supplied an appropriate delimiter.")forrowindata_load[1:]:ifrow[0]=="wl":ifrow[1]=="k":has_k+=1else:raiseValidationError("Invalid URL. The file is not well formatted for ['wl', 'k'] data.")else:# make sure the rest is float typetry:_=[float(x)forxinrow]exceptExceptionase:raiseValidationError("Invalid URL. Float data cannot be recognized.")fromeifhas_k>1:raiseValidationError("Invalid URL. Too many k labels.")
[docs]@classmethoddeffrom_url(cls,url_file:str,delimiter:str=",",ignore_k:bool=False,**kwargs)->DispersionFitter:"""loads :class:`DispersionFitter` from url linked to a csv/txt file that contains wavelength (micron), n, and optionally k data. Preferred from refractiveindex.info. Hint ---- The data file from url should be in this format (delimiter not displayed here, and note that the strings such as "wl", "n" need to be included in the file): * For lossless media:: wl n [float] [float] . . . . . . * For lossy media:: wl n [float] [float] . . . . . . wl k [float] [float] . . . . . . Parameters ---------- url_file : str Url link to the data file. e.g. "https://refractiveindex.info/data_csv.php?datafile=database/data-nk/main/Ag/Johnson.yml" delimiter : str = "," E.g. in refractiveindex.info, it'll be "," for csv file, and "\\\\t" for txt file. ignore_k : bool = False Ignore the k data if they are present, so the fitted material is lossless. Returns ------- :class:`DispersionFitter` A :class:`DispersionFitter` instance. """resp=requests.get(url_file,verify=Env.current.ssl_verify)try:resp.raise_for_status()exceptExceptionase:raiseWebError("Connection to the website failed. Please provide a valid URL.")fromedata_url=list(csv.reader(codecs.iterdecode(resp.iter_lines(),"utf-8"),delimiter=delimiter))data_url=list(data_url)# first validate datacls._validate_url_load(data_url)# parsing the datan_lam=[]k_lam=[]# the two variables contain [wvl_um, n(k)]has_k=0# whether k is in the dataforrowindata_url[1:]:ifhas_k==1:k_lam.append([float(x)forxinrow])elifrow[0]=="wl":has_k+=1else:n_lam.append([float(x)forxinrow])n_lam=np.array(n_lam)k_lam=np.array(k_lam)ifhas_k==1andnotignore_k:ifn_lam.shape==k_lam.shapeandnp.allclose(n_lam[:,0],k_lam[:,0]):returncls(wvl_um=n_lam[:,0],n_data=n_lam[:,1],k_data=k_lam[:,1],**kwargs)raiseValidationError("Invalid URL. Both n and k should be provided at each wavelength.")returncls(wvl_um=n_lam[:,0],n_data=n_lam[:,1],**kwargs)
[docs]@classmethoddeffrom_file(cls,fname:str,**loadtxt_kwargs)->DispersionFitter:"""Loads :class:`DispersionFitter` from file containing wavelength, n, k data. Parameters ---------- fname : str Path to file containing wavelength (um), n, k (optional) data in columns. **loadtxt_kwargs Kwargs passed to ``np.loadtxt``, such as ``skiprows``, ``delimiter``. Hint ---- The data file should be in this format (``delimiter`` and ``skiprows`` can be customized in ``**loadtxt_kwargs``): * For lossless media:: wl n [float] [float] . . . . . . * For lossy media:: wl n k [float] [float] [float] . . . . . . . . . Returns ------- :class:`DispersionFitter` A :class:`DispersionFitter` instance. """data=np.loadtxt(fname,**loadtxt_kwargs)iflen(data.shape)!=2:raiseValueError("data must contain [wavelength, ndata, kdata] in columns")ifdata.shape[-1]notin(2,3):raiseValueError("data must have either 2 or 3 rows (if k data)")ifdata.shape[-1]==2:wvl_um,n_data=data.Tk_data=Noneelse:wvl_um,n_data,k_data=data.Treturncls(wvl_um=wvl_um,n_data=n_data,k_data=k_data)
[docs]@classmethoddeffrom_complex_permittivity(cls,wvl_um:ArrayFloat1D,eps_real:ArrayFloat1D,eps_imag:ArrayFloat1D=None,wvl_range:Tuple[Optional[float],Optional[float]]=(None,None),)->DispersionFitter:"""Loads :class:`DispersionFitter` from wavelength and complex relative permittivity data Parameters ---------- wvl_um : ArrayFloat1D Wavelength data in micrometers. eps_real : ArrayFloat1D Real parts of relative permittivity data eps_imag : Optional[ArrayFloat1D] Imaginary parts of relative permittivity data; `None` for lossless medium. wvg_range : Tuple[Optional[float], Optional[float]] Wavelength range [wvl_min,wvl_max] for fitting. Returns ------- :class:`DispersionFitter` A :class:`DispersionFitter` instance. """ifeps_imagisNone:n,_=AbstractMedium.eps_complex_to_nk(eps_real+0j)returncls(wvl_um=wvl_um,n_data=n,wvl_range=wvl_range)n,k=AbstractMedium.eps_complex_to_nk(eps_real+eps_imag*1j)returncls(wvl_um=wvl_um,n_data=n,k_data=k,wvl_range=wvl_range)
[docs]@classmethoddeffrom_loss_tangent(cls,wvl_um:ArrayFloat1D,eps_real:ArrayFloat1D,loss_tangent:ArrayFloat1D,wvl_range:Tuple[Optional[float],Optional[float]]=(None,None),)->DispersionFitter:"""Loads :class:`DispersionFitter` from wavelength and loss tangent data. Parameters ---------- wvl_um : ArrayFloat1D Wavelength data in micrometers. eps_real : ArrayFloat1D Real parts of relative permittivity data loss_tangent : Optional[ArrayFloat1D] Loss tangent data, defined as the ratio of imaginary and real parts of permittivity. wvl_range : Tuple[Optional[float], Optional[float]] Wavelength range [wvl_min,wvl_max] for fitting. Returns ------- :class:`DispersionFitter` A :class:`DispersionFitter` instance. """eps_complex=AbstractMedium.eps_loss_tangent_to_eps_complex(eps_real,loss_tangent)n,k=AbstractMedium.eps_complex_to_nk(eps_complex)returncls(wvl_um=wvl_um,n_data=n,k_data=k,wvl_range=wvl_range)