#!/usr/bin/env python # coding: utf-8 # # Dataset slicing # An event recording is somewhat similar to a video. Sometimes it is desirable to slice a single event recording into multiple samples. During training time, we might want to load just a slice of a recording rather than the whole one. This is typically the case when training an ANN on event frames, if one recording contains multiple labels or if recordings are just very long. We specify a `slicer` method which decides how recordings are cut into smaller chunks. The overall dataset size will grow larger the smaller the chunks are. Let's look at how we can cut a sample of the N-MNIST dataset which is around 300 ms into smaller pieces of 50 ms. # In[ ]: import tonic from tonic import SlicedDataset from tonic.slicers import SliceByTime dataset = tonic.datasets.NMNIST(save_to="./data", train=False) slicing_time_window = 50000 # microseconds slicer = SliceByTime(time_window=slicing_time_window) sliced_dataset = SlicedDataset( dataset, slicer=slicer, metadata_path="./metadata/nmnist" ) # In[ ]: print( f"Went from {len(dataset)} samples in the original dataset to {len(sliced_dataset)} in the sliced version." ) # In[ ]: events, targets = sliced_dataset[100] # We can verify that the difference between last and first timestamp in the slice is not greater than our slicing time window earlier # In[ ]: slice_time_difference = events["t"][-1] - events["t"][0] print( f"Difference between last and first timestamp in slice: {slice_time_difference} us" ) assert slice_time_difference <= slicing_time_window # ## Applying transforms post-slicing # We can specify transform and/or target_transform which will be applied to the slice after loading. # In[ ]: frame_transform = tonic.transforms.ToImage( sensor_size=tonic.datasets.NMNIST.sensor_size ) sliced_dataset = SlicedDataset( dataset, slicer=slicer, transform=frame_transform, metadata_path="./metadata/nmnist" ) # In[ ]: frames, targets = sliced_dataset[100] # We can verify that the sum of events in the frames is the same as the number of events without transform. # In[ ]: print("Number of spikes: " + str(frames.sum())) assert frames.sum() == len(events) # ## Caching a SlicedDataset # To retrieve a slice from our new dataset means opening and loading the original recording, finding the desired slice and returning it. This adds considerable computational overhead. To speed things up we can make use of caching to store the slices either on disk or in memory. # In the next example we wrap our sliced dataset in a `MemoryCachedDataset`, which will write the slices to working memory, from where they can be retrieved very quickly the next time you need them (for example in the next training epoch). We'll also provide an augmentation transform that is applied post-loading from cache. # In[ ]: import torch import torchvision from tonic import MemoryCachedDataset torch.manual_seed(1234) augmentation = tonic.transforms.Compose( [torch.tensor, torchvision.transforms.RandomRotation([-45, 45])] ) augmented_dataset = MemoryCachedDataset(sliced_dataset, transform=augmentation) # In[ ]: rotated_frames, targets = augmented_dataset[200] # In[ ]: get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt plt.imshow(rotated_frames[0]); # In[ ]: