Bayer sorter

Bayer sorter#

The bayer sorter challenge involves the design of a Si3N4 metasurface to split light in a wavelength-dependent way, so that red, green, and blue light is predominantly collected in red, green, and blue subpixels in a sensor array. The challenge is based on Pixel-level Bayer-type colour router based on metasurfaces by Zou et al.

Simulating an existing design#

We’ll begin by loading, visualizing, and simulating the design from Supplementary Figure 2 of Zou et al.

import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

design = onp.genfromtxt("../../../reference_designs/bayer/zou.csv", delimiter=",")

plt.figure(figsize=(3, 3))
ax = plt.subplot(111)
im = plt.imshow(1 - design, cmap="gray")
im.set_clim([-2, 1])
ax.set_xticks([])
ax.set_yticks([])
for c in measure.find_contours(design):
    plt.plot(c[:, 1], c[:, 0], 'k', lw=1)
../../_images/65342b615f2290570a74690a6ff79190230228d272a11f113a92750cca3cf0a8.png

Next, create the bayer_sorter challenge, which enables us to simulate and evaluate a bayer sorter design.

The default simulation parameters are chosen to balance accuracy and simulation cost. For this notebook, we’ll override these with settings that yield more accurate results: more terms in the Fourier basis, and more wavelengths.

import dataclasses
import jax.numpy as jnp
from invrs_gym.challenges.bayer import challenge as bayer_challenge

challenge = bayer_challenge.bayer_sorter(
    sim_params=dataclasses.replace(
        bayer_challenge.BAYER_SIM_PARAMS,
        approximate_num_terms=1000,
        wavelength=jnp.arange(0.405, 0.7, 0.02),
    )
)

The params or optimization variables of the challenge include the metasurface pattern and also the metasurface thickness and metasurface-to-focal-plane separation. We’ll obtain default initial parameters from the challenge, and then overwrite the metasurface pattern with the array loaded and plotted above.

import jax

params = challenge.component.init(jax.random.PRNGKey(0))
assert params["density_metasurface"].shape == design.shape
params["density_metasurface"].array = design

Now, simulate the bayer sorter.

response, aux = challenge.component.response(params)

The response contains the transmission for normally-incident plane wave at the specified wavelengths for both x-polarized and y-polarized fields. The transmission is reported for four sub-pixels; the first is for red, the second and third for green, and the final is for blue.

Plot the transmission into red, green, and blue subpixels

# Average the transmission over the two different incident polarizations.
transmission = onp.mean(response.transmission, axis=-2)

transmission_blue_pixel = transmission[:, 0]
transmission_green_pixel = transmission[:, 1] + transmission[:, 2]
transmission_red_pixel = transmission[:, 3]

plt.plot(response.wavelength, transmission_blue_pixel, "bo-", lw=3)
plt.plot(response.wavelength, transmission_green_pixel, "go-", lw=3)
plt.plot(response.wavelength, transmission_red_pixel, "ro-", lw=3)
plt.xlabel("Wavelength")
plt.ylabel("Sub-pixel transmission")
_ = plt.ylim(-0.05, 0.7)
../../_images/b7fd8b7182705a877cc2820b2c58392727122f76508c1c47f5c20fdd8914e5c0.png

The result is in very close agreement with Supplementary figure 7 of Zou et al.

The transmission is computed by calculating the Poynting flux on the real-space grid at the focal plane, and summing within each sub-pixel quadrant. Since the fields are automatically computed during the course of a simulation, they are returned in aux.

Let’s plot the fields in the focal plane for each of the wavelengths.

import ccmaps

x, y = aux["coordinates_xy"]
x = jnp.squeeze(x, axis=0)
y = jnp.squeeze(y, axis=0)
ex, ey, ez = aux["efield_xy"]
intensity = jnp.abs(ex)**2 + jnp.abs(ey)**2 + jnp.abs(ez)**2
intensity = jnp.mean(intensity, axis=-1)  # Average over polarizations

fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(9, 7))
axs = axs.flatten()
for i, wavelength in enumerate(response.wavelength):
    cmap = ccmaps.cmap_for_wavelength(wavelength_nm=wavelength * 1000)
    axs[i].pcolormesh(x, y, intensity[i, :, :], cmap=cmap)
    axs[i].set_ylim(axs[i].get_ylim()[::-1])
    axs[i].axis("equal")
    axs[i].axis(False)
    axs[i].plot([jnp.amin(x), jnp.amax(x)], [jnp.mean(y), jnp.mean(y)], "w--", lw=1)
    axs[i].plot([jnp.mean(x), jnp.mean(x)], [jnp.amin(y), jnp.amax(y)], "w--", lw=1)
    axs[i].set_title(f"$\lambda$={wavelength:.3f}$\mu$m", fontsize=10)

plt.subplots_adjust(wspace=0.05, hspace=0.25)
../../_images/ce2b9aeeb1af452de0ce9819b3ffaba074069c703a980d48ffa2f5063795bdc1.png