#!/usr/bin/env python # coding: utf-8 # # Layout Aware Monte Carlo with GDSFactory # > Towards layout-aware optimization and monte-carlo simulations # In[ ]: # hide import os os.environ["LOGURU_LEVEL"] = "CRITICAL" # import warnings # warnings.filterwarnings("ignore") # In[ ]: import itertools import json import os import sys from functools import partial from typing import List import gdsfactory as gf # conda install gdsfactory import jax import jax.example_libraries.optimizers as opt import jax.numpy as jnp import matplotlib.pyplot as plt import meow as mw import numpy as np import sax from numpy.fft import fft2, fftfreq, fftshift, ifft2 from tqdm.notebook import tqdm, trange # ## Simple MZI Layout # In[ ]: @gf.cell def simple_mzi(): c = gf.Component() # components mmi_in = gf.components.mmi1x2() mmi_out = gf.components.mmi2x2() bend = gf.components.bend_euler() half_delay_straight = gf.components.straight(length=10.0) # references mmi_in = c.add_ref(mmi_in, name="mmi_in") mmi_out = c.add_ref(mmi_out, name="mmi_out") straight_top1 = c.add_ref(half_delay_straight, name="straight_top1") straight_top2 = c.add_ref(half_delay_straight, name="straight_top2") bend_top1 = c.add_ref(bend, name="bend_top1") bend_top2 = c.add_ref(bend, name="bend_top2").dmirror() bend_top3 = c.add_ref(bend, name="bend_top3").dmirror() bend_top4 = c.add_ref(bend, name="bend_top4") bend_btm1 = c.add_ref(bend, name="bend_btm1").dmirror() bend_btm2 = c.add_ref(bend, name="bend_btm2") bend_btm3 = c.add_ref(bend, name="bend_btm3") bend_btm4 = c.add_ref(bend, name="bend_btm4").dmirror() # connections bend_top1.connect("o1", mmi_in.ports["o2"]) straight_top1.connect("o1", bend_top1.ports["o2"]) bend_top2.connect("o1", straight_top1.ports["o2"]) bend_top3.connect("o1", bend_top2.ports["o2"]) straight_top2.connect("o1", bend_top3.ports["o2"]) bend_top4.connect("o1", straight_top2.ports["o2"]) bend_btm1.connect("o1", mmi_in.ports["o3"]) bend_btm2.connect("o1", bend_btm1.ports["o2"]) bend_btm3.connect("o1", bend_btm2.ports["o2"]) bend_btm4.connect("o1", bend_btm3.ports["o2"]) mmi_out.connect("o1", bend_btm4.ports["o2"]) # ports c.add_port( "o1", port=mmi_in.ports["o1"], ) c.add_port("o2", port=mmi_out.ports["o3"]) c.add_port("o3", port=mmi_out.ports["o4"]) return c # In[ ]: mzi = simple_mzi() mzi # ## Simulate MZI # # We used the following components to construct the MZI circuit: # # - mmi1x2 # - mmi2x2 # - straight # - bend_euler # We need a model for each of those components to be able to simulate the circuit with SAX. Let's create some dummy models for now. # In[ ]: def mmi1x2(): S = { ("o1", "o2"): 0.5**0.5, ("o1", "o3"): 0.5**0.5, } return sax.reciprocal(S) # In[ ]: def mmi2x2(): S = { ("o1", "o3"): 0.5**0.5, ("o1", "o4"): 1j * 0.5**0.5, ("o2", "o3"): 1j * 0.5**0.5, ("o2", "o4"): 0.5**0.5, } return sax.reciprocal(S) # In[ ]: def straight(length=10.0, width=0.5): S = {("o1", "o2"): 1.0} # we'll improve this model later! return sax.reciprocal(S) # In[ ]: def bend_euler(length=10.0, width=0.5, dy=10.0, radius_min=7, radius=10): return straight(length=length, width=width) # stub with straight for now # Let's create a SAX circuit with our very simple placeholder models: # In[ ]: models = { "mmi1x2": mmi1x2, "mmi2x2": mmi2x2, "straight": straight, "bend_euler": bend_euler, } mzi1, _ = sax.circuit(mzi.get_netlist(recursive=True), models=models) get_ipython().run_line_magic('pinfo', 'mzi1') # the resulting circuit is just a model function on its own! Hence, calling it will give the result: # In[ ]: mzi1() # ## Waveguide Model # # Our waveguide model is not very good (it just has 100% transmission and no phase). Let's do something about the phase calculation. To do this, we need to find the effective index of the waveguide in relation to its parameters. We can use [meow](https://github.com/flaport/meow) to obtain the waveguide effective index. Let's first create a `find_waveguide_modes`: # In[ ]: def find_waveguide_modes( wl: float = 1.55, n_box: float = 1.4, n_clad: float = 1.4, n_core: float = 3.4, t_slab: float = 0.1, t_soi: float = 0.22, w_core: float = 0.45, du=0.02, n_modes: int = 10, cache_path: str = "modes", replace_cached: bool = False, ): length = 10.0 delta = 10 * du env = mw.Environment(wl=wl) cache_path = os.path.abspath(cache_path) os.makedirs(cache_path, exist_ok=True) fn = f"{wl=:.2f}-{n_box=:.2f}-{n_clad=:.2f}-{n_core=:.2f}-{t_slab=:.3f}-{t_soi=:.3f}-{w_core=:.3f}-{du=:.3f}-{n_modes=}.json" path = os.path.join(cache_path, fn) if not replace_cached and os.path.exists(path): return [mw.Mode.model_validate(mode) for mode in json.load(open(path, "r"))] # fmt: off m_core = mw.SampledMaterial(name="slab", n=np.asarray([n_core, n_core]), params={"wl": np.asarray([1.0, 2.0])}, meta={"color": (0.9, 0, 0, 0.9)}) m_clad = mw.SampledMaterial(name="clad", n=np.asarray([n_clad, n_clad]), params={"wl": np.asarray([1.0, 2.0])}) m_box = mw.SampledMaterial(name="box", n=np.asarray([n_box, n_box]), params={"wl": np.asarray([1.0, 2.0])}) box = mw.Structure(material=m_box, geometry=mw.Box(x_min=- 2 * w_core - delta, x_max= 2 * w_core + delta, y_min=- 2 * t_soi - delta, y_max=0.0, z_min=0.0, z_max=length)) slab = mw.Structure(material=m_core, geometry=mw.Box(x_min=-2 * w_core - delta, x_max=2 * w_core + delta, y_min=0.0, y_max=t_slab, z_min=0.0, z_max=length)) clad = mw.Structure(material=m_clad, geometry=mw.Box(x_min=-2 * w_core - delta, x_max=2 * w_core + delta, y_min=0, y_max=3 * t_soi + delta, z_min=0.0, z_max=length)) core = mw.Structure(material=m_core, geometry=mw.Box(x_min=-w_core / 2, x_max=w_core / 2, y_min=0.0, y_max=t_soi, z_min=0.0, z_max=length)) cell = mw.Cell(structures=[box, clad, slab, core], mesh=mw.Mesh2D( x=np.arange(-2*w_core, 2*w_core, du), y=np.arange(-2*t_soi, 3*t_soi, du), ), z_min=0.0, z_max=10.0) cross_section = mw.CrossSection.from_cell(cell=cell, env=env) modes = mw.compute_modes(cross_section, num_modes=n_modes) # fmt: on json.dump([json.loads(mode.json()) for mode in modes], open(path, "w")) return modes # We can also create a rudimentary model for the silicon refractive index: # In[ ]: def silicon_index(wl): """a rudimentary silicon refractive index model""" a, b = 0.2411478522088102, 3.3229394315868976 return a / wl + b # We can now easily calculate the modes of a strip waveguide: # In[ ]: modes = find_waveguide_modes(wl=1.5, n_core=silicon_index(wl=1.5)) # The fundamental mode is the mode with index 0: # In[ ]: mw.visualize(modes[0]) # In[ ]: wavelengths, widths = np.mgrid[1.5:1.6:10j, 0.4:0.6:5j] neffs = np.zeros_like(wavelengths) neffs_ = neffs.ravel() for i, (wl, w) in enumerate(zip(tqdm(wavelengths.ravel()), widths.ravel())): modes = find_waveguide_modes( wl=wl, n_core=silicon_index(wl), w_core=w, replace_cached=False ) neffs_[i] = np.real(modes[0].neff) # This results in the following effective indices: # In[ ]: _wls = np.unique(wavelengths.ravel()) _widths = np.unique(widths.ravel()) plt.figure(figsize=(8, 3)) plt.plot(_wls * 1000, neffs) plt.ylabel("neff") plt.xlabel("λ [nm]") plt.title("Effective Index") plt.grid(True) plt.figlegend( [f"{w=:.2f}um" for w in _widths], ncol=len(widths), bbox_to_anchor=(0.95, -0.05) ) plt.show() # We can do a grid interpolation on those effective indices: # In[ ]: _grid = [jnp.sort(jnp.unique(wavelengths)), jnp.sort(jnp.unique(widths))] _data = jnp.asarray(neffs) @jax.jit def _get_coordinate(arr1d: jnp.ndarray, value: jnp.ndarray): return jnp.interp(value, arr1d, jnp.arange(arr1d.shape[0])) @jax.jit def _get_coordinates(arrs1d: List[jnp.ndarray], values: jnp.ndarray): # don't use vmap as arrays in arrs1d could have different shapes... return jnp.array([_get_coordinate(a, v) for a, v in zip(arrs1d, values)]) @jax.jit def neff(wl=1.55, width=0.5): params = jnp.stack(jnp.broadcast_arrays(jnp.asarray(wl), jnp.asarray(width)), 0) coords = _get_coordinates(_grid, params) return jax.scipy.ndimage.map_coordinates(_data, coords, 1, mode="nearest") neff(wl=[1.52, 1.58], width=[0.5, 0.55]) # In[ ]: wavelengths_ = np.linspace(wavelengths.min(), wavelengths.max(), 100) widths_ = np.linspace(widths.min(), widths.max(), 100) wavelengths_, widths_ = np.meshgrid(wavelengths_, widths_) neffs_ = neff(wavelengths_, widths_) plt.figure(figsize=(8, 3)) plt.pcolormesh(wavelengths_ * 1000, widths_, neffs_) plt.ylabel("neff") plt.xlabel("λ [nm]") plt.title("Effective Index") plt.grid(True) plt.figlegend( [f"{w=:.2f}um" for w in _widths], ncol=len(_widths), bbox_to_anchor=(0.95, -0.05) ) plt.show() # In[ ]: def straight(wl=1.55, length=10.0, width=0.5): S = { ("o1", "o2"): jnp.exp(2j * np.pi * neff(wl=wl, width=width) / wl * length), } return sax.reciprocal(S) # Even though this still is lossless transmission, we're at least modeling the phase correctly. # In[ ]: straight() # ## Simulate MZI again # In[ ]: models["straight"] = straight mzi2, _ = sax.circuit(mzi.get_netlist(recursive=True), models=models) mzi2() # In[ ]: wl = jnp.linspace(1.51, 1.59, 1000) S = mzi2(wl=wl) plt.plot(wl, abs(S["o1", "o2"]) ** 2) plt.ylim(-0.05, 1.05) plt.xlabel("λ [μm]") plt.ylabel("T") plt.ylim(-0.05, 1.05) plt.grid(True) plt.show() # ## Optimize MZI # We'd like to optimize an MZI such that one of the minima is at 1550nm. To do this, we need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments: # In[ ]: @jax.jit def loss_fn(delta_length): S = mzi2( wl=1.55, straight_top1={"length": delta_length / 2}, straight_top2={"length": delta_length / 2}, ) return jnp.mean(jnp.abs(S["o1", "o2"]) ** 2) # We can use this loss function to define a grad function which works on the parameters of the loss function: # In[ ]: grad_fn = jax.jit( jax.grad( loss_fn, argnums=0, # JAX gradient function for the first positional argument, jitted ) ) # Next, we need to define a JAX optimizer, which on its own is nothing more than three more functions: an initialization function with which to initialize the optimizer state, an update function which will update the optimizer state (and with it the model parameters). The third function that's being returned will give the model parameters given the optimizer state. # In[ ]: loss_fn(20.0) # In[ ]: initial_delta_length = 10.0 init_fn, update_fn, params_fn = opt.adam(step_size=0.1) state = init_fn(initial_delta_length) # Given all this, a single training step can be defined: # In[ ]: def step_fn(step, state): params = params_fn(state) loss = loss_fn(params) grad = grad_fn(params) state = update_fn(step, grad, state) return loss, state # And we can use this step function to start the training of the MZI: # In[ ]: for step in ( pb := trange(300) ): # the first two iterations take a while because the circuit is being jitted... loss, state = step_fn(step, state) pb.set_postfix(loss=f"{loss:.6f}") # In[ ]: delta_length = params_fn(state) delta_length # Let's see what we've got over a range of wavelengths: # In[ ]: S = mzi2( wl=wl, straight_top1={"length": delta_length / 2}, straight_top2={"length": delta_length / 2}, ) plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2) plt.xlabel("λ [nm]") plt.ylabel("T") plt.plot([1550, 1550], [-1, 2], ls=":", color="black") plt.ylim(-0.05, 1.05) plt.grid(True) plt.show() # Note that we could've just as well optimized the waveguide width: # In[ ]: @jax.jit def loss_fn(width): S = mzi2( wl=1.55, straight_top1={"width": width}, straight_top2={"width": width}, ) return jnp.mean(jnp.abs(S["o1", "o2"]) ** 2) grad_fn = jax.jit( jax.grad( loss_fn, argnums=0, # JAX gradient function for the first positional argument, jitted ) ) initial_width = 0.5 init_fn, update_fn, params_fn = opt.adam(step_size=0.01) state = init_fn(initial_width) for step in ( pb := trange(300) ): # the first two iterations take a while because the circuit is being jitted... loss, state = step_fn(step, state) pb.set_postfix(loss=f"{loss:.6f}") optim_width = params_fn(state) S = Sw = mzi2( wl=wl, straight_top1={"width": optim_width}, straight_top2={"width": optim_width}, ) plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2) plt.xlabel("λ [nm]") plt.ylabel("T") plt.plot([1550, 1550], [-1, 2], color="black", ls=":") plt.ylim(-0.05, 1.05) plt.grid(True) plt.show() # ## Layout-aware Monte Carlo # # Let's assume the waveguide width changes with a certain correlation length. We can create a 'wafermap' of width variations by randomly varying the width and low pass filtering with a spatial frequency being the inverse of the correlation length (there are probably better ways to do this, but this works for this tutorial). # In[ ]: def create_wafermaps( placements, correlation_length=1.0, num_maps=1, mean=0.0, std=1.0, seed=None ): dx = dy = correlation_length / 200 xs = [p["x"] for p in placements.values()] ys = [p["y"] for p in placements.values()] xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys) wx, wy = xmax - xmin, ymax - ymin xmin, xmax, ymin, ymax = xmin - wx, xmax + wx, ymin - wy, ymax + wy x, y = np.arange(xmin, xmax + dx, dx), np.arange(ymin, ymax + dy, dy) if seed is None: r = np.random else: r = np.random.RandomState(seed=seed) W0 = r.randn(num_maps, x.shape[0], y.shape[0]) fx = fftshift(fftfreq(x.shape[0], d=x[1] - x[0])) fy = fftshift(fftfreq(y.shape[0], d=y[1] - y[0])) fY, fX = np.meshgrid(fy, fx) fW = fftshift(fft2(W0)) if correlation_length >= min(x.shape[0], y.shape[0]): fW = np.zeros_like(fW) else: fW = np.where(np.sqrt(fX**2 + fY**2)[None] > 1 / correlation_length, 0, fW) W = np.abs(fftshift(ifft2(fW))) ** 2 mean_ = W.mean(1, keepdims=True).mean(2, keepdims=True) std_ = W.std(1, keepdims=True).std(2, keepdims=True) if (std_ == 0).all(): std_ = 1 W = (W - mean_) / std_ W = W * std + mean return x, y, W # In[ ]: placements = mzi.get_netlist()["placements"] xm, ym, wmaps = create_wafermaps( placements, correlation_length=100, mean=0.5, std=0.002, num_maps=100, seed=42, ) for i, wmap in enumerate(wmaps): if i > 1: break plt.imshow(wmap, cmap="RdBu") plt.show() # In[ ]: def widths(xw, yw, wmaps, x, y): _wmap_grid = [xw, yw] params = jnp.stack(jnp.broadcast_arrays(jnp.asarray(x), jnp.asarray(y)), 0) coords = _get_coordinates(_wmap_grid, params) map_coordinates = partial( jax.scipy.ndimage.map_coordinates, coordinates=coords, order=1, mode="nearest" ) w = jax.vmap(map_coordinates)(wmaps) return w # Let's now sample the MZI width variation on the wafer map (let's assume a single width variation per point): # In[ ]: mzi_params = sax.get_settings(mzi2) placements = mzi.get_netlist()["placements"] width_params = { k: {"width": widths(xm, ym, wmaps, v["x"], v["y"])} for k, v in placements.items() if "width" in mzi_params[k] } S0 = mzi2(wl=wl) S = mzi2( wl=wl[:, None], **width_params, ) ps = plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2, color="C0", lw=1, alpha=0.1) nps = plt.plot(wl * 1e3, abs(S0["o1", "o2"]) ** 2, color="C1", lw=2, alpha=1) plt.xlabel("λ [nm]") plt.ylabel("T") plt.plot([1550, 1550], [-1, 2], color="black", ls=":") plt.ylim(-0.05, 1.05) plt.grid(True) plt.figlegend([*ps[-1:], *nps], ["MC", "nominal"], bbox_to_anchor=(1.1, 0.9)) rmse = jnp.mean( jnp.abs(jnp.abs(S["o1", "o2"]) ** 2 - jnp.abs(S0["o1", "o2"][:, None]) ** 2) ** 2 ) plt.title(f"{rmse=}") plt.show() # ## Compact MZI # # Let's see if we can improve variability (i.e. the RMSE w.r.t. nominal) by making the MZI more compact: # In[ ]: @gf.cell def compact_mzi(): c = gf.Component() # instances mmi_in = gf.components.mmi1x2() mmi_out = gf.components.mmi2x2() bend = gf.components.bend_euler() half_delay_straight = gf.components.straight() middle_straight = gf.components.straight(length=6.0) half_middle_straight = gf.components.straight(3.0) # references (sax convention: vars ending in underscore are references) mmi_in = c.add_ref(mmi_in, name="mmi_in") bend_top1 = c.add_ref(bend, name="bend_top1") straight_top1 = c.add_ref(half_delay_straight, name="straight_top1") bend_top2 = c.add_ref(bend, name="bend_top2").dmirror() straight_top2 = c.add_ref(middle_straight, name="straight_top2") bend_top3 = c.add_ref(bend, name="bend_top3").dmirror() straight_top3 = c.add_ref(half_delay_straight, name="straight_top3") bend_top4 = c.add_ref(bend, name="bend_top4") straight_btm1 = c.add_ref(half_middle_straight, name="straight_btm1") bend_btm1 = c.add_ref(bend, name="bend_btm1") bend_btm2 = c.add_ref(bend, name="bend_btm2").dmirror() bend_btm3 = c.add_ref(bend, name="bend_btm3").dmirror() bend_btm4 = c.add_ref(bend, name="bend_btm4") straight_btm2 = c.add_ref(half_middle_straight, name="straight_btm2") mmi_out = c.add_ref(mmi_out, name="mmi_out") # connections bend_top1.connect("o1", mmi_in.ports["o2"]) straight_top1.connect("o1", bend_top1.ports["o2"]) bend_top2.connect("o1", straight_top1.ports["o2"]) straight_top2.connect("o1", bend_top2.ports["o2"]) bend_top3.connect("o1", straight_top2.ports["o2"]) straight_top3.connect("o1", bend_top3.ports["o2"]) bend_top4.connect("o1", straight_top3.ports["o2"]) straight_btm1.connect("o1", mmi_in.ports["o3"]) bend_btm1.connect("o1", straight_btm1.ports["o2"]) bend_btm2.connect("o1", bend_btm1.ports["o2"]) bend_btm3.connect("o1", bend_btm2.ports["o2"]) bend_btm4.connect("o1", bend_btm3.ports["o2"]) straight_btm2.connect("o1", bend_btm4.ports["o2"]) mmi_out.connect("o1", straight_btm2.ports["o2"]) # ports c.add_port( "o1", port=mmi_in.ports["o1"], ) c.add_port("o2", port=mmi_out.ports["o3"]) c.add_port("o3", port=mmi_out.ports["o4"]) return c # In[ ]: compact_mzi1 = compact_mzi() compact_mzi1 # In[ ]: placements = compact_mzi1.get_netlist()["placements"] mzi3, _ = sax.circuit(compact_mzi1.get_netlist(recursive=True), models=models) mzi3() # In[ ]: mzi_params = sax.get_settings(mzi3) placements = compact_mzi1.get_netlist()["placements"] width_params = { k: {"width": widths(xm, ym, wmaps, v["x"], v["y"])} for k, v in placements.items() if "width" in mzi_params[k] } S0 = mzi3(wl=wl) S = mzi3( wl=wl[:, None], **width_params, ) ps = plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2, color="C0", lw=1, alpha=0.1) nps = plt.plot(wl * 1e3, abs(S0["o1", "o2"]) ** 2, color="C1", lw=2, alpha=1) plt.xlabel("λ [nm]") plt.ylabel("T") plt.plot([1550, 1550], [-1, 2], color="black", ls=":") plt.ylim(-0.05, 1.05) plt.grid(True) plt.figlegend([*ps[-1:], *nps], ["MC", "nominal"], bbox_to_anchor=(1.1, 0.9)) rmse = jnp.mean( jnp.abs(jnp.abs(S["o1", "o2"]) ** 2 - jnp.abs(S0["o1", "o2"][:, None]) ** 2) ** 2 ) plt.title(f"{rmse=}") plt.show()