"""Fit PoleResidue Dispersion models to optical NK data based on web service"""
from __future__ import annotations
import ssl
from typing import Tuple, Optional
from enum import Enum
import requests
import pydantic.v1 as pydantic
from pydantic.v1 import PositiveInt, NonNegativeFloat, PositiveFloat, Field, validator
from ...log import log
from ...components.base import Tidy3dBaseModel, skip_if_fields_missing
from ...components.types import Literal
from ...components.medium import PoleResidue
from ...constants import MICROMETER, HERTZ
from ...exceptions import WebError, Tidy3dError, SetupError
from tidy3d.web.core.http_util import get_headers
from tidy3d.web.core.environment import Env
from .fit import DispersionFitter
BOUND_MAX_FACTOR = 10
URL_ENV = {
"local": "http://127.0.0.1:8000",
"dev": "https://tidy3d-service.dev-simulation.cloud",
"prod": "https://tidy3d-service.simulation.cloud",
}
class ExceptionCodes(Enum):
"""HTTP exception codes to handle individually."""
GATEWAY_TIMEOUT = 504
NOT_FOUND = 404
[docs]
class AdvancedFitterParam(Tidy3dBaseModel):
"""Advanced fitter parameters"""
bound_amp: NonNegativeFloat = Field(
None,
title="Upper bound of oscillator strength",
description="Upper bound of real and imagniary part of oscillator "
"strength ``c`` in the model :class:`.PoleResidue` (The default 'None' will trigger "
"automatic setup based on the frequency range of interest).",
units=HERTZ,
)
bound_f: NonNegativeFloat = Field(
None,
title="Upper bound of pole frequency",
description="Upper bound of real and imaginary part of ``a`` that corresponds to pole "
"damping rate and frequency in the model :class:`.PoleResidue` (The default 'None' "
"will trigger automatic setup based on the frequency range of interest).",
units=HERTZ,
)
bound_f_lower: NonNegativeFloat = Field(
0.0,
title="Lower bound of pole frequency",
description="Lower bound of imaginary part of ``a`` that corresponds to pole "
"frequency in the model :class:`.PoleResidue`.",
units=HERTZ,
)
bound_eps_inf: float = Field(
10.0,
title="Upper bound of epsilon at infinity frequency",
description="Upper bound of epsilon at infinity frequency. It must be no less than 1.",
ge=1,
)
constraint: Literal["hard", "soft"] = Field(
"hard",
title="Type of constraint for stability",
description="Stability constraint: 'hard' constraints are generally recommended since "
"they are faster to compute per iteration, and they often require fewer iterations to "
"converge since the search space is smaller. But sometimes the search space is "
"so restrictive that all good solutions are missed, then please try the 'soft' constraints "
"for larger search space. However, both constraints improve stability equally well.",
)
nlopt_maxeval: PositiveInt = Field(
5000,
title="Number of inner iterations",
description="Number of iterations in each inner optimization.",
)
random_seed: Optional[int] = Field(
0,
title="Random seed for starting coefficients",
description="The fitting tool performs global optimizations with random "
"starting coefficients. With the same random seed, one obtains identical "
"results when re-running the fitter; on the other hand, if "
"one wants to re-run the fitter several times to obtain the best results, "
"the value of the seed should be changed, or set to ``None`` so that "
"the starting coefficients are different each time. ",
ge=0,
lt=2**32,
)
@validator("bound_f_lower", always=True)
@skip_if_fields_missing(["bound_f"])
def _validate_lower_frequency_bound(cls, val, values):
"""bound_f_lower cannot be larger than bound_f."""
if values["bound_f"] is not None and val > values["bound_f"]:
raise SetupError(
"The upper bound 'bound_f' cannot be smaller "
"than the lower bound 'bound_f_lower'."
)
return val
class FitterData(AdvancedFitterParam):
"""Data class for request body of Fitter where dipsersion data is input through tuple."""
wvl_um: Tuple[float, ...] = Field(
...,
title="Wavelengths",
description="A set of wavelengths for dispersion data.",
units=MICROMETER,
)
n_data: Tuple[float, ...] = Field(
...,
title="Index of refraction",
description="Real part of the complex index of refraction at each wavelength.",
)
k_data: Tuple[float, ...] = Field(
None,
title="Extinction coefficient",
description="Imaginary part of the complex index of refraction at each wavelength.",
)
num_poles: PositiveInt = Field(
1, title="Number of poles", description="Number of poles in model."
)
num_tries: PositiveInt = Field(
50,
title="Number of tries",
description="Number of optimizations to run with different initial guess.",
)
tolerance_rms: NonNegativeFloat = Field(
0.0,
title="RMS error tolerance",
description="RMS error below which the fit is successful and result is returned.",
)
bound_amp: PositiveFloat = Field(
100.0,
title="Upper bound of oscillator strength",
description="Upper bound of oscillator strength in the model.",
units="eV",
)
bound_f: PositiveFloat = Field(
100.0,
title="Upper bound of pole frequency",
description="Upper bound of pole frequency in the model.",
units="eV",
)
@staticmethod
def create(
fitter: DispersionFitter,
num_poles: PositiveInt,
num_tries: PositiveInt,
tolerance_rms: NonNegativeFloat,
advanced_param: AdvancedFitterParam,
) -> FitterData:
"""Setup FitterData to be provided to web service
Parameters
----------
fitter : DispersionFitter
Fitter with the data to fit.
num_poles : PositiveInt
Number of poles in the model.
num_tries : PositiveInt
Number of optimizations to run with random initial guess.
tolerance_rms : NonNegativeFloat
RMS error below which the fit is successful and the result is returned.
advanced_param : :class:`AdvancedFitterParam`
Other advanced parameters.
Returns
-------
:class:`FitterData`
Data class for request body of Fitter where dispersion
data is input through tuple.
"""
# set up bound_f, bound_amp
if advanced_param.bound_f is None:
new_bound_f = (
advanced_param.bound_f_lower + fitter.frequency_range[1] * BOUND_MAX_FACTOR
)
advanced_param = advanced_param.copy(update={"bound_f": new_bound_f})
if advanced_param.bound_amp is None:
new_bound_amp = fitter.frequency_range[1] * BOUND_MAX_FACTOR
advanced_param = advanced_param.copy(update={"bound_amp": new_bound_amp})
wvl_um, n_data, k_data = fitter.data_in_range
if fitter.lossy:
k_data = k_data.tolist()
else:
k_data = None
task = FitterData(
wvl_um=wvl_um.tolist(),
n_data=n_data.tolist(),
k_data=k_data,
num_poles=num_poles,
num_tries=num_tries,
tolerance_rms=tolerance_rms,
bound_amp=fitter._Hz_to_eV(advanced_param.bound_amp),
bound_f=fitter._Hz_to_eV(advanced_param.bound_f),
bound_f_lower=fitter._Hz_to_eV(advanced_param.bound_f_lower),
bound_eps_inf=advanced_param.bound_eps_inf,
constraint=advanced_param.constraint,
nlopt_maxeval=advanced_param.nlopt_maxeval,
random_seed=advanced_param.random_seed,
)
return task
@staticmethod
def _set_url(config_env: Literal["default", "dev", "prod", "local"] = "default"):
"""Set the url of python web service
Parameters
----------
config_env : Literal["default", "dev", "prod", "local"], optional
Service environment to pick from
"""
_env = config_env
if _env == "default":
_env = "dev" if "dev" in Env.current.web_api_endpoint else "prod"
return URL_ENV[_env]
@staticmethod
def _setup_server(url_server: str):
"""set up web server access
Parameters
----------
url_server : str
URL for the server
"""
try:
# test connection
resp = requests.get(f"{url_server}/health", verify=Env.current.ssl_verify)
resp.raise_for_status()
except (requests.exceptions.SSLError, ssl.SSLError):
log.info("Retrying with SSL verification disabled.")
Env.current.ssl_verify = False
resp = requests.get(f"{url_server}/health", verify=Env.current.ssl_verify)
except Exception as e:
raise WebError("Connection to the server failed. Please try again.") from e
return get_headers()
def run(self) -> Tuple[PoleResidue, float]:
"""Execute the data fit using the stable fitter in the server.
Returns
-------
Tuple[:class:`.PoleResidue`, float]
Best results of multiple fits: (dispersive medium, RMS error).
"""
url_server = self._set_url("default")
headers = self._setup_server(url_server)
resp = requests.post(
f"{url_server}/dispersion/fit",
headers=headers,
data=self.json(),
verify=Env.current.ssl_verify,
)
try:
resp.raise_for_status()
except Exception as e:
if resp.status_code == ExceptionCodes.GATEWAY_TIMEOUT.value:
msg = (
(
"Fitter failed due to timeout. Try to decrease the number of tries or "
"inner iterations, to relax the RMS tolerance, or to use the 'hard' "
"constraint."
)
if self.constraint != "hard"
else (
"Fitter failed due to timeout. Try to decrease the number of tries or "
"inner iterations, or to relax the RMS tolerance."
)
)
raise Tidy3dError(msg) from e
raise WebError(
"Fitter failed. Try again, tune the parameters, or contact us for more help."
) from e
run_result = resp.json()
best_medium = PoleResidue.parse_raw(run_result["message"])
best_rms = float(run_result["rms"])
if best_rms < self.tolerance_rms:
log.info("Found optimal fit with RMS error %.3g", best_rms)
else:
log.warning(
"Unable to fit with RMS error under 'tolerance_rms' of %.3g", self.tolerance_rms
)
log.info("Returning best fit with RMS error %.3g", best_rms)
return best_medium, best_rms
[docs]
def run(
fitter: DispersionFitter,
num_poles: PositiveInt = 1,
num_tries: PositiveInt = 50,
tolerance_rms: NonNegativeFloat = 1e-2,
advanced_param: AdvancedFitterParam = AdvancedFitterParam(),
) -> Tuple[PoleResidue, float]:
"""Execute the data fit using the stable fitter in the server.
Parameters
----------
fitter : DispersionFitter
Fitter with the data to fit.
num_poles : PositiveInt, optional
Number of poles in the model.
num_tries : PositiveInt, optional
Number of optimizations to run with random initial guess.
tolerance_rms : NonNegativeFloat, optional
RMS error below which the fit is successful and the result is returned.
advanced_param : :class:`AdvancedFitterParam`, optional
Advanced parameters passed on to the server.
Returns
-------
Tuple[:class:`.PoleResidue`, float]
Best results of multiple fits: (dispersive medium, RMS error).
"""
task = FitterData.create(fitter, num_poles, num_tries, tolerance_rms, advanced_param)
return task.run()
[docs]
class StableDispersionFitter(DispersionFitter):
"""Deprecated."""
@pydantic.root_validator()
def _deprecate_stable_fitter(cls, values):
log.warning(
"'StableDispersionFitter' has been deprecated. Use 'DispersionFitter' with "
"'tidy3d.plugins.dispersion.web.run' to access the stable fitter from the web server."
)
return values
[docs]
def fit(
self,
num_poles: PositiveInt = 1,
num_tries: PositiveInt = 50,
tolerance_rms: NonNegativeFloat = 1e-2,
guess: PoleResidue = None,
advanced_param: AdvancedFitterParam = AdvancedFitterParam(),
) -> Tuple[PoleResidue, float]:
"""Deprecated."""
return run(self, num_poles, num_tries, tolerance_rms, advanced_param)