import os
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import torch
import torchaudio
import PyTCI.tci as tci
import PyTCI.audio as fx
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
In this example script, we create a mock model that convolves the input waveform by 7 causal filters emulating gamma distribution probability density functions, with different theoretical integration windows specified by the $\text{rate}$ parameter.
Since we know the filters, we can estimate the integration windows directly from the cumulative density function of the gamma distributions. Here, we approximate the integration window by measuring how long it takes for the area under the curve of each filter to reach 90% of its maximum value.
Our toy model consists of 5 output channels, each one the result of convolving the input with a causal gamma distributed window, with increasing temporal rates (decreasing temporal widths):
# Expected input sampling rate
in_sr = 16_000
# Different rates for the filter
rates = [10, 5, 2.5, 1.25, 0.625]
# Create filters gamma distributions
weights = []
for rate in rates:
dist = torch.distributions.gamma.Gamma(2, rate)
dist = torch.tensor([torch.exp(dist.log_prob(i)) for i in np.arange(1e-7, 10, 5/in_sr)])
weights.append(dist.flip(0) / sum(dist))
weights = torch.stack(weights, dim=0).unsqueeze(dim=1).to(device).float()
Here we plot the 5 integration windows of the model:
plt.figure(figsize=(8, 3))
plt.plot(np.arange(in_sr)/in_sr, weights.squeeze(1).T.flip(0)[:in_sr].cpu(), linewidth=1.5)
plt.xlabel('Time lag (s)', fontsize=11)
plt.ylabel('Amplitude', fontsize=11)
plt.title('Causal filters used as model', fontsize=12)
plt.xticks([0, 0.5, 1.0], fontsize=11)
plt.yticks([0, 0.0012], fontsize=11)
plt.grid(axis='y')
plt.legend(rates, loc='upper right', fontsize=11)
plt.show()
We create a function that applies each filter defined above to the input and returns a 5-channel output:
def model(x):
x = torch.as_tensor(x, device=device)
x = torch.nn.functional.pad(x, (0, 0, weights.shape[-1]-1, 0)).float()
x = x.T.unsqueeze(dim=0)
x = torch.nn.functional.conv1d(x, weights, stride=16)
x = x.squeeze(dim=0).T
return x
Note that the output sampling rate is 1/16 of the input sampling rate (i.e., 1KHz), due to the stride of 16 in the model.
Knowing the exact filters the model applies to the input and since the filters are static, we can directly measure the window of stimulus within which inputs affect the output, and outside of which they have little effect. Here, we approximate this window size by taking the integral of the filters, and finding the smallest time lag $t_I$ such that the integral of the filter for $0 ≤ t ≤ t_I$ is at least 90% of the total integral of the filter. In other words, the area under the filter curve from $0 ≤ t ≤ t_I$ is at least 90% of the total area under the filter curve.
cdf = weights.squeeze(1).cumsum(axis=1).cpu()
intwin_cdf = [(np.where(c.flip(0) >= 0.10)[0][-1])/in_sr+0.001 for c in cdf]
print('Directly estimated integration windows:', [round(_, 3) for _ in intwin_cdf], 'second')
Directly estimated integration windows: [0.079, 0.157, 0.312, 0.623, 1.198] second
The TCI analysis consists of three main steps:
This step entails extracting segments of a specified duration from a list of source stimuli, here excerpts of speech extracted from the LibriSpeech corpus. Then randomly shuffling the extracted segments and concatenating them (with crossfade) to form a long sequence. We repeat this shuffling-concatenation process twice for a given segment duration to achieve a sequence pair, and repeat the entire process above multiple times for different values of segment durations, to achieve a list of sequence pairs, each corresponding to a segment duration. Sequences should have shape $[time \times channel]$. Check the cells below for audio examples.
First, we load the source stimuli that will be used for analysis:
stimuli, in_sr = tci.load_stimuli('resources/segments-librispeech-1k/', fmt='audio')
ipd.display(ipd.Audio(stimuli[2].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10].T, rate=in_sr))
Note that the model output had a sampling rate 1KHz due to the stride:
out_sr = in_sr / 16
We can generate a single sequence from random orderings of 0.5 second segments:
seq = tci.generate_sequence(stimuli, in_sr, segment_dur=0.5, seed=1)
ipd.display(ipd.Audio(seq[:4*in_sr].T, rate=in_sr))
We can generate a sequence pair from two random orderings of 0.5 second subsegments:
seq_A, seq_B = tci.generate_sequence_pair(stimuli, in_sr, segment_dur=0.5)
ipd.display(ipd.Audio(seq_A[:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(seq_B[:4*in_sr].T, rate=in_sr))
Finally, we can generate multiple sequence pairs, one for each segment duration from a list of segment durations (as described above), to use in our analysis:
# Default list of segment durations from 20ms to 2.48s
segment_durs = tci.SEGMENT_DURS
# Generate all sequence pairs
sequence_pairs = tci.generate_sequence_pair(stimuli, in_sr, segment_durs)
print('Sequence pair for segment duration of 280ms:')
ipd.display(ipd.Audio(sequence_pairs[10][0][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(sequence_pairs[10][1][:4*in_sr].T, rate=in_sr))
print('Sequence pair for segment duration of 1.72s:')
ipd.display(ipd.Audio(sequence_pairs[20][0][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(sequence_pairs[20][1][:4*in_sr].T, rate=in_sr))
Sequence pair for segment duration of 280ms:
Sequence pair for segment duration of 1.72s:
The second step entails getting the model responses to all sequences generated above. Each sequence pair, is a tuple pair of two sequences with shapes $[time \times channel]$. For each sequence in a pair, the model should return a response time-course with shape $[time \times units]$.
response_pairs = [(model(seq_A), model(seq_B)) for seq_A, seq_B in sequence_pairs]
Since the duration of the generated sequences grow with the number of source stimuli and the segment duration, there is a helper function called infer_sequence_pair
that takes the model and a sequence pair (or a list of sequence pairs), batches the input stimulus across time into block_size
second batches, with an overlap of 2*context_size
seconds between two adjacent batches to avoid boundary effects, and runs through the model to get the model responses to each sequence:
response_pairs = tci.infer_sequence_pair(
model, sequence_pairs, segment_durs, in_sr=in_sr, out_sr=out_sr,
block_size=48.0, context_size=8.0, device=device
)
Note, that this function needs sampling rate of input and output (model activations) to batch correctly:
The third step entails rearranging the response sequence to achieve the pre-shuffling order, which we call the segment-aligned responses (SAR), then calculating the correlation between the responses across the two shuffling conditions (across SAR pair) to obtain what we call the cross-context correlation:
SAR_pairs = tci.rearrange_sequence_pair(
response_pairs, out_sr, segment_durs
)
cross_context_corrs = tci.cross_context_corrs(SAR_pairs)
Finally, now that for each model unit we have a cross-context correlation value per segment duration, we use a correlation threshold to determine the integration window of the model. In other words, we determine the smallest segment duration for which at some point the response of that unit in the two different context conditions achieves the threshold, i.e., becomes invariant to context.
intwin_tci = tci.estimate_integration_window(
cross_context_corrs, segment_durs, threshold=0.98
)
Now we can compare the integration windows estimated using TCI, with the integration windows estimated directly from the filter kernels:
plt.figure(figsize=(3, 3))
plt.plot(np.log([0.06, 1.5]), np.log([0.06, 1.5]), 'k--', linewidth=0.5)
for i, rate in enumerate(rates):
plt.scatter(np.log(intwin_cdf[i]), np.log(intwin_tci[i]), 40, label=rate)
plt.xticks(np.log([0.06, 0.3, 1.5]), [20, 160, 1200], fontsize=12)
plt.yticks(np.log([0.06, 0.3, 1.5]), [20, 160, 1200], fontsize=12)
plt.xlabel('Integration window\nfrom CDF (ms)', fontsize=12)
plt.ylabel('Integration window\nfrom TCI (ms)', fontsize=12)
plt.legend(loc='lower right', fontsize=9)
plt.show()
As expected, our analysis has identified the proper integration window for each kernel.
There are two methods for generating a TCI sequence pair, as described in the paper:
The default behavior is "random-random". To change the method to "natural-random", set the $\text{comparison_method}$ parameter to $\text{'natural-random'}$, when generating and rearranging sequence pairs.
# Default list of segment durations from 20ms to 2.48s
segment_durs = tci.SEGMENT_DURS
# Generate all sequence pairs
sequence_pairs = tci.generate_sequence_pair(stimuli, in_sr, segment_durs, comparison_method='natural-random')
print('Sequence pair for segment duration of 280ms:')
ipd.display(ipd.Audio(sequence_pairs[10][0][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(sequence_pairs[10][1][:4*in_sr].T, rate=in_sr))
print('Sequence pair for segment duration of 1.72s:')
ipd.display(ipd.Audio(sequence_pairs[20][0][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(sequence_pairs[20][1][:4*in_sr].T, rate=in_sr))
response_pairs = tci.infer_sequence_pair(
model, sequence_pairs, segment_durs, in_sr=in_sr, out_sr=out_sr,
block_size=48.0, context_size=8.0, device=device
)
SAR_pairs = tci.rearrange_sequence_pair(
response_pairs, out_sr, segment_durs, comparison_method='natural-random'
)
cross_context_corrs = tci.cross_context_corrs(SAR_pairs)
intwin_tci = tci.estimate_integration_window(
cross_context_corrs, segment_durs, threshold=0.98
)
Sequence pair for segment duration of 280ms:
Sequence pair for segment duration of 1.72s:
As shown below, the results from both cases are quite similar:
plt.figure(figsize=(3, 3))
plt.plot(np.log([0.06, 1.5]), np.log([0.06, 1.5]), 'k--', linewidth=0.5)
for i, rate in enumerate(rates):
plt.scatter(np.log(intwin_cdf[i]), np.log(intwin_tci[i]), 40, label=rate)
plt.xticks(np.log([0.06, 0.3, 1.5]), [20, 160, 1200], fontsize=12)
plt.yticks(np.log([0.06, 0.3, 1.5]), [20, 160, 1200], fontsize=12)
plt.xlabel('Integration window\nfrom CDF (ms)', fontsize=12)
plt.ylabel('Integration window\nfrom TCI (ms)', fontsize=12)
plt.legend(loc='lower right', fontsize=9)
plt.show()
It is possible to manipulate the source stimuli before generating the TCI sequences, using the $\text{process}$ parameter of $\text{load_segments}$ or $\text{process_segments}$ functions. This is mainly useful in two cases:
The $\text{process}$ parameter takes in a function or an ordered list of functions, that will be applied to each stimulus tensor in series. There are a number of basic audio manipulation functions provided in the $\text{audio}$ module. This module utilizes the SoX backend of the torchaudio library.
Here are some examples:
stimuli, in_sr = tci.load_stimuli(
'resources/segments-librispeech-1k/'
)
ipd.display(ipd.Audio(stimuli[2][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10][:4*in_sr].T, rate=in_sr))
stimuli, in_sr = tci.load_stimuli(
'resources/segments-librispeech-1k/',
process=fx.reverb_fx(reverberance=40, room_scale=70)
)
ipd.display(ipd.Audio(stimuli[2][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10][:4*in_sr].T, rate=in_sr))
stimuli, in_sr = tci.load_stimuli(
'resources/segments-librispeech-1k/',
process=fx.tempo_fx(scale_factor=1.5)
)
ipd.display(ipd.Audio(stimuli[2][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10][:4*in_sr].T, rate=in_sr))
stimuli, in_sr = tci.load_stimuli(
'resources/segments-librispeech-1k/',
process=fx.speed_fx(scale_factor=1.5)
)
ipd.display(ipd.Audio(stimuli[2][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10][:4*in_sr].T, rate=in_sr))
stimuli, in_sr = tci.load_stimuli(
'resources/segments-librispeech-1k/',
process=fx.pitch_fx(300)
)
ipd.display(ipd.Audio(stimuli[2][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10][:4*in_sr].T, rate=in_sr))
stimuli, in_sr = tci.load_stimuli(
'resources/segments-librispeech-1k/',
process=[fx.reverb_fx(reverberance=40, room_scale=70), fx.pitch_fx(100)]
)
ipd.display(ipd.Audio(stimuli[2][:4*in_sr].T, rate=in_sr))
ipd.display(ipd.Audio(stimuli[10][:4*in_sr].T, rate=in_sr))