Bayer sorter

Bayer sorter#

The bayer sorter challenge involves the design of a Si3N4 metasurface to split light in a wavelength-dependent way, so that red, green, and blue light is predominantly collected in red, green, and blue subpixels in a sensor array. The challenge is based on Pixel-level Bayer-type colour router based on metasurfaces by Zou et al.

Simulating an existing design#

We’ll begin by loading, visualizing, and simulating the design from Supplementary Figure 2 of Zou et al.

import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

design = onp.genfromtxt("../../../reference_designs/bayer/zou.csv", delimiter=",")

plt.figure(figsize=(3, 3))
ax = plt.subplot(111)
im = plt.imshow(1 - design, cmap="gray")
im.set_clim([-2, 1])
ax.set_xticks([])
ax.set_yticks([])
for c in measure.find_contours(design):
    plt.plot(c[:, 1], c[:, 0], 'k', lw=1)
../../_images/65342b615f2290570a74690a6ff79190230228d272a11f113a92750cca3cf0a8.png

Next, create the bayer_sorter challenge, which enables us to simulate and evaluate a bayer sorter design.

The default simulation parameters are chosen to balance accuracy and simulation cost. For this notebook, we’ll override these with settings that yield more accurate results: more terms in the Fourier basis, and more wavelengths.

import dataclasses
import jax.numpy as jnp
from invrs_gym.challenges.bayer import challenge as bayer_challenge

challenge = bayer_challenge.bayer_sorter(
    sim_params=dataclasses.replace(
        bayer_challenge.BAYER_SIM_PARAMS,
        approximate_num_terms=1000,
        wavelength=jnp.arange(0.405, 0.7, 0.02),
    )
)

The params or optimization variables of the challenge include the metasurface pattern and also the metasurface thickness and metasurface-to-focal-plane separation. We’ll obtain default initial parameters from the challenge, and then overwrite the metasurface pattern with the array loaded and plotted above.

import jax

params = challenge.component.init(jax.random.PRNGKey(0))
assert params["density_metasurface"].shape == design.shape
params["density_metasurface"].array = design

Now, simulate the bayer sorter.

response, aux = challenge.component.response(params)

The response contains the transmission for normally-incident plane wave at the specified wavelengths for both x-polarized and y-polarized fields. The transmission is reported for four sub-pixels; the first is for red, the second and third for green, and the final is for blue.

Plot the transmission into red, green, and blue subpixels

# Average the transmission over the two different incident polarizations.
transmission = onp.mean(response.transmission, axis=-2)

transmission_blue_pixel = transmission[:, 0]
transmission_green_pixel = transmission[:, 1] + transmission[:, 2]
transmission_red_pixel = transmission[:, 3]

plt.plot(response.wavelength, transmission_blue_pixel, "bo-", lw=3)
plt.plot(response.wavelength, transmission_green_pixel, "go-", lw=3)
plt.plot(response.wavelength, transmission_red_pixel, "ro-", lw=3)
plt.xlabel("Wavelength")
plt.ylabel("Sub-pixel transmission")
_ = plt.ylim(-0.05, 0.7)
../../_images/0fb824005d95aa8c6f4c1fb8fa5996babb9bd2b8945f97c00a8090f4d7a8d5e7.png

The result is in very close agreement with Supplementary figure 7 of Zou et al.

The transmission is computed by calculating the Poynting flux on the real-space grid at the focal plane, and summing within each sub-pixel quadrant. Since the fields are automatically computed during the course of a simulation, they are returned in aux.

Let’s plot the fields in the focal plane for each of the wavelengths.

Hide code cell content
import base64
import zlib
from matplotlib import colors

# Code is from https://github.com/rsmith-nl/wavelength_to_rgb.
_CTBL = (
    b"eNrV0ltr0AUAxuHn1rqoiAqSiA6EEJ3ogBUJeeFFRUdJOjOyWau1bLZpGztnM1oyy"
    b"2yYmcmMETPNlixNVpmssobKMJuYDVnNGksz0zTe5C9BXyF4v8Dz4y1RMlPpLGVlKs"
    b"pVVqh+Vu0cDVVa5mqdp61Ge63FdTrqLWuwolFno64m3U3WNutp1ttsY7O+Jpub9Df"
    b"a2migwY56O+sM1dpTY3iekblGq4zNcWC2QxWOlDtWJqVSIg/JfTJd7pRbZZpMlZtk"
    b"slwtl8skuUjOk3PkDDlFnNipcWZMjAtjUlwR18aNcXNMi9virpgRD0ZJlMZTMTuqo"
    b"ibqoyUWxCuxKJbEm/F2dEZXrI4PYn1siL7YHP3xTWyLwfg+9sRwjMT+GI/f4884Fj"
    b"mxP2S/7JVB+Vr6pEe65C1ZJC9KjTxduKcX1smF74TMhDgtzopz4/y4uGBdFlfFdXF"
    b"DTImpBe6WuD3ujnvj/ni4ID4WTxTKZ6IyqgtoXTTGC9EaL8fCgvt6dPwrXhmrCnR3"
    b"rIl18VH0xicF/fPYEl/G1hiI7UWA72KoaPBj7Iufigxj8VtR4nAcjeMnYxyXo3JYD"
    b"sq4/CqjMiLD8oPsll1Fp+0yUNTqly/kU9kkG2S9fChrpLtIuErekeWyVN6Q16Rd2m"
    b"SBzJcmqSvqVkulVMiTMkselUfkAZkh98gd/znZFLlerpEr5VK5RC6QiXK2nC4TTv7"
    b"sf7C/OcZfHOEwhzjIAcYZ4xdG+ZkR9jHMXvawmyF2sZNBdrCNAb5lK1/RzxY28xl9"
    b"bGIjH9PLenpYx1rep5v36OJdOlnJCpazjKV0sITFvEo7C2njJVqZTwtNNFBHLc9Tz"
    b"XNUMpsKyinjcUqZSQn/AJ7p9HY="
)


_CTBL = zlib.decompress(base64.b64decode(_CTBL))


def rgb(wavelength_nm):
    """Converts a wavelength between 380 and 780 nm to an RGB color tuple.

    Args:
        wavelength_nm: Wavelength in nanometers. It is rounded to the nearest integer.

    Returns:
        A 3-tuple (red, green, blue) of integers in the range 0-255.
    """
    wavelength_nm = int(round(wavelength_nm))
    if wavelength_nm < 380 or wavelength_nm > 780:
        raise ValueError("wavelength out of range")
    idx = (wavelength_nm - 380) * 3
    color_str = _CTBL[idx : idx + 3]
    return onp.asarray([int(i) for i in color_str]) / 255


def cmap_for_wavelength(wavelength_um, background_color = "k"):
    """Generates a"""
    color = rgb(wavelength_nm=wavelength_um * 1000)
    return colors.LinearSegmentedColormap.from_list(
        "b", [background_color, color], N=256
    )
x, y = aux["coordinates_xy"]
x = jnp.squeeze(x, axis=0)
y = jnp.squeeze(y, axis=0)
ex, ey, ez = aux["efield_xy"]
intensity = jnp.abs(ex)**2 + jnp.abs(ey)**2 + jnp.abs(ez)**2
intensity = jnp.mean(intensity, axis=-1)  # Average over polarizations

fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(9, 7))
axs = axs.flatten()
for i, wavelength in enumerate(response.wavelength):
    cmap = cmap_for_wavelength(wavelength)
    axs[i].pcolormesh(x, y, intensity[i, :, :], cmap=cmap)
    axs[i].set_ylim(axs[i].get_ylim()[::-1])
    axs[i].axis("equal")
    axs[i].axis(False)
    axs[i].plot([jnp.amin(x), jnp.amax(x)], [jnp.mean(y), jnp.mean(y)], "w--", lw=1)
    axs[i].plot([jnp.mean(x), jnp.mean(x)], [jnp.amin(y), jnp.amax(y)], "w--", lw=1)
    axs[i].set_title(f"$\lambda$={wavelength:.3f}$\mu$m", fontsize=10)

plt.subplots_adjust(wspace=0.05, hspace=0.25)
../../_images/5b1e6b769e5cadc4dbe05aba5a4aa6e22cf3abed438fa209ee05094fb6f2f2c8.png