General SAX utilities
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pytest
import sax
arr1 = 1 * jnp.ones((1, 2, 2))
arr2 = 2 * jnp.ones((1, 3, 3))
assert (
sax.block_diag(arr1, arr2)
== jnp.array(
[
[
[1.0, 1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 2.0, 2.0, 2.0],
]
]
)
).all()
assert sax.clean_string("Hello, string 1.0") == "Hello__string_1p0"
orig_settings = {"a": 3, "c": jnp.array([9.0, 10.0, 11.0])}
new_settings = sax.copy_settings(orig_settings)
assert orig_settings["a"] == new_settings["a"]
assert jnp.all(orig_settings["c"] == new_settings["c"])
new_settings["a"] = jnp.array(5.0)
assert orig_settings["a"] == 3
assert new_settings["a"] == 5
assert orig_settings["c"] is new_settings["c"]
nested_dict = {
"a": 3.0,
"b": {"c": 4.0},
}
flat_dict = sax.flatten_dict(nested_dict, sep=",")
assert flat_dict == {"a": 3.0, "b,c": 4.0}
assert sax.unflatten_dict(flat_dict, sep=",") == nested_dict
sax.scoo({("in0", "out0"): 1.0})
def coupler(coupling=0.5):
return {
("in0", "out0"): coupling**0.5,
("in0", "out1"): 1j*coupling**0.5,
("in1", "out0"): 1j*coupling**0.5,
("in1", "out1"): coupling**0.5,
}
model = coupler
assert sax.get_ports(model) == ("in0", "in1", "out0", "out1")
sdict_ = coupler()
assert sax.get_ports(sdict_) == ("in0", "in1", "out0", "out1")
scoo_ = sax.scoo(sdict_)
assert sax.get_ports(scoo_) == ("in0", "in1", "out0", "out1")
sdense_ = sax.sdense(sdict_)
assert sax.get_ports(sdense_) == ("in0", "in1", "out0", "out1")
model = coupler
assert sax.get_port_combinations(model) == (
("in0", "out0"),
("in0", "out1"),
("in1", "out0"),
("in1", "out1"),
)
sdict_ = coupler()
assert sax.get_port_combinations(sdict_) == (
("in0", "out0"),
("in0", "out1"),
("in1", "out0"),
("in1", "out1"),
)
scoo_ = sax.scoo(sdict_)
assert sax.get_port_combinations(scoo_) == (
("in0", "out0"),
("in0", "out1"),
("in1", "out0"),
("in1", "out1"),
)
sdense_ = sax.sdense(sdict_)
assert sax.get_port_combinations(sdense_) == (
("in0", "in0"),
("in0", "in1"),
("in0", "out0"),
("in0", "out1"),
("in1", "in0"),
("in1", "in1"),
("in1", "out0"),
("in1", "out1"),
("out0", "in0"),
("out0", "in1"),
("out0", "out0"),
("out0", "out1"),
("out1", "in0"),
("out1", "in1"),
("out1", "out0"),
("out1", "out1"),
)
assert sax.get_settings(coupler) == {'coupling': 0.5}
# hide
wls = jnp.array([2.19999, 2.20001, 2.22499, 2.22501, 2.24999, 2.25001, 2.27499, 2.27501, 2.29999, 2.30001, 2.32499, 2.32501, 2.34999, 2.35001, 2.37499, 2.37501, 2.39999, 2.40001, 2.42499, 2.42501, 2.44999, 2.45001])
phis = jnp.array([5.17317336, 5.1219654, 4.71259842, 4.66252492, 5.65699608, 5.60817922, 2.03697377, 1.98936119, 6.010146, 5.96358061, 4.96336733, 4.91777933, 5.13912198, 5.09451137, 0.22347545, 0.17979684, 2.74501894, 2.70224092, 0.10403192, 0.06214664, 4.83328794, 4.79225525])
wl = jnp.array([2.21, 2.27, 1.31, 2.424])
phi = jnp.array(sax.grouped_interp(wl, wls, phis))
phi_ref = jnp.array([-1.4901831, 1.3595749, -1.110012 , 2.1775336])
assert ((phi-phi_ref)**2 < 1e-5).all()
d = sax.merge_dicts({"a": 3}, {"b": 4})
assert d["a"] == 3
assert d["b"] == 4
assert tuple(sorted(d)) == ("a", "b")
d = sax.merge_dicts({"a": 3}, {"a": 4})
assert d["a"] == 4
assert tuple(d) == ("a",)
d = sax.merge_dicts({"a": 3}, {"a": {"b": 5}})
assert d["a"]["b"] == 5
assert tuple(d) == ("a",)
d = sax.merge_dicts({"a": {"b": 5}}, {"a": 3})
assert d["a"] == 3
assert tuple(d) == ("a",)
assert sax.mode_combinations(modes=["te", "tm"]) == (('te', 'te'), ('tm', 'tm'))
assert sax.mode_combinations(modes=["te", "tm"], cross=True) == (('te', 'te'), ('te', 'tm'), ('tm', 'te'), ('tm', 'tm'))
sdict_ = {("in0", "out0"): 1.0}
assert sax.reciprocal(sdict_) == {("in0", "out0"): 1.0, ("out0", "in0"): 1.0}
def model(x=jnp.array(3.0), y=jnp.array(4.0), z=jnp.array([3.0, 4.0])) -> sax.SDict:
return {("in0", "out0"): jnp.array(3.0)}
renamings = {"x": "a", "y": "z", "z": "y"}
new_model = sax.rename_params(model, renamings)
settings = sax.get_settings(new_model)
assert settings["a"] == 3.0
assert settings["z"] == 4.0
assert jnp.all(settings["y"] == jnp.array([3.0, 4.0]))
d = sax.reciprocal({("p0", "p1"): 0.1, ("p1", "p2"): 0.2})
origports = sax.get_ports(d)
renamings = {"p0": "in0", "p1": "out0", "p2": "in1"}
d_ = sax.rename_ports(d, renamings)
assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))
d_ = sax.rename_ports(sax.scoo(d), renamings)
assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))
d_ = sax.rename_ports(sax.sdense(d), renamings)
assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))
Assuming you have a settings dictionary for a circuit
containing a directional coupler "dc"
and a waveguide "wg"
:
settings = {"wl": 1.55, "dc": {"coupling": 0.5}, "wg": {"wl": 1.56, "neff": 2.33}}
You can update this settings dictionary with some global settings as follows. When updating settings globally like this, each subdictionary of the settings dictionary will be updated with these values (if the key exists in the subdictionary):
settings = sax.update_settings(settings, wl=1.3, coupling=0.3, neff=3.0)
assert settings == {"wl": 1.3, "dc": {"coupling": 0.3}, "wg": {"wl": 1.3, "neff": 3.0}}
Alternatively, you can set certain settings for a specific component (e.g. 'wg' in this case) as follows:
settings = sax.update_settings(settings, "wg", wl=2.0)
assert settings == {"wl": 1.3, "dc": {"coupling": 0.3}, "wg": {"wl": 2.0, "neff": 3.0}}
note that only the "wl"
belonging to "wg"
has changed.
sdict = {("in0", "out0"): 1.0, ("out0", "in0"): 1.0}
sax.validate_not_mixedmode(sdict)
sdict = {("in0@te", "out0@te"): 1.0, ("out0@tm", "in0@tm"): 1.0}
sax.validate_not_mixedmode(sdict)
sdict = {("in0@te", "out0@te"): 1.0, ("out0", "in0@tm"): 1.0}
with pytest.raises(ValueError):
sax.validate_not_mixedmode(sdict)
sdict = {("in0", "out0"): 1.0, ("out0", "in0"): 1.0}
with pytest.raises(ValueError):
sax.validate_multimode(sdict)
sdict = {("in0@te", "out0@te"): 1.0, ("out0@tm", "in0@tm"): 1.0}
sax.validate_multimode(sdict)
sdict = {("in0@te", "out0@te"): 1.0, ("out0", "in0@tm"): 1.0}
with pytest.raises(ValueError):
sax.validate_multimode(sdict)
good_sdict = sax.reciprocal({("p0", "p1"): 0.1,
("p1", "p2"): 0.2})
assert sax.validate_sdict(good_sdict) is None
bad_sdict = {
"p0,p1": 0.1,
("p1", "p2"): 0.2,
}
with pytest.raises(ValueError):
sax.validate_sdict(bad_sdict)
assert sax.get_inputs_outputs(["in0", "out0"]) == (('in0',), ('out0',))
assert sax.get_inputs_outputs(["in0", "in1"]) == (('in0', 'in1'), ())
assert sax.get_inputs_outputs(["out0", "out1"]) == ((), ('out0', 'out1'))
assert sax.get_inputs_outputs(["out0", "dc0"]) == (('dc0',), ('out0',))
assert sax.get_inputs_outputs(["dc0", "in0"]) == (('in0',), ('dc0',))