When training spiking neural networks, we typically experience long training times, depending on the number of time steps and training algorithm used. One thing that should not contribute to long training times is the time it takes to load a potentially transformed sample. For a start, let's measure the time it takes to apply a transform to 100 NMNIST samples without any tricks.
import tonic
import tonic.transforms as transforms
sensor_size = tonic.datasets.NMNIST.sensor_size
transform = transforms.Compose(
[
transforms.Denoise(filter_time=10000),
transforms.ToFrame(sensor_size=sensor_size, n_time_bins=3),
]
)
dataset = tonic.datasets.NMNIST(save_to="./data", train=False, transform=transform)
def load_sample_simple():
for i in range(100):
events, target = dataset[i]
%timeit -o load_sample_simple()
print(
f"Loading time for 60k samples and 200 epochs: ~{int(_.average*600*200/3600)} minutes."
)
To speed up things a bit, we can make use of sophisticated dataloaders, which provide support for pre-fetching data, multiple worker threads, batching and other things. Let's try the PyTorch dataloader. You can find all the supported functionality in the official documentation.
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, num_workers=2, shuffle=True)
def load_sample_pytorch():
for i, (events, target) in enumerate(iter(dataloader)):
if i > 99:
break
load_sample_pytorch = lambda: next(iter(dataloader))
%timeit load_sample_pytorch()
Even with a smarter Dataloader, we still do 2 things:
To address these two issues, Tonic provides a DiskCachedDataset
. A DiskCachedDataset
wraps around your dataset object of choice. Whenever you load a sample, it applies the original transforms to your data and saves the result on disk in an efficient and convenient format. The next time you want to read the same sample, we will just read from that new file instead.
In practice, this means that while your first epoch might be similarly slow as before, the following epochs will load much faster.
from tonic import DiskCachedDataset
cached_dataset = DiskCachedDataset(dataset, cache_path="./cache/fast_dataloading")
cached_dataloader = DataLoader(cached_dataset, num_workers=2)
def load_sample_cached():
for i, (events, target) in enumerate(iter(cached_dataloader)):
if i > 99:
break
%timeit -o -r 20 load_sample_cached()
print(
f"Loading time for 60k samples and 200 epochs with cache: ~{int(_.average*600*200/3600)} minutes."
)
If we want to apply stochastic transformations as well, we can pass another set of transforms to the DiskCachedDataset, which will then apply them after reading them from the cache. In the following example, we will convert our cached samples (which are already frames) to tensors and then apply random rotations to the whole recording.
import torch
import torchvision
transform = tonic.transforms.Compose(
[torch.tensor, torchvision.transforms.RandomRotation([-30, 30])]
)
augmented_dataset = DiskCachedDataset(
dataset, cache_path="./cache/fast_dataloading2", transform=transform
)
augmented_dataloader = DataLoader(augmented_dataset, num_workers=2)
def load_sample_augmented():
for i, (events, target) in enumerate(iter(augmented_dataloader)):
if i > 99:
break
%timeit -r 20 load_sample_augmented()