Aug_DiskCachedDataset
is a modified version of DiskCachedDataset
that is useful while applying deterministic augmentations on data samples.
This is the case when the parameter space of augmentation is desceret, for instance applying pitchshift
on audio data in which shift parameter (semitone) can only take N values.
Using DiskCachedDataset
and setting num_copies
to N is likely to cause 2 issues:
Aug_DiskCachedDataset
resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account:
The user needs to pass all_transforms
dict as input with seperated transforms pre_aug
, aug
, post_aug
(spesifying transforms that are applied before and after augmentations, also augmentation transforms).
The augmentation class receives aug_index
(aug_index = copy) as initialization parameter also caching=True
needs to be set (please see tonic.audio_augmentations
)
Follwing is a simple example to show function of Aug_DiskCachedDataset
# %%writefile mini_dataset.py
import warnings
warnings.filterwarnings('ignore')
from torch.utils.data import Dataset
import numpy as np
class mini_dataset(Dataset):
def __init__(self) -> None:
super().__init__()
np.random.seed(0)
self.data = np.random.rand(10, 16000)
self.transform = None
self.target_transform = None
def __getitem__(self, index):
sample = self.data[index]
label = 1
if sample.ndim==1:
sample = sample[None,...]
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
label = self.target_transform(label)
return sample, label
Aug_DiskCachedDataset
with transforms¶from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache
from tonic.audio_transforms import AmplitudeScale, FixLength
from tonic.audio_augmentations import RandomPitchShift
all_transforms = {}
all_transforms["pre_aug"] = [AmplitudeScale(max_amplitude = 0.150)]
all_transforms["augmentations"] = [RandomPitchShift(samplerate=16000, caching=True)]
all_transforms["post_aug"] = [FixLength(16000)]
# number of copies is set to number of augmentation params (factors)
n = len(RandomPitchShift(samplerate=16000, caching=True).factors)
Aug_cach = Aug_DiskCachedDataset(dataset=mini_dataset(), cache_path='cache/', all_transforms = all_transforms, num_copies=n)
- 10 augmented versions of data sample with index = 0 are generated
sample_index = 0
Aug_cach.generate_all(sample_index)
- loading the saved copies
- and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter
- they are equal
from torchvision.transforms import Compose
for i in range(n):
transform = Compose([AmplitudeScale(max_amplitude = 0.150),RandomPitchShift(samplerate=16000, caching=True, aug_index=i), FixLength(16000)])
ds = mini_dataset()
ds.transform = transform
sample = ds[sample_index][0]
data, targets = load_from_disk_cache('cache/' + '0_' + str(i) + '.hdf5' )
print((sample==data).all())
True True True True True True True True True True