#!/usr/bin/env python # coding: utf-8 # ## Using Aug_DiskCachedDataset for efficient caching of augmented copies # - `Aug_DiskCachedDataset` is a modified version of `DiskCachedDataset` that is useful while applying deterministic augmentations on data samples. # # - This is the case when the parameter space of augmentation is desceret, for instance applying `pitchshift` on audio data in which shift parameter (semitone) can only take N values. # # - Using `DiskCachedDataset` and setting `num_copies` to N is likely to cause 2 issues: # # - Copies might not be unique, as copy_index is not linked to the augmentation parameter # - And there is no guarantee that copies cover the desired augmentation space # # # # - `Aug_DiskCachedDataset` resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account: # # - The user needs to pass `all_transforms` dict as input with seperated transforms `pre_aug`, `aug`, `post_aug` (spesifying transforms that are applied before and after augmentations, also augmentation transforms). # # - The augmentation class receives `aug_index` (aug_index = copy) as initialization parameter also `caching=True` needs to be set (please see `tonic.audio_augmentations`) # # - Follwing is a simple example to show function of `Aug_DiskCachedDataset` # ### A simple dataset # In[1]: # %%writefile mini_dataset.py import warnings warnings.filterwarnings('ignore') from torch.utils.data import Dataset import numpy as np class mini_dataset(Dataset): def __init__(self) -> None: super().__init__() np.random.seed(0) self.data = np.random.rand(10, 16000) self.transform = None self.target_transform = None def __getitem__(self, index): sample = self.data[index] label = 1 if sample.ndim==1: sample = sample[None,...] if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: label = self.target_transform(label) return sample, label # ### Initializing `Aug_DiskCachedDataset` with transforms # In[2]: from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache from tonic.audio_transforms import AmplitudeScale, FixLength from tonic.audio_augmentations import RandomPitchShift all_transforms = {} all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude = 0.150)] all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)] all_transforms["post_aug"] = [FixLength(16000)] # number of copies is set to number of augmentation params (factors) n = len(RandomPitchShift(samplerate=16000, caching=True).factors) Aug_cach = Aug_DiskCachedDataset(dataset=mini_dataset(), cache_path='cache/', all_transforms = all_transforms, num_copies=n) # ### Generating all copies of a data sample # - 10 augmented versions of data sample with index = 0 are generated # In[3]: sample_index = 0 Aug_cach.generate_all(sample_index) # ### To verify # - loading the saved copies # - and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter # - they are equal # # In[7]: from torchvision.transforms import Compose for i in range(n): transform = Compose([AmplitudeScale(max_amplitude = 0.150),RandomPitchShift(samplerate=16000, caching=True, aug_index=i), FixLength(16000)]) ds = mini_dataset() ds.transform = transform sample = ds[sample_index][0] data, targets = load_from_disk_cache('cache/' + '0_' + str(i) + '.hdf5' ) print((sample==data).all())