from zntrack import Node, dvc, config, zn
config.nb_name = "06_named_nodes.ipynb"
from zntrack.utils import cwd_temp_dir
temp_dir = cwd_temp_dir()
!git init
!dvc init
Initialized empty Git repository in /tmp/tmpcn8yts4z/.git/ Initialized DVC repository. You can now commit the changes to git. +---------------------------------------------------------------------+ | | | DVC has enabled anonymous aggregate usage analytics. | | Read the analytics documentation (and how to opt-out) here: | | <https://dvc.org/doc/user-guide/analytics> | | | +---------------------------------------------------------------------+ What's next? ------------ - Check out the documentation: <https://dvc.org/doc> - Get help and share ideas: <https://dvc.org/chat> - Star us on GitHub: <https://github.com/iterative/dvc>
Named Nodes allow us to use the same Node multiple times in a single graph at e.g. different steps. Therefore, we can pass a name
argument to the __init__
of our Node.
Notice that this is one of only very few scenarios where we want to pass an argument directly to the `__init__`
class HelloWorld(Node):
inputs = zn.params()
outputs = zn.outs()
def __init__(self, inputs=None, **kwargs):
super().__init__(**kwargs)
self.inputs = inputs
def run(self):
self.outputs = self.inputs
HelloWorld(inputs=3).write_graph(no_exec=False)
HelloWorld(name="Test01", inputs=17).write_graph(no_exec=False)
HelloWorld(name="Test02", inputs=42).write_graph(no_exec=False)
2022-03-16 13:28:59,293 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command! Submit issues to https://github.com/zincware/ZnTrack. 2022-03-16 13:29:03,070 (WARNING): Running DVC command: 'dvc run -n HelloWorld ...' 2022-03-16 13:29:10,857 (WARNING): Running DVC command: 'dvc run -n Test01 ...' 2022-03-16 13:29:18,410 (WARNING): Running DVC command: 'dvc run -n Test02 ...'
!dvc dag
+------------+ | HelloWorld | +------------+ +--------+ | Test01 | +--------+ +--------+ | Test02 | +--------+
We can now also build a Node that depends on multiple of the same Nodes
class FindMaximum(Node):
deps = zn.deps(
[
HelloWorld.load(),
HelloWorld.load(name="Test01"),
HelloWorld.load(name="Test02"),
]
)
maximum = zn.outs()
def run(self):
self.maximum = 0
for node in self.deps:
if node.outputs > self.maximum:
self.maximum = node.outputs
print(f"New maximum found {node.outputs}.")
FindMaximum().write_graph(run=True)
2022-03-16 13:29:28,486 (WARNING): Running DVC command: 'dvc run -n FindMaximum ...'
!dvc dag
+------------+ +--------+ +--------+ | HelloWorld | | Test01 | | Test02 | +------------+** +--------+ ***+--------+ *** * *** **** * **** ** * ** +-------------+ | FindMaximum | +-------------+
Using this combined Node we can e.g. find the maximum of the generated values.
FindMaximum.load().maximum
42
# Running it manually to highlight the print statements
FindMaximum.load().run()
New maximum found 3. New maximum found 17. New maximum found 42.
In addition to the introduced classmethod Node.load(name="nodename")
it is also possible to use Node["nodename"]
. Note that this only works for Node["nodename"]
and not for Node()["nodename"]
. Using this we can also write the following:
print(HelloWorld["Test01"].outputs)
print(HelloWorld["Test01"].node_name)
17 Test01
this is equivalent to the classmethod load()
. It is also possible to pass a dictionary as kwargs which will be passed to load(**kwargs)
.
print(HelloWorld.load("Test02").outputs)
print(HelloWorld.load("Test02").node_name)
print(HelloWorld[{"name": "Test02"}].outputs)
42 Test02 42
temp_dir.cleanup()