# Photon extractor

The photon extractor challenge is based on "[Inverse-designed photon extractors for optically addressable defect qubits](https://opg.optica.org/optica/fulltext.cfm?uri=optica-7-12-1805)" by Chakravarthi et al.; it involves optimizing a GaP patterned layer on diamond substrate above an implanted nitrogen vacancy defect. An oxide hard mask used to pattern the GaP is left in place after the etch.

The goal of the optimization is to maximize extraction of 637 nm emission, i.e. to maximize the power coupled from the defect to the ambient above the extractor. Such a device device could be useful for quantum information processing applications.

## Simulating an existing design

We'll begin by loading, visualizing, and simulating the design from the [invrs-gym paper](https://arxiv.org/abs/2410.24132). Later, we'll plot an x-z cross section of fields in the extractor, which is indicated below with the dashed black line.

In [None]:
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure
from totypes import json_utils

file = "240118_mfschubert_8569349cf4b44541ee37aa3eeed0127b70e29fb52674a4e97370fe7d95323bc7.json"
with open(f"../../../reference_designs/photon_extractor/{file}", "r") as f:
    serialized = f.read()
params = json_utils.pytree_from_json(serialized)


plt.figure(figsize=(4, 4))
ax = plt.subplot(111)
im = ax.imshow(1 - params.array.T, cmap="gray")
im.set_clim([-2, 1])
contours = measure.find_contours(params.array.T)
for c in contours:
    plt.plot(c[:, 1], c[:, 0], "k", lw=1)
ax.set_xticks([])
ax.set_yticks([])

midpoint = params.shape[0] / 2
ax.plot([midpoint, midpoint], [0, params.shape[0]], "k--")
ax.set_xlim([120, params.shape[1] - 120])
_ = ax.set_ylim([120, params.shape[0] - 120])

We will use the `Challenge` object returned by `challenges.photon_extractor` to carry out the simulation.

In [None]:
from invrs_gym import challenges

challenge = challenges.photon_extractor()

We are now ready to simulate the photon extractor, using the `component.response` method. By default, this will not compute the fields emitted by the source (for improved performance), but we will do so here for visualization purposes.

In [None]:
response, aux = challenge.component.response(params, compute_fields=True)

In this challenge, we care about the enhancement in flux compared to a bare substrate. This is included in the challenge metrics, for x-, y-, and z-oriented dipoles.

In [None]:
metrics = challenge.metrics(response, params=params, aux=aux)
for i, orientation in enumerate(["x", "y", "z"]):
    print(
        f"Flux enhancement for {orientation} dipole is "
        f"{metrics['enhancement_flux_per_dipole'][i]:.2f}"
    )

The values are similar to those reported by Chakravarthi et al.

Now let's visualize the fields; these are for an xz slice, and are computed for each of the dipole orientations. Plot the field magnitude for each dipole orientation with the structure overlaid.

In [None]:
import ccmaps
from skimage import measure

x, y, z = aux["field_coordinates"]
ex, ey, ez = aux["efield"]

assert ex.ndim == 3 and ex.shape[-1] == 3
field_magnitude = onp.sqrt(onp.abs(ex) ** 2 + onp.abs(ey) ** 2 + onp.abs(ez) ** 2)
maxval = onp.amax(field_magnitude)


# Define a function that will plot the fields and overlay the structure.
def plot_field_and_structure(ax, field, title):
    # Plot the field.
    xplot, zplot = onp.meshgrid(x, z, indexing="ij")
    im = ax.pcolormesh(xplot, zplot, field, cmap=ccmaps.wbgyr())
    plt.colorbar(im)

    # Overlay the structure.
    spec = challenge.component.spec

    z0 = spec.thickness_ambient
    z1 = z0 + spec.thickness_oxide
    z2 = z1 + spec.thickness_extractor

    # Plot line at the top of the substrate.
    ax.plot([0, onp.amax(x)], [z2, z2], "k", lw=1)

    density_plot = params.array
    density_plot_slice = density_plot[:, density_plot.shape[1] // 2, onp.newaxis]
    contours = measure.find_contours(onp.tile(density_plot_slice, (1, 2)))
    for c1, c2 in zip(contours[::2], contours[1::2]):
        zc = onp.concatenate([c1[:, 1], c2[:, 1], [0]])
        xc = onp.concatenate([c1[:, 0], c2[:, 0], [c1[0, 0]]]) + 0.5
        xc = xc * (x[1] - x[0]) + x[0]
        zcp = onp.where(zc == 0, z0, z1)
        ax.plot(xc, zcp, "k", lw=1)  # Oxide
        zcp = onp.where(zc == 0, z1, z2)
        ax.plot(xc, zcp, "k", lw=1)  # GaP

    ax.set_xlim([0, onp.amax(x)])
    ax.set_ylim([onp.amax(z), 0])
    ax.axis("equal")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)


plt.figure(figsize=(5, 12))
plot_field_and_structure(
    plt.subplot(311),
    field_magnitude[:, :, 0],
    title=f"x dipole\nflux enhancement={metrics['enhancement_flux_per_dipole'][0]:.2f}",
)
plot_field_and_structure(
    plt.subplot(312),
    field_magnitude[:, :, 1],
    title=f"y dipole\nflux enhancement={metrics['enhancement_flux_per_dipole'][1]:.2f}",
)
plot_field_and_structure(
    plt.subplot(313),
    field_magnitude[:, :, 2],
    title=f"z dipole\nflux enhancement={metrics['enhancement_flux_per_dipole'][2]:.2f}",
)