SAX types
from typing import Callable
import jax.numpy as jnp
import numpy as np
import pytest
import sax
assert sax.is_float(3.0)
assert not sax.is_float(3)
assert not sax.is_float(3.0 + 2j)
assert not sax.is_float(jnp.array(3.0, dtype=complex))
assert not sax.is_float(jnp.array(3, dtype=int))
assert not sax.is_complex(3.0)
assert not sax.is_complex(3)
assert sax.is_complex(3.0 + 2j)
assert sax.is_complex(jnp.array(3.0, dtype=complex))
assert not sax.is_complex(jnp.array(3, dtype=int))
assert sax.is_complex_float(3.0)
assert not sax.is_complex_float(3)
assert sax.is_complex_float(3.0 + 2j)
assert sax.is_complex_float(jnp.array(3.0, dtype=complex))
assert not sax.is_complex_float(jnp.array(3, dtype=int))
_sdict: sax.SDict = {
("in0", "out0"): 3.0,
}
Si = jnp.arange(3, dtype=int)
Sj = jnp.array([0, 1, 0], dtype=int)
Sx = jnp.array([3.0, 4.0, 1.0])
port_map = {"in0": 0, "in1": 2, "out0": 1}
_scoo: sax.SCoo = (Si, Sj, Sx, port_map)
Sd = jnp.arange(9, dtype=float).reshape(3, 3)
port_map = {"in0": 0, "in1": 2, "out0": 1}
_sdense = Sd, port_map
assert not sax.is_sdict(object())
assert sax.is_sdict(_sdict)
assert not sax.is_sdict(_scoo)
assert not sax.is_sdict(_sdense)
assert not sax.is_scoo(object)
assert not sax.is_scoo(_sdict)
assert sax.is_scoo(_scoo)
assert not sax.is_scoo(_sdense)
assert not sax.is_sdense(object)
assert not sax.is_sdense(_sdict)
assert not sax.is_sdense(_scoo)
assert sax.is_sdense(_sdense)
def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:
return {("in0", "out0"): jnp.array(3.0)}
assert sax.is_model(good_model)
def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:
return {("in0", "out0"): jnp.array(3.0)}
assert not sax.is_model(bad_model)
Note: For a
Callable
to be considered aModelFactory
in SAX, it MUST have aCallable
orModel
return annotation. Otherwise SAX will view it as aModel
and things might break!
def func() -> sax.Model:
...
assert sax.is_model_factory(func) # yes, we only check the annotation for now...
def func():
...
assert not sax.is_model_factory(func) # yes, we only check the annotation for now...
def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:
return {("in0", "out0"): jnp.array(3.0)}
assert sax.validate_model(good_model) is None
def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:
return {("in0", "out0"): jnp.array(3.0)}
with pytest.raises(ValueError):
sax.validate_model(bad_model)
a.k.a SDict, SDense, SCoo helpers
Convert an SDict
, SCoo
or SDense
into an SDict
(or convert a model generating any of these types into a model generating an SDict
):
assert sax.sdict(_sdict) is _sdict
assert sax.sdict(_scoo) == {
("in0", "in0"): 3.0,
("in1", "in0"): 1.0,
("out0", "out0"): 4.0,
}
assert sax.sdict(_sdense) == {
("in0", "in0"): 0.0,
("in0", "out0"): 1.0,
("in0", "in1"): 2.0,
("out0", "in0"): 3.0,
("out0", "out0"): 4.0,
("out0", "in1"): 5.0,
("in1", "in0"): 6.0,
("in1", "out0"): 7.0,
("in1", "in1"): 8.0,
}
Convert an SDict
, SCoo
or SDense
into an SCoo
(or convert a model generating any of these types into a model generating an SCoo
):
sax.scoo(_sdense)
assert sax.scoo(_scoo) is _scoo
assert sax.scoo(_sdict) == (0, 1, 3.0, {"in0": 0, "out0": 1})
Si, Sj, Sx, port_map = sax.scoo(_sdense) # type: ignore
np.testing.assert_array_equal(Si, jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))
np.testing.assert_array_equal(Sj, jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2]))
np.testing.assert_array_almost_equal(Sx, jnp.array([0.0, 2.0, 1.0, 6.0, 8.0, 7.0, 3.0, 5.0, 4.0]))
assert port_map == {"in0": 0, "in1": 1, "out0": 2}
Convert an SDict
, SCoo
or SDense
into an SDense
(or convert a model generating any of these types into a model generating an SDense
):
assert sax.sdense(_sdense) is _sdense
Sd, port_map = sax.sdense(_scoo) # type: ignore
Sd_ = jnp.array([[3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],
[1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]])
np.testing.assert_array_almost_equal(Sd, Sd_)
assert port_map == {"in0": 0, "in1": 2, "out0": 1}