SAX Multimode utils
import sax
import jax.numpy as jnp
sdict_s = {("in0", "out0"): 1.0}
sdict_m = sax.multimode(sdict_s)
assert sdict_m == {("in0@TE", "out0@TE"): 1.0, ("in0@TM", "out0@TM"): 1.0}
scoo_s = sax.scoo(sdict_s)
scoo_m = sax.multimode(scoo_s)
assert (scoo_m[0] == jnp.array([0, 2], dtype=int)).all()
assert (scoo_m[1] == jnp.array([1, 3], dtype=int)).all()
assert (scoo_m[2] == jnp.array([1.0, 1.0], dtype=float)).all()
assert scoo_m[3] == {"in0@TE": 0, "out0@TE": 1, "in0@TM": 2, "out0@TM": 3}
sdense_s = sax.sdense(sdict_s)
sdense_m = sax.multimode(sdense_s)
assert (
sdense_m[0]
== jnp.array(
[
[0.0 + 0.0j, 1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 1.0 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
]
)
).all()
assert sdense_m[1] == {"in0@TE": 0, "out0@TE": 1, "in0@TM": 2, "out0@TM": 3}
sdict_s = sax.singlemode(sdict_m)
assert sdict_s == {("in0", "out0"): 1.0}
scoo_s = sax.singlemode(scoo_s)
assert (scoo_s[0]==jnp.array([0], dtype=int)).all()
assert (scoo_s[1]==jnp.array([1], dtype=int)).all()
assert (scoo_s[2]==jnp.array([1.0], dtype=float)).all()
assert scoo_s[3] == {'in0': 0, 'out0': 1}
sdense_s = sax.singlemode(sdense_m)
assert (
sdense_s[0] == jnp.array([[0.0 + 0.0j, 1.0 + 0.0j], [0.0 + 0.0j, 0.0 + 0.0j]])
).all()
assert sdense_s[1] == {"in0": 0, "out0": 1}