Source code for tidy3d.config.migrations

"""Schema versioning and migration helpers for tidy3d configuration files."""

from __future__ import annotations

import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import tomlkit

from tidy3d.log import log

from .registry import get_sections
from .schema_utils import TOP_LEVEL_METADATA_KEYS, _resolve_model_type

if TYPE_CHECKING:
    from typing import Optional

    from pydantic import BaseModel

CONFIG_VERSION_KEY = "config_version"
CURRENT_CONFIG_VERSION = 1

AUTO_MIGRATE_ENV = "TIDY3D_CONFIG_AUTO_MIGRATE"
FORWARD_COMPAT_ENV = "TIDY3D_CONFIG_FORWARD_COMPAT"

FORWARD_COMPAT_STRICT = "strict"
FORWARD_COMPAT_BEST_EFFORT = "best-effort"

MigrationFunc = Callable[[tomlkit.TOMLDocument], None]
_MIGRATIONS: dict[int, list[MigrationFunc]] = {}
_MIGRATION_CHAIN_VALIDATED_UP_TO = 0


[docs] def register_migration(version: int) -> Callable[[MigrationFunc], MigrationFunc]: """Register a schema migration from ``version`` to ``version + 1``.""" def decorator(func: MigrationFunc) -> MigrationFunc: _MIGRATIONS.setdefault(version, []).append(func) _invalidate_migration_chain() return func return decorator
def get_config_version(source: Any) -> int: """Return the config version stored in a dict or TOML document.""" if isinstance(source, tomlkit.TOMLDocument): raw = source.get(CONFIG_VERSION_KEY) elif isinstance(source, dict): raw = source.get(CONFIG_VERSION_KEY) else: raw = None if raw is None: return 0 if isinstance(raw, bool): return _warn_invalid_version(raw) if isinstance(raw, int): version = raw elif isinstance(raw, str): try: version = int(raw) except (TypeError, ValueError): return _warn_invalid_version(raw) else: return _warn_invalid_version(raw) if version < 0: return _warn_invalid_version(version) return version def _warn_invalid_version(value: Any) -> int: log.warning(f"Invalid '{CONFIG_VERSION_KEY}' value {value!r}; falling back to version 0.") return 0 def set_config_version(document: tomlkit.TOMLDocument, version: int) -> None: """Set the config version in a TOML document.""" document[CONFIG_VERSION_KEY] = int(version) def strip_config_version(data: dict[str, Any]) -> dict[str, Any]: """Return a copy of ``data`` without the config version key.""" if CONFIG_VERSION_KEY not in data: return data cleaned = dict(data) cleaned.pop(CONFIG_VERSION_KEY, None) return cleaned def inject_config_version(data: dict[str, Any], version: int) -> dict[str, Any]: """Return a copy of ``data`` with the config version key set.""" updated = dict(data) updated[CONFIG_VERSION_KEY] = int(version) return updated def auto_migrate_enabled() -> bool: """Return True if automatic write-back is enabled.""" raw = os.getenv(AUTO_MIGRATE_ENV) if raw is None: return True value = raw.strip().lower() if value in {"0", "false", "no", "off"}: return False if value in {"1", "true", "yes", "on"}: return True log.warning(f"Unrecognized '{AUTO_MIGRATE_ENV}' value {raw!r}; defaulting to auto-migrate.") return True def forward_compat_mode() -> str: """Return the forward-compat behavior for newer config versions.""" raw = os.getenv(FORWARD_COMPAT_ENV) if not raw: return FORWARD_COMPAT_BEST_EFFORT value = raw.strip().lower() if value in {FORWARD_COMPAT_STRICT, FORWARD_COMPAT_BEST_EFFORT}: return value log.warning(f"Unrecognized '{FORWARD_COMPAT_ENV}' value {raw!r}; defaulting to best-effort.") return FORWARD_COMPAT_BEST_EFFORT def apply_migrations(document: tomlkit.TOMLDocument, from_version: int, to_version: int) -> None: """Apply registered migrations to the document.""" if from_version >= to_version: return if from_version < 0: from_version = 0 _ensure_migration_chain(to_version) for version in range(from_version, to_version): for migrator in _MIGRATIONS.get(version, []): migrator(document) def best_effort_filter(data: dict[str, Any]) -> dict[str, Any]: """Drop unknown keys from a config payload using the registered schemas.""" sections = get_sections() if not sections: return strip_config_version(data) filtered: dict[str, Any] = {} for key, value in data.items(): if key == CONFIG_VERSION_KEY: continue if key in TOP_LEVEL_METADATA_KEYS: filtered[key] = value continue if key == "plugins": filtered_plugins = _filter_plugins(value, sections) if filtered_plugins is not None: filtered["plugins"] = filtered_plugins continue schema = sections.get(key) if schema is None: log.warning( f"Ignoring unknown configuration section '{key}' during best-effort parsing." ) continue if isinstance(value, dict): filtered[key] = _filter_section_data(schema, value) else: log.warning( f"Configuration section '{key}' should be a table; " "ignoring non-table value during best-effort parsing." ) filtered[key] = {} return filtered def _filter_plugins(value: Any, sections: dict[str, type[BaseModel]]) -> Optional[dict[str, Any]]: if not isinstance(value, dict): log.warning( "Configuration section 'plugins' should be a table; " "ignoring non-table value during best-effort parsing." ) return None filtered: dict[str, Any] = {} for plugin_name, plugin_data in value.items(): schema = sections.get(f"plugins.{plugin_name}") if schema is None: log.warning( f"Ignoring unknown plugin configuration section '{plugin_name}' " "during best-effort parsing." ) continue if isinstance(plugin_data, dict): filtered[plugin_name] = _filter_section_data(schema, plugin_data) else: log.warning( f"Configuration plugin section '{plugin_name}' should be a table; " "ignoring non-table value during best-effort parsing." ) filtered[plugin_name] = {} return filtered def _filter_section_data(schema: type[BaseModel], data: dict[str, Any]) -> dict[str, Any]: filtered: dict[str, Any] = {} for field_name, field in schema.model_fields.items(): if field_name not in data: continue value = data[field_name] nested_model = _resolve_model_type(field.annotation) if nested_model is not None: if isinstance(value, dict): filtered[field_name] = _filter_section_data(nested_model, value) continue if isinstance(value, list): filtered[field_name] = [ _filter_section_data(nested_model, item) if isinstance(item, dict) else item for item in value ] continue filtered[field_name] = value return filtered def _ensure_migration_chain(target_version: int) -> None: global _MIGRATION_CHAIN_VALIDATED_UP_TO if target_version <= _MIGRATION_CHAIN_VALIDATED_UP_TO: return _validate_migration_chain(target_version) _MIGRATION_CHAIN_VALIDATED_UP_TO = target_version def _invalidate_migration_chain() -> None: global _MIGRATION_CHAIN_VALIDATED_UP_TO _MIGRATION_CHAIN_VALIDATED_UP_TO = 0 def _validate_migration_chain(target_version: int) -> None: for version in range(target_version): if version not in _MIGRATIONS or not _MIGRATIONS[version]: raise RuntimeError(f"Missing config migration step for v{version} -> v{version + 1}.") @register_migration(0) def _migrate_v0_to_v1(document: tomlkit.TOMLDocument) -> None: """Initial schema migration (no-op).""" return None __all__ = [ "AUTO_MIGRATE_ENV", "CONFIG_VERSION_KEY", "CURRENT_CONFIG_VERSION", "FORWARD_COMPAT_BEST_EFFORT", "FORWARD_COMPAT_ENV", "FORWARD_COMPAT_STRICT", "apply_migrations", "auto_migrate_enabled", "best_effort_filter", "forward_compat_mode", "get_config_version", "inject_config_version", "register_migration", "set_config_version", "strip_config_version", ]