ZnTrack allows you to create a custom ZnTrackOption similar to zn.outs
.
ZnTrack tries to handle some standard types automatically within the zn.outs
option, but it can be useful to write custom ones.
In the following example we use Atomic Simulation Environment to store / load objects to a custom datafile.
from zntrack import config
# When using ZnTrack we can write our code inside a Jupyter notebook.
# We can make use of this functionality by setting the `nb_name` config as follows:
config.nb_name = "08_custom_zntrackoptions.ipynb"
from zntrack.utils import cwd_temp_dir
temp_dir = cwd_temp_dir()
!git init
!dvc init
Initialized empty Git repository in /tmp/tmp5twr4kp_/.git/ Initialized DVC repository. You can now commit the changes to git. +---------------------------------------------------------------------+ | | | DVC has enabled anonymous aggregate usage analytics. | | Read the analytics documentation (and how to opt-out) here: | | <https://dvc.org/doc/user-guide/analytics> | | | +---------------------------------------------------------------------+ What's next? ------------ - Check out the documentation: <https://dvc.org/doc> - Get help and share ideas: <https://dvc.org/chat> - Star us on GitHub: <https://github.com/iterative/dvc>
We will use the ZnTrackOption
to build our new custom options.
import zntrack
import ase.db
import ase.io
import tqdm
class Atoms(zntrack.Field):
# we will save the file as dvc run --outs
dvc_option = "outs"
group = zntrack.FieldGroup.RESULT # you can choose from RESULT or PARAMETER
def get_files(self, instance) -> list:
"""Define the filename that is passed to dvc (used if tracked=True)"""
# self.name is the name of the class attribute we use for this database
return [instance.nwd / f"{self.name}.db"]
def save(self, instance):
"""Save the values to file"""
# we gather the actual values using __get__
atoms = getattr(instance, self.name)
# get the file name
file = self.get_files(instance)[0]
# save the data to the file
with ase.db.connect(file) as db:
for atom in tqdm.tqdm(atoms, ncols=70, desc=f"Writing atoms to {file}"):
db.write(atom, group=instance.name)
def get_data(self, instance):
"""Load data with ase.db.connect from file"""
# get the file name
file = self.get_files(instance)[0]
# load the data
atoms = []
with ase.db.connect(file) as db:
for row in tqdm.tqdm(
db.select(), ncols=70, desc=f"Loading atoms from {file}"
):
atoms.append(row.toatoms())
# return the data so it can be saved in __dict__
return atoms
Now that we have defined our custom ZnTrackOption we can use it as follows.
class AtomsClass(zntrack.Node):
atoms = Atoms()
def run(self):
self.atoms = [ase.Atoms("N2", positions=[[0, 0, -1], [0, 0, 1]])]
with zntrack.Project() as project:
node = AtomsClass()
project.run(repro=False)
Running DVC command: 'stage add --name AtomsClass --force ...'
Creating 'dvc.yaml' Adding stage 'AtomsClass' in 'dvc.yaml' To track the changes with git, run: git add dvc.yaml nodes/AtomsClass/.gitignore To enable auto staging, run: dvc config core.autostage true
Jupyter support is an experimental feature! Please save your notebook before running this command! Submit issues to https://github.com/zincware/ZnTrack. [NbConvertApp] Converting notebook 08_custom_zntrackoptions.ipynb to script [NbConvertApp] Writing 2881 bytes to 08_custom_zntrackoptions.py
!dvc repro
Running stage 'AtomsClass': > zntrack run src.AtomsClass.AtomsClass --name AtomsClass Loading atoms from nodes/AtomsClass/atoms.db: 0it [00:00, ?it/s] Writing atoms to nodes/AtomsClass/atoms.db: 100%|█| 1/1 [00:00<00:00, Generating lock file 'dvc.lock' Updating lock file 'dvc.lock' To track the changes with git, run: git add dvc.lock To enable auto staging, run: dvc config core.autostage true Use `dvc push` to send your updates to remote storage.
node.load()
print(node.atoms)
# or
AtomsClass.from_rev().atoms
Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 1151.02it/s]
[Atoms(symbols='N2', pbc=False)]
Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 3231.36it/s]
[Atoms(symbols='N2', pbc=False)]
temp_dir.cleanup()