Source code for tidy3d.plugins.adjoint.utils.filter

"""Spatial filtering Functions for adjoint plugin."""
from abc import ABC, abstractmethod

import pydantic.v1 as pd
import numpy as np
import jax.numpy as jnp
import jax.scipy as jsp

from ....components.base import Tidy3dBaseModel
from ....constants import MICROMETER
from ....log import log


class Filter(Tidy3dBaseModel, ABC):
    """Abstract filter class. Initializes with parameters and .evaluate() on a design."""

    @abstractmethod
    def evaluate(self, spatial_data: jnp.array) -> jnp.array:
        """Process supplied array containing spatial data."""


class AbstractCircularFilter(Filter, ABC):
    """Abstract circular filter class. Initializes with parameters and .evaluate() on a design."""

    radius: float = pd.Field(
        ...,
        title="Filter Radius",
        description="Radius of the filter to convolve with supplied spatial data. "
        "Note: the corresponding feature size expressed in the device is typically "
        "sqrt(3) times smaller than the radius. For best results, it is recommended to make your "
        "radius about twice as large as the desired feature size.",
        units=MICROMETER,
    )

    design_region_dl: float = pd.Field(
        ...,
        title="Grid Size in Design Region",
        description="Grid size in the design region. "
        "This sets the length scale for the conic convolution filter.",
        units=MICROMETER,
    )

    @property
    def filter_radius_pixels(self) -> int:
        """Filter radius in pixels."""
        return np.ceil(self.radius / self.design_region_dl)

    @pd.root_validator(pre=True)
    def _deprecate_feature_size(cls, values):
        """Extra warning for user using ``feature_size`` field."""
        if "feature_size" in values:
            raise pd.ValidationError(
                "The 'feature_size' field of circular filters available in 2.4 pre-releases was "
                "renamed to 'radius' for the official 2.4.0 release. "
                "If you're seeing this message, please change your script to use that field name."
            )
        return values

    @abstractmethod
    def make_kernel(self, coords_rad: jnp.array) -> jnp.array:
        """Function to make the kernel out of a coordinate grid of radius values."""

    @staticmethod
    def _check_kernel_size(kernel: jnp.array, signal_in: jnp.array) -> jnp.array:
        """Make sure kernel isn't larger than signal and warn and truncate if so."""

        kernel_shape = kernel.shape
        input_shape = signal_in.shape

        if any((k_shape > in_shape for k_shape, in_shape in zip(kernel_shape, input_shape))):
            # remove some pixels from the kernel to make things right
            new_kernel = kernel.copy()
            for axis, (len_kernel, len_input) in enumerate(zip(kernel_shape, input_shape)):
                if len_kernel > len_input:
                    rm_pixels_total = len_kernel - len_input
                    rm_pixels_edge = int(np.ceil(rm_pixels_total / 2))
                    indices_truncated = np.arange(rm_pixels_edge, len_kernel - rm_pixels_edge)
                    new_kernel = new_kernel.take(indices=indices_truncated.astype(int), axis=axis)

            log.warning(
                f"The filter input has shape {input_shape} whereas the "
                f"kernel has shape {kernel_shape}. "
                "These shapes are incompatible as the input must "
                "be larger than the kernel along all dimensions. "
                "The kernel will automatically be "
                f"resized to {new_kernel.shape} to be less than the input shape. "
                "If this is unexpected, "
                "either reduce the filter 'radius' or increase the input array's size."
            )

            return new_kernel

        return kernel

    def evaluate(self, spatial_data: jnp.array) -> jnp.array:
        """Process on supplied spatial data."""

        rho = jnp.squeeze(spatial_data)
        num_dims = len(rho.shape)

        # Builds the conic filter and apply it to design parameters.
        coords_1d = np.arange(-self.filter_radius_pixels, self.filter_radius_pixels + 1)
        meshgrid_args = [coords_1d.copy() for _ in range(num_dims)]
        meshgrid_coords = np.meshgrid(*meshgrid_args)
        coords_rad = np.sqrt(np.sum([np.square(v) for v in meshgrid_coords], axis=0))

        # construct the kernel
        kernel = self.make_kernel(coords_rad)

        # handle when kernel is too large compared to input
        kernel = self._check_kernel_size(kernel=kernel, signal_in=rho)

        # normalize by the kernel operating on a spatial_data of all ones
        num = jsp.signal.convolve(rho, kernel, mode="same")
        den = jsp.signal.convolve(jnp.ones_like(rho), kernel, mode="same")

        return num / den


[docs] class ConicFilter(AbstractCircularFilter): """Filter that convolves an image with a conical mask, used for larger feature sizes. Note ---- .. math:: filter(r) = max(radius - r, 0) """
[docs] def make_kernel(self, coords_rad: jnp.array) -> jnp.array: """Function to make the kernel out of a coordinate grid of radius values (in pixels).""" kernel = self.filter_radius_pixels - coords_rad kernel[coords_rad > self.filter_radius_pixels] = 0.0 return kernel
class CircularFilter(AbstractCircularFilter): """Filter that convolves an image with a circular mask, used for larger feature sizes. Note ---- .. math:: filter(r) = 1 if r <= radius else 0 Note ---- This uniform circular mask produces results that are harder to binarize than the conical mask. We recommend you use the ``ConicFilter`` instead for most applications. """ def make_kernel(self, coords_rad: jnp.array) -> jnp.array: """Function to make the kernel out of a coordinate grid of radius values (in pixels).""" # construct the kernel kernel = np.ones_like(coords_rad) kernel[coords_rad > self.filter_radius_pixels] = 0.0 return kernel
[docs] class BinaryProjector(Filter): """Projects a grayscale image towards min and max values using a smooth ``tanh`` function. Note ---- .. math:: v(x) = v_{min} + (v_{max} - v_{min}) \\frac{\\mathrm{tanh}(\\beta \\eta) + \\mathrm{tanh}(\\beta (x - \\eta))} {\\mathrm{tanh}(\\beta \\eta) + \\mathrm{tanh}(\\beta (1 - \\eta))} """ vmin: float = pd.Field(..., title="Min Value", description="Minimum value to project to.") vmax: float = pd.Field(..., title="Max Value", description="Maximum value to project to.") beta: float = pd.Field( 1.0, title="Beta", description="Steepness of the binarization, " "higher means more sharp transition " "at the expense of gradient accuracy and ease of optimization. " "Can be useful to ramp up in a scheduled way during optimization.", ) eta: float = pd.Field(0.5, title="Eta", description="Halfway point in projection function.") strict_binarize: bool = pd.Field( False, title="Binarize strictly", description="If ``False``, the binarization is still continuous between min and max. " "If ``True``, the values are snapped to the min and max values after projection.", )
[docs] def evaluate(self, spatial_data: jnp.array) -> jnp.array: """Process on supplied spatial data.""" # Applies a hyperbolic tangent binarization function. num = jnp.tanh(self.beta * self.eta) + jnp.tanh(self.beta * (spatial_data - self.eta)) den = jnp.tanh(self.beta * self.eta) + jnp.tanh(self.beta * (1.0 - self.eta)) rho_bar = num / den # Calculates the values from the transformed design parameters. vals = self.vmin + (self.vmax - self.vmin) * rho_bar if self.strict_binarize: vals = jnp.where(vals < (self.vmin + self.vmax) / 2, self.vmin, self.vmax) else: vals = jnp.where(vals < self.vmin, self.vmin, vals) vals = jnp.where(vals > self.vmax, self.vmax, vals) return vals