#| default_exp aa_mixer_toy
Toy model of the (intended) Audio Algebra mixer model, or "abusing VICReg for fun and profit"
NOTE: to view this notebook properly, open it via NBViewer or Colab. GitHub won't show you everything.
We stick an autoencoder in the middle of some "given" encoder(-decoder) system, and we'll try to get our new system to preserve linearity with respect to the (original) inputs.
Basic Idea: Inputs $x$ -> Given Model encodings $y$ -> AA embeddings $z$
Flowchart:
We want this new "audio algebra" map $h$ to have two properties, which we'll write as loss functions:
i.e. $$ "zmix = zsum" $$ So the mix loss will be
mix_loss
= MSE( zmix
- zsum
).
aa_recon_loss
= MSE ( ymix
- y_recon
) + MSE( y_i
- y_i_recon
),
where the "recon" variables are in terms of $h^{-1}$ applied to the z's, e.g. ymix_recon
:= $h^{-1}$(zmix) ).
...the challenge is that the "mix loss" leads to "collapse": all the points either constrict to the origin, or towards some constant vector. We'll use the VICReg method to fix these problems. More on VICReg further below.
Instead of real audio, in order to keep things simple and visualize-able, we'll just work in terms of 2D data points. The real problem we want so solve eventually involves huge numbers of dimensions that we can't visualize. The hope is to check our work and understanding and gain intuition using this simple "toy model".
# install a few non-standard dependencies (for colab)
%pip install -Uqq pip einops tqdm wandb
Note: you may need to restart the kernel to use updated packages.
#| export
import math
import os
import matplotlib.pyplot as plt
from PIL import Image
import random
import string
import numpy as np
import torch
from torch import optim, nn, Tensor
from torch.nn import functional as F
from torch.utils import data as torchdata
from einops import rearrange
from tqdm.auto import tqdm
import wandb
from aeiou.core import get_device
Get/set a few "environment" related variables
device = get_device()
print("device = ",device)
SCRATCH_DIR = "/scratch" if os.path.isdir('/scratch') else "/tmp"
SCRATCH_DIR += '/aa' # give audio alg scratch its own spot
if not os.path.exists(SCRATCH_DIR):
os.makedirs(SCRATCH_DIR)
print("SCRATCH_DIR =",SCRATCH_DIR)
device = cuda SCRATCH_DIR = /tmp/aa
# generate a unique 2-character code for the run
suffix_len = 2
RUN_SUFFIX = '_'+''.join(random.choices(string.ascii_lowercase, k=suffix_len))
print("RUN_SUFFIX =",RUN_SUFFIX)
RUN_SUFFIX = _ao
For starters, just a bunch of random numbers
seed = 5
torch.manual_seed(seed)
batch_size = 1024
n_train_points, train_val_split = batch_size*2000, 0.8 # we're going to grab random data anyway
n_points = int(n_train_points*1.25)
in_dims = 2 # number of dimensions the input data will live in, e.g. 2 or 3
emb_dims = 2 # number of dimensions for embeddings
train_val_split = 0.8
train_len, val_len = round(n_points*train_val_split), round(n_points*(1-train_val_split))
class RandVecDataset(torchdata.Dataset):
"very simple dataset"
def __init__(self, length, dims):
super().__init__()
self.data = 2*torch.rand(length, dims) - 1
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return self.data.shape[0]
train_dataset = RandVecDataset(train_len, in_dims)
val_dataset = RandVecDataset(val_len, in_dims)
v = train_dataset.__getitem__(0)
print(train_dataset.__len__(), v.shape)
print(val_dataset.__len__(), v.shape)
2048000 torch.Size([2]) 512000 torch.Size([2])
train_dl = torchdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dl = torchdata.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
val_iter = iter(val_dl)
DEMO BATCH: Vizualization demo batch (same each time)
#| export
def get_demo_batch(val_iter,
demo_line = True,
debug=False):
demo_batch = next(val_iter).clone()
absmax = (torch.abs(demo_batch)).max().item()
if demo_line: # make demo batch lie along a line in order to showcase nonlinearity
demo_batch[:,1] = demo_batch[:,0] # line y = x
demo_batch[::2,0] *= -1 # even-numbered points go to y = -x
demo_batch[::4,0] = 0.01 # data along y axis
demo_batch[1::4,1] = 0.01 # data along x axis
demo_batch[::5] = absmax*(2*torch.rand((demo_batch[::5]).shape)-1) # make some points random
demo_batch[1::7] = absmax*(2*torch.rand((demo_batch[1::7]).shape)-1) # and more points random
return demo_batch
# plot what the demo batch looks like in input space -- you're going to be seeing this alot!
fig, ax = plt.subplots(1,3, figsize=(12,4))
demo_batch = get_demo_batch(val_iter, debug=True)
ax[0].plot(demo_batch[:,0], demo_batch[:,1], marker='.', linestyle='None')
ax[0].set_title('single "lines+dots" stem')
print("")
demo_mix = demo_batch + get_demo_batch(val_iter, demo_line=True, debug=True) # add another 'line stem'
ax[1].plot(demo_mix[:,0], demo_mix[:,1], marker='.', linestyle='None')
ax[1].set_title('mix of two "lines+dots" stems')
demo_mix = demo_batch + get_demo_batch(val_iter, demo_line=False, debug=True) # random noise will swamp lines
ax[2].plot(demo_mix[:,0], demo_mix[:,1], marker='.', linestyle='None')
ax[2].set_title('mix of "lines+dots" + unit noise')
Text(0.5, 1.0, 'mix of "lines+dots" + unit noise')
...so in general our mix of random numbers -- even if one of the stems has a nice pattern -- will just be a bunch of random numbers.
We'll use the far right for "demo_mix" since that's what it'll typically look like, whereas the middle graph is contrived.
#| export
class EmbedBlock(nn.Module):
def __init__(self, in_dims:int, out_dims:int, act=nn.GELU(), resid=True, use_bn=False, requires_grad=True, **kwargs) -> None:
"generic little block for embedding stuff. note residual-or-not doesn't seem to make a huge difference for a-a"
super().__init__()
self.in_dims, self.out_dims, self.act, self.resid = in_dims, out_dims, act, resid
self.lin = nn.Linear(in_dims, out_dims, **kwargs)
self.bn = nn.BatchNorm1d(out_dims) if use_bn else None # even though points in 2d, only one non-batch dim in data
if requires_grad == False:
self.lin.weight.requires_grad = False
self.lin.bias.requires_grad = False
def forward(self, xin: Tensor) -> Tensor:
x = self.lin(xin)
if self.act is not None: x = self.act(x)
if self.bn is not None: x = self.bn(x) # re. "BN before or after Activation? cf. https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md"
return xin + x if (self.resid and self.in_dims==self.out_dims) else x
# test that
emb_test = EmbedBlock(2, 2, use_bn=True, requires_grad=False)
#res_test = ResBlock(2, 2)
This is a stand-in for whatever the main encoder is to be,
i.e. for which the audio-algebra is going be inserted in the middle of.
This could be an actual audio encoder or,...just something random.
Now, for word embeddings these are typically just weights from a linear transformation, but we're going to assume that there's maybe some set of nonlinear tranformations that led us to this point. I made up a twisty nonlinear model for the toy model to use.
Oh, first I'm going to define a couple nonlinearities that might be used in the model.
#| export
def friendly_tanh(x, a=3, b=0.1, s=0.9):
"like a tanh but not as harsh"
return (torch.tanh(a*x) + b*x)*s
#return torch.tanh(a*x) + b*torch.sgn(x)*torch.sqrt(torch.abs(x))
def compressor(x, thresh=0.7, ratio=10):
"another nonlinearity: kind of like an audio compressor"
return torch.where(torch.abs(x) < thresh, x, torch.sgn(x)*(1/ratio*(torch.abs(x)-thresh)+thresh))
...and we'll just quickly visualize them.
x = torch.linspace(-2.5,2.5,500)
fig, ax = plt.subplots(figsize=(9,4))
ax.plot(x.cpu().numpy(), torch.tanh(3*x).cpu().numpy(), label='tanh')
ax.plot(x.cpu().numpy(), friendly_tanh(x,a=3).cpu().numpy(), label='friendly_tanh')
ax.plot(x.cpu().numpy(), compressor(x).cpu().numpy(), label='"compressor"')
ax.legend()
<matplotlib.legend.Legend>
#| export
class TwistAndScrunch(nn.Module):
"""for now just something random with a bit of nonlinearity
"""
def __init__(self,
in_dims=2, # number of input dimensions to use
emb_dims=2, # number of output dimensions to use
hidden_dims=32, # number of hidden dimensions
scrunch_fac=1.7, twist_fac=0.2, debug=False, act=nn.GELU(),
):
super().__init__()
self.sf, self.debug, = scrunch_fac, debug,
self.pi = torch.acos(torch.zeros(1)).item() * 2
self.decoder = nn.Sequential(
EmbedBlock(emb_dims, hidden_dims, act=act),
EmbedBlock(hidden_dims, hidden_dims, act=act),
nn.Linear(hidden_dims, in_dims) )
# note we'll never use this encoder. Just included this to keep other codes from crashing. ;-)
self.encoder = nn.Sequential( # note that this is not going to be used.
EmbedBlock(emb_dims, hidden_dims, act=act),
EmbedBlock(hidden_dims, hidden_dims, act=act),
nn.Linear(hidden_dims, in_dims)
)
def encode(self, v): # some nonlinear function i made up
vp = v.clone()
vp[:,0], vp[:,1] = vp[:,0] - 0.5*friendly_tanh(v[:,1],a=3), vp[:,1] + 0.7*friendly_tanh(v[:,0],a=1.2)
return friendly_tanh(self.sf*vp,a=2,b=0.1)
def decode(self, emb, exact=False):
return self.decoder(emb)
def forward(self, x, debug=False, exact=False):
debug_save, self.debug = self.debug, debug # only turn on debug for this one call
emb = self.encode(x)
outs = self.decode(emb, exact=exact), emb
self.debug = debug_save
return outs
given_model.encode
on Inputs¶torch.manual_seed(seed)
#given_model = SimpleAutoEncoder(in_dims=in_dims, emb_dims=emb_dims).to(device)
given_model = TwistAndScrunch()
given_model = given_model.to(device)
#| export
def viz_given_batch(
val_batch, # should be from val_batch
given_model, # the main autoencoder model we are splicing into
aa_model=None, # our audio algebra model
debug=False,
device=get_device(),
return_img = True,
show_recon=False,
):
markers=['o', '+', 'x', '*', '.', 'X', "v", "^", "<", ">", "8", "s", "p", "P", "h", "H", "D", "d","|","_"]
colors=['blue', 'magenta','orange', 'green', 'red', 'cyan', 'black']
markersize=2.5
val_batch = val_batch.to(device)
with torch.no_grad():
emb_batch = given_model.encode(val_batch)
recon_batch = given_model.decode(emb_batch)
fig, ax = plt.subplots(1,2, figsize=(11,5))
ax[0].set_title(f"input space. $x \in X$")
ax[1].set_title(f"given encoder space. $y \in Y$")
if debug:
print("absmax of val_batch = ",torch.abs(val_batch).max().item())
print("absmax of emb_batch = ",torch.abs(emb_batch).max().item())
nm = len(markers)
for mi in range(nm):
#for i in range(val_batch.shape[0]): # TODO: looping over every points is very slow, but apparently needed for arrows
for bi,b in enumerate([val_batch, emb_batch, recon_batch]):
x, y = b[mi::nm,0].cpu().numpy(), b[mi::nm,1].cpu().numpy()
if (bi<2) or show_recon: ax[bi%2].plot(x, y, marker=markers[mi], markersize=markersize, color=colors[bi], linestyle='None', label='' if mi > 0 else ['input', 'embedding', 'recon'][bi])
ax[0].legend(loc='upper left')
val_plot_fname = f'{SCRATCH_DIR}/val_plot{RUN_SUFFIX}{random.randint(1,5)}.png'
plt.savefig(val_plot_fname, bbox_inches='tight', dpi=300)
plt.close()
im = Image.open(val_plot_fname)
return im
print("for single stem:")
viz_given_batch(demo_batch, given_model, debug=True)
for single stem: absmax of val_batch = 0.9996815919876099 absmax of emb_batch = 1.141883134841919
print("Also operate on the demo mix:")
viz_given_batch(demo_mix, given_model, debug=True)
Also operate on the demo mix: absmax of val_batch = 1.9301079511642456 absmax of emb_batch = 1.2925496101379395
this is the model that we want to train
#| export
class AudioAlgebra(nn.Module):
"""
Main AudioAlgebra model. Contrast to aa-mixer code, keep this one simple & move mixing stuff outside
"""
def __init__(self, dims=2, hidden_dims=64, act=nn.GELU(), use_bn=False, resid=True, block=EmbedBlock):
super().__init__()
self.encoder = nn.Sequential(
block( dims, hidden_dims, act=act, use_bn=use_bn, resid=resid),
block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid),
block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid),
block( hidden_dims, dims, act=None, use_bn=use_bn, resid=resid),
)
self.decoder = nn.Sequential( # same as encoder, in fact.
block( dims, hidden_dims, act=act, use_bn=use_bn, resid=resid),
block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid),
block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid),
block( hidden_dims, dims, act=None, use_bn=use_bn, resid=resid),
)
def encode(self,x):
return x + self.encoder(x)
def decode(self,x):
return x + self.decoder(x)
def forward(self,
x # the enbedding vector from the given encoder
):
xprime = self.encode(x)
xprimeprime = self.decode(xprime) # train system to invert itself (and hope it doesn't all collapse to nothing!)
return xprime, xprimeprime # encoder output, decoder output
really this is more of a dataloader
utility but for now its being called from the main loop because it involves less change to the dataloader. ;-)
#| export
def get_stems_faders(batch, # "1 stem" (or batch thereof) already drawn fron the dataloader (val or train)
dl_iter, # pre-made the iterator for the/a dataloader
dl, # the dataloader itself, for restarting
maxstems=2, # how many total stems will be used, i.e. draw maxstems-1 new stems from dl_iter
unity_gain=False, # this will force all faders to be +/-1 instead of random numers
debug=False):
"grab some more inputs and multiplies and some gain values to go with them"
nstems = random.randint(2, maxstems)
if debug: print("maxstems, nstems =",maxstems, nstems)
device=batch.device
faders = torch.sgn(2*torch.rand(nstems)-1) # random +/- 1's
if not unity_gain:
faders += 0.5*torch.tanh(2*(2*torch.rand(nstems)-1)) # gain is now between 0.5 and 1.5
stems = [batch] # note that stems is a list
for i in range(nstems-1): # in addtion to the batch of stem passed in, grab some more
try:
next_stem = next(dl_iter).to(device) # this is just another batch of input data
except StopIteration:
dl_iter = iter(dl) # time to restart. hoping this propagates out as a pointer
next_stem = next(dl_iter).to(device)
if debug: print(" next_stem.shape = ",next_stem.shape)
stems.append(next_stem)
return stems, faders.to(device), dl_iter # also return the iterator
Code to to do the encoding of stems and faders into zmix and zsum
def do_mixing(stems, faders, given_model, aa_model, device, debug=False, **kwargs):
"""
here we actually encode inputs.
0's denote values in the given model space, non-0's denode those in our aa_model
e.g., "y" denotes an embedding from the frozen encoder, "z" denotes re-mapped embeddings
"""
zs, ys, zsum, ysum, yrecon_sum, fadedstems = [], [], None, None, None, []
mix = torch.zeros_like(stems[0]).to(device)
#if debug: print("do_mixing: stems, faders =",stems, faders)
for s, f in zip(stems, faders): # iterate through list of stems, encode a bunch of stems at different fader settings
fadedstem = (s * f).to(device) # audio stem adjusted by gain fader f
with torch.no_grad():
y = given_model.encode(fadedstem) # encode the stem
z, y_recon = aa_model(y) # <-- this is the main work of the model
zsum = z if zsum is None else zsum + z # <---- compute the sum of all the z's so far. we'll end up using this in our (metric) loss as "pred"
mix += fadedstem # make full mix in input space
with torch.no_grad():
ymix = given_model.encode(mix) # encode the mix in the given model
zmix, ymix_recon = aa_model(ymix) # <----- map that according to our learned re-embedding. this will be the "target" in the metric loss
# for diagnostics:
ysum = y if ysum is None else ysum + y # = sum of embeddings in original model space; we don't really care about ysum except for diagnostics
yrecon_sum = y_recon if yrecon_sum is None else yrecon_sum + y_recon # = sum of embeddings in original model space; we don't really care about ysum except for diagnostics
zs.append(z) # save a list of individual z's
ys.append(y) # save a list of individual y's
fadedstems.append(fadedstem) # safe a list of each thing that went into the mix
archive = {'zs':zs, 'mix':mix, 'znegsum':None, 'ys': ys, 'ysum':ysum, 'ymix':ymix, 'ymix_recon':ymix_recon, 'yrecon_sum':yrecon_sum, 'fadedstems':fadedstems} # more info for diagnostics
return zsum, zmix, archive # we will try to get these two to be close to each other via loss. archive is for diagnostics
aa_use_bn = False # batch norm?
aa_use_resid = True # use residual connections? (doesn't make much difference tbh)
hidden_dims = 64 # number of hidden dimensions in aa model. usually was 64
torch.manual_seed(seed) # chose this value because it shows of nice nonlinearity
aa_model = AudioAlgebra(dims=emb_dims, hidden_dims=hidden_dims, use_bn=aa_use_bn, resid=aa_use_resid).to(device)
...and visualize its effects
#| export
def multimarkerplot(ax, data, color='blue', markersize=2.7, label=''):
"little thingy for using lots of symbols for a plot"
markers = ['o', '+', 'x', '*', '.', 'X', "v", "^", "<", ">", "8", "s", "p", "P", "h", "H", "D", "d","|","_"]
nm = len(markers)
for mi in range(nm):
ax.plot(data[mi::nm,0], data[mi::nm,1], marker=markers[mi], markersize=markersize, color=color, linestyle='None', label=label if mi==0 else '')
def viz_aa_demo(
zsum, zmix,
archive,
aa_model,
debug=False,
):
"for visualizing what aa model is doing"
colors=['blue', 'red', 'green', 'magenta', 'orange', 'cyan', 'black',]
col_labels = ["inputs", "given representations", "a-a embeddings"]
nrows, ncols = 2, len(col_labels)
figsize = (4*ncols, 4*nrows)
fig, ax = plt.subplots(nrows,ncols, figsize=figsize)
# top row: stems*faders, ys, zs: each of these is a list
keys, varnames = ['fadedstems', 'ys', 'zs'], ['x', 'y', 'z']
for coli, key in enumerate(keys):
ax[0,coli].set_title(f"{col_labels[coli]} ${varnames[coli]}$")
datalist = [x.detach().cpu().numpy() for x in archive[key]]
for i, stem in enumerate(datalist):
label = f"{varnames[coli]}{i+1}"
if coli > 0:
label += f" := f(x{i+1})" if coli==1 else f" := h(y{i+1})"
multimarkerplot( ax[0,coli], stem, color=colors[i], label=label)
ax[0,coli].legend(fancybox=True, framealpha=0.6, loc='upper left', prop={'size': 9})
# bottom row: mix
archive["zmix"] = zmix # lil prep
keys = ['mix', 'ymix', 'zmix']
annotations = [' := x1 + x2', r' := $f\,$(mix)', r' := $h$(ymix)']
for coli, key in enumerate(keys):
ax[1,coli].set_title(f"{col_labels[coli]}")
data = archive[key].detach().cpu().numpy()
multimarkerplot( ax[1,coli], data, label=key+annotations[coli], color=colors[2])
if key=='ymix': # add another plot
multimarkerplot( ax[1,coli], archive["ymix_recon"].detach().cpu().numpy(), label=r'ymix_recon := $h^{-1}$(zmix)', color=colors[4])
elif key=='zmix':
multimarkerplot( ax[1,coli], zsum.detach().cpu().numpy(), label='zsum := z1 + z2', color=colors[3])
ax[1,coli].legend(fancybox=True, framealpha=0.6, loc='upper left', prop={'size': 9})
if debug: print("saving to file")
filename = f'{SCRATCH_DIR}/aa_val_plot{RUN_SUFFIX}.png'
plt.savefig(filename, bbox_inches='tight', dpi=225)
plt.close()
if debug: print("loading file")
im = Image.open(filename)
return im
# test the viz
torch.manual_seed(seed)
demo_stems, demo_faders, val_iter = get_stems_faders(demo_batch, val_iter, val_dataset, debug=True, unity_gain=False, maxstems=2)
print("demo_faders = ",demo_faders)
print("calling do_mixing...")
demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device, debug=True)
print("calling viz_aa_batch...")
im = viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model, debug=False)
im
maxstems, nstems = 2 2 next_stem.shape = torch.Size([1024, 2]) demo_faders = tensor([ 1.4630, -0.5718]) calling do_mixing... calling viz_aa_batch...
The fact that zsum and zmix (bottom right) are not the same is the problem that our aa-mixer model will try to put right.
Define some losses
mseloss = nn.MSELoss()
def rel_loss(y_pred: torch.Tensor, y: torch.Tensor, eps=1e-3) -> float:
"relative error loss --- note we're never going to actually use this. it was just part of development"
e = torch.abs(y.view_as(y_pred) - y_pred) / ( torch.abs(y.view_as(y_pred)) + eps )
return torch.median(e)
We could try to train the model now, but the representations would collapse, either toward the origin or toward a line.
Avoiding collapse has been some of the primary outputs of the Meta AI (Yann LeCun et al)'s work on Self-Supervised Learning (SSL) of representations. Mostly they're interested in classification of images, but we will try to adapt this to our purposes. A few notable papers, in chronological and topical order:
The VICReg paper's architecture diagram and figure caption is as follows:
"Figure 1: **VICReg: joint embedding architecture with variance, invariance and covariance
regularization.** Given a batch of images $I$, two batches of different views $X$ and $X'$ are produced and are then encoded into representations $Y$ and $Y'$. The representations are fed to an expander producing the embeddings $Z$ and $Z'$. The distance between two embeddings from the same image is minimized, the variance of each embedding variable over a batch is maintained above a threshold, and the covariance between pairs of embedding variables over a batch are attracted to zero, decorrelating the variables from each other. Although the two branches do not require identical architectures nor share weights, in most of our experiments, they are Siamese with shared weights: the encoders are ResNet-50 backbones with output dimension 2048. The expanders have 3 fully-connected layers of size 8192." -- Bardes et al
...Note that we're going to be using this scheme a little differently than the original authors intended:
Instead of sending augments of crops of image down the lower "path" of the above diagram, we're going to send the mix to get $zmix$! And along the top path, we're going to send multiple stems independently $x_i \mapsto z_i$, and then sum these up to get $zsum := \sum_i z_i$ before comparing that against $z_{mix}$. So our similarity function will have the form $s( zmix, zsum )$ instead of $s(Z, Z')$ shown in the digram above.
Our diagram looks like this:
...and the "augmentation" or "views" will take the form of us just grabbing random pairs of stems over & over.
As explained in the "VICReg: Intuition" section of the paper:
"The basic idea is to use a loss function with three terms:
- Invariance: the mean square distance between the embedding vectors.
- Variance: a hinge loss to maintain the standard deviation (over a batch) of each variable of
the embedding above a given threshold. This term forces the embedding vectors of samples within a batch to be different.
- Covariance: a term that attracts the covariances (over a batch) between every pair of
(centered) embedding variables towards zero. This term decorrelates the variables of each embedding and prevents an informational collapse in which the variables would vary together or be highly correlated."
The invariance part is easy. But that's what leads to collapse.
The variance part is about using a hinge loss on some scale $\gamma$ ("fixed to $1$ in our experiments") minus the standard deviation $S$ across each dimension $j$ (for $d$ dimensions) within the batch $Z$:
$$ v(Z) = \frac{1}{d} \sum_{j=1}^{d} \max(0, \gamma - S(z^{j}, \epsilon)) $$where $z^j$ denotes "the vector composed of each value at dimension $j$ in all vectors in $Z$," and $\epsilon$ is some small constant to prevent numerical instability, i.e.:
$$ S(x, \epsilon) = \sqrt{\mathrm{Var}(x) + \epsilon} $$In other words, they're trying to make sure the variance doesn't get any smaller than $\gamma$ (but larger than $\gamma$ is fine).
#| export
def vicreg_var_loss(z, gamma=1, eps=1e-4):
std_z = torch.sqrt(z.var(dim=0) + eps)
return torch.mean(F.relu(gamma - std_z)) # the relu gets us the max(0, ...)
Note: The authors say, "Using the standard deviation and not directly the variance is crucial."
What about the covariance term? This is another way to prevent collapse. While the invariance can prevent collapse toward zero, the covariance can prevent collapse to a constant vector. This part I did not have! Or at least didn't have it right. For the covariance, we need a notion of "what should be different" -- easy! We'll just use other elements within the same batch!
Let's adapt the VICReg code:
#| export
def off_diagonal(x):
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
def vicreg_cov_loss(z):
"the regularization term that is the sum of the off-diagaonal terms of the covariance matrix"
cov_z = torch.cov(z.T)
return off_diagonal(cov_z).pow_(2).sum().div(z.shape[1])
#| export
model_path = 'aa_model.pth'
def load_aa_checkpoint(model_path='aa_model.pth', emb_dims=2, device=get_device()):
aa_model = None
try:
aa_model = torch.load(model_path, map_location=torch.device('cpu')).to(device)
except Exception as e:
print(f"Error trying to load presaved models:\n{e}\n\nStarting from scratch.")
return None, None
print("Successfully loaded pretrained models")
return aa_model
def save_aa_checkpoint(aa_model, model_path='aa_model.pth', suffix=''):
"save a model checkpoint -- full version. note we save to scratch so as not to overwrite existing file"
torch.save(aa_model, SCRATCH_DIR+'/'+model_path.replace(".pth", f"{suffix}.pth"))
start_from_pretrained = False
result = "Starting from scratch."
if start_from_pretrained:
aa_load = load_aa_checkpoint(emb_dims=emb_dims, device=device)
if aa_load is not None:
aa_model = aa_load
# re-viz showing loaded model
demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device, debug=True)
result = viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model=aa_model)
print("Successfully loaded checkpoint")
result
'Starting from scratch.'
aa_model
with VICReg¶wandb.login()
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving. wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server) wandb: You can find your API key in your browser here: https://wandb.ai/authorize wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:
········
wandb: Appending key for api.wandb.ai to your netrc file: /home/ec2-user/.netrc
True
#| export
train_new_aa = True # set to false to disable this training
def sayno(b, suffix=''): # just a little util for naming runs
return f'no{suffix}' if (b is None) or (b==False) else ''
def train_aa_vicreg(debug = False):
"packaging as def to re-use later. this will use global variables, not sorry"
max_epochs = 40
lossinfo_every, viz_demo_every = 20, 1000 # in units of steps
checkpoint_every = 10000
max_lr= 0.002
# optimizer and learning rate scheduler
opt = optim.Adam([*aa_model.parameters()], lr=5e-4)
total_steps = len(train_dataset)//batch_size * max_epochs
print("total_steps =",total_steps) # for when I'm checking wandb
scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=max_lr, total_steps=total_steps)
# rename our wandb run with some significant info
wandb.init(project='aa-toy-vicreg')
run_name = wandb.run.name
wandb.run.name = f"h{hidden_dims}_{'bn_' if aa_use_bn else ''}bs{batch_size}_lr{max_lr}{RUN_SUFFIX}"
print("New run name =",wandb.run.name)
wandb.run.save()
# training loop
train_iter = iter(train_dl) # this is only for use with get_stems_faders
epoch, step = 0, 0
torch.manual_seed(seed) # for reproducibility
while (epoch < max_epochs) or (max_epochs < 0): # training loop
with tqdm(train_dl, unit="batch") as tepoch:
for batch in tepoch: # train
opt.zero_grad()
log_dict = {}
batch = (1.2*(2*torch.rand(batch.shape, device=device)-1)) # overwrite batch with unending random data!
#batch = batch.to(device)
stems, faders, train_iter = get_stems_faders(batch, train_iter, train_dl, debug=debug)
# vicreg: 1. invariance
zsum, zmix, archive = do_mixing(stems, faders, given_model, aa_model, device, debug=debug)
mix_loss = mseloss(zsum, zmix)
var_loss = (vicreg_var_loss(zsum) + vicreg_var_loss(zmix))/2 # vicreg: 2. variance
cov_loss = (vicreg_cov_loss(zsum) + vicreg_cov_loss(zmix))/2 # vicreg: 3. covariance
# reconstruction loss: inversion of aa map h^{-1}(z): z -> y, i.e. train the aa decoder
y = given_model.encode(batch)
z, yrecon = aa_model.forward(y) # (re)compute ys for one input batch (not stems&faders)
aa_recon_loss = mseloss(y, yrecon)
aa_recon_loss = aa_recon_loss + mseloss(archive['ymix'], archive['ymix_recon']) # also recon of the mix ecoding
#aa_recon_loss = aa_recon_loss + mseloss(archive['ysum'], archive['yrecon_sum']) # Never use this: ysum shouldn't matter / is poorly defined
loss = mix_loss + var_loss + cov_loss + aa_recon_loss # --- full loss function
log_dict['train_loss'] = loss.detach() # --- this is the full loss
log_dict['mix_loss'] = mix_loss.detach()
log_dict['aa_recon_loss'] = aa_recon_loss.detach()
log_dict['var_loss'] = var_loss.detach()
log_dict['cov_loss'] = cov_loss.detach()
log_dict['learning_rate'] = opt.param_groups[0]['lr']
log_dict['epoch'] = epoch
if step % lossinfo_every == 0:
tepoch.set_description(f"Epoch {epoch+1}/{max_epochs}")
tepoch.set_postfix(loss=loss.item()) # TODO: use EMA for loss display?
loss.backward()
opt.step()
with torch.no_grad():
# run on demo data (same each time)
demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device, debug=debug)
demo_mix_loss = mseloss(demo_zsum, demo_zmix)
log_dict['demo_mix_loss'] = demo_mix_loss.detach()
if step % viz_demo_every == 0: # but only draw the viz every now and then.
try: # to avoid "OSError: image file is truncated"
im = viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model)
log_dict["aa_mapping"] = wandb.Image(im)
except:
pass
if step % checkpoint_every == 0:
save_aa_checkpoint(aa_model, suffix=RUN_SUFFIX+f"_{step}")
wandb.log(log_dict)
scheduler.step()
step += 1
epoch += 1
#----- training loop finished
save_aa_checkpoint(aa_model, suffix=RUN_SUFFIX+f"_{step}")
if train_new_aa:
train_aa_vicreg()
wandb: Currently logged in as: drscotthawley. Use `wandb login --relogin` to force relogin
total_steps = 80000
/home/ec2-user/SageMaker/audio-algebra/wandb/run-20230104_095749-2khoulne
wandb: WARNING Calling wandb.run.save without any arguments is deprecated.Changes to attributes are automatically persisted.
New run name = h64_bs1024_lr0.002_ao
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
0%| | 0/2000 [00:00<?, ?batch/s]
Let's see how we did...
torch.manual_seed(seed)
demo_stems, demo_faders, val_iter = get_stems_faders(demo_batch, val_iter, val_dataset, maxstems=2)
demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device)
viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model)
^^ If the green markers in the bottom right and bottom middle panels are covered up, then we're good!
BTW, if you want to make a movie of those images over the course of training, see the instructions by Scott Condon of Weights and Biases, except you should use this here Colab link to my modification of his code. Here's a movie I made:
from IPython.display import Video
video_url = "https://hedges.belmont.edu/aa_1280.mp4"
Video(video_url, width=720)
# NOTE that GitHub won't display this video; you have to execute the above code
aa_model
(requires Decoder for given_model
)¶The next part involves mapping to back to input space, so you need a trained decoder for the given model $f(x): x\mapsto y$.
i.e. we're going to be doing a lot of $z := h(f(x))$, and $x := f^{-1}(h^{-1}(z) $.
# try to load a checkpoint for the given model
given_path = 'given_model.pth'
train_given_model = True
try:
given_model = TwistAndScrunch() # GivenAutoEncoder(in_dims=in_dims, act=act, emb_dims=emb_dims)
given_model.load_state_dict(torch.load(given_path))
given_model.to(device)
given_model.eval()
train_given_model = False
print("Loaded trained given_model just fine. Proceed.")
except Exception as e:
print(f"{e}:\nFailed to load given model checkpoint. You'll need to (re)train the decoder.")
given_model.to(device)
Loaded trained given_model just fine. Proceed.
TwistAndScrunch( (decoder): Sequential( (0): EmbedBlock( (act): GELU() (lin): Linear(in_features=2, out_features=32, bias=True) ) (1): EmbedBlock( (act): GELU() (lin): Linear(in_features=32, out_features=32, bias=True) ) (2): Linear(in_features=32, out_features=2, bias=True) ) (encoder): Sequential( (0): EmbedBlock( (act): GELU() (lin): Linear(in_features=2, out_features=32, bias=True) ) (1): EmbedBlock( (act): GELU() (lin): Linear(in_features=32, out_features=32, bias=True) ) (2): Linear(in_features=32, out_features=2, bias=True) ) )
if train_given_model:
opt_given = optim.Adam(given_model.parameters(), lr=5e-4)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_given, 'min', factor=0.3, verbose=True)
mseloss = nn.MSELoss()
wandb.finish()
wandb.init(project='aa-toy-given')
epoch, step, max_epochs = 0, 0, 20
lossinfo_every, demo_every = 20, 1000 # in unites of steps
freeze_enc_at = 65 # if we don't freeze the encoder early on, then it linearizes everything, making the challenge for aa too easy.
debug = False
total_steps = len(train_dataset)//batch_size * max_epochs
print("total_steps = ",total_steps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(opt_given, max_lr=0.025, total_steps=total_steps)
while epoch < max_epochs: # training loop
with tqdm(train_dl, unit="batch") as tepoch:
for batch in tepoch: # training
opt_given.zero_grad()
batch = (2*torch.rand(batch.shape)-1).to(device)
batch_out, emb = given_model(batch)
batch_loss = mseloss(batch_out, batch)
mix = batch + (2*torch.rand(batch.shape)-1).to(device) # "the mix"
mix_out, emb2 = given_model(mix)
mix_loss = mseloss(mix_out, mix)
loss = batch_loss + mix_loss
#if step <= freeze_enc_at:
# loss += 0.001*mseloss(emb, 0*emb) # wee bit of L2 decay on given embeddings
log_dict = {}
log_dict['train_loss'] = loss.detach()
log_dict['learning_rate'] = opt_given.param_groups[0]['lr']
if step % lossinfo_every == 0:
tepoch.set_description(f"Epoch {epoch+1}/{max_epochs}")
tepoch.set_postfix(loss=loss.item())
if step == freeze_enc_at: # freeze encoder after this many steps
print(f"Step = {step}: Freezing encoder.")
for param in given_model.encoder.parameters():
param.requires_grad = False
given_model.encoder.train(False)
given_model.encoder.eval()
loss.backward()
opt_given.step()
# run on validation set
with torch.no_grad():
val_batch = next(iter(val_dl)).to(device)
val_mix = val_batch + next(iter(val_dl)).to(device)
val_batch_out, emb = given_model(val_batch)
val_batch_loss = mseloss(val_batch_out, val_batch)
val_mix_out, emb2 = given_model(val_mix)
val_mix_loss = mseloss(val_mix_out, val_mix)
val_loss = val_batch_loss + val_mix_loss
log_dict['val_batch_loss'] = val_batch_loss.detach()
log_dict['val_mix_loss'] = val_mix_loss.detach()
log_dict['val_loss'] = val_loss.detach()
if step % demo_every == 0:
log_dict["given_map_demo_batch"] = wandb.Image(viz_given_batch(demo_batch, given_model))
log_dict["given_map_demo_mix"] = wandb.Image(viz_given_batch(demo_mix, given_model))
log_dict["outs_hist"] = wandb.Histogram(batch_out.cpu().detach().numpy())
log_dict["emb_hist"] = wandb.Histogram(emb.cpu().detach().numpy())
wandb.log(log_dict)
step += 1
scheduler.step()
epoch += 1
wandb.finish()
torch.save(given_model.state_dict(), given_path)
else:
print("Nevermind all that. Proceed.")
Nevermind all that. Proceed.
i.e. that we can do reconstruction to input space: $f^{-1}(y): y \mapsto x$
print(r"viz model ops for single 'stem' x:")
viz_given_batch(demo_batch, given_model, show_recon=True)
viz model ops for single 'stem' x:
print(r"viz model ops for demo mix x1+x2:")
viz_given_batch(demo_mix, given_model, show_recon=True)
viz model ops for demo mix x1+x2:
If, in the left panel, the yellow markers are covering the blue markers, then we're good to continue...
ok, first some utility routines...
#| export
def inp_to_aa(x):
"z = h(f(x))"
with torch.no_grad():
y = given_model.encode(x.to(device))
return aa_model.encode(y), y
def aa_to_inp(z):
"x = h^-1 ( f^-1 ( z ) )"
with torch.no_grad():
y = aa_model.decode(z.to(device))
return given_model.decode(y), y
def plot_kmwq_demo(q, forward=True, labels = ['man', 'king','woman','queen']):
# forward = inputs to z's, backaward = z's to inputs
" little routine to plot our math ops"
if forward:
xs = q
zs, ys = inp_to_aa(q)
maybe = zs[1] - zs[0] + zs[2]
else:
zs = q
xs, ys = aa_to_inp(q)
maybe = xs[1] - xs[0] + xs[2]
[xs, zs, ys, maybe] = [ q.cpu().numpy() for q in [xs, zs, ys, maybe]]
fig, ax = plt.subplots(1,3,figsize=(12,4))
titles, qs = ['input','given emb','aa emb'], [xs, ys, zs]
for i, q in enumerate(qs):
ax[i].set_title(titles[i])
ax[i].plot(q[:,0], q[:,1], marker='o')
print("orange dot shows the 'guess'")
ax[forward*2].plot(maybe[0], maybe[1], marker="o")
for i,lab in enumerate(labels): # show the individual data points and label them
ax[0].text(xs[i,0],xs[i,1], lab)
ax[2].text(zs[i,0],zs[i,1], lab)
points = np.vstack((xs,zs))
valmin, valmax = 1.1*points.min(axis=0), 1.1*points.max(axis=0)
for i in range(3):
ax[i].set_xlim(-2,2)
ax[i].set_ylim(-2,2)
plt.plot()
# vectors in input space
xs = 2*torch.rand(4,2, requires_grad=False)-1 #
xs[3] = xs[1] - xs[0] + xs[2] # queen = king - man + woman
plot_kmwq_demo(xs)
orange dot shows the 'guess'
# vectors in input space
zs = 2*torch.rand(4,2, requires_grad=False)-1
zs[3] = zs[1] - zs[0] + zs[2]
plot_kmwq_demo(zs, forward=False)
orange dot shows the 'guess'
This is kind of the same thing as we just did but more of a (theoretical) "music" example.
Pretend we're mixing 4 stems. But then we want to remove the last one. We can do the math for that in the input space. But can we also do it in the aa embedding space?
def demix_fun(labels = ['guitar','drums','bass','vocal']):
i_rm = random.randint(0,len(labels)-1)
# input space
xs = 2*torch.rand(4,2, requires_grad=False) -1
mix = xs.sum(dim=0).unsqueeze(0)
mix_rm = mix - xs[i_rm,:] # subtract a stem in input space
print(f"mix_remove_{labels[i_rm]} = {mix_rm}")
# do the op in embedding space
with torch.no_grad():
zs = aa_model.encode( given_model.encode(xs.to(device)) )
zmix = aa_model.encode( given_model.encode(mix.to(device)) )
zmix_rm = zmix - zs[i_rm,:] # subtract a stem in aa embedding space!
guess = given_model.decode( aa_model.decode(zmix_rm) ).cpu() # convert to input space
print(f"guess = {guess}")
#plot what we got, in input space
fig, ax = plt.subplots(figsize=(7,7))
ax.text(0,0 ,f"origin")
for i in range(4): # plot all the stems and put labels on them
ax.arrow(0,0, xs[i,0], xs[i,1], length_includes_head=True, head_width=0.05, head_length=0.1) # plot vectors for stems
ax.annotate("", xy=(xs[i,0], xs[i,1]), xytext=(0, 0), arrowprops=dict(arrowstyle="->"))
ax.text(xs[i,0],xs[i,1],f"{labels[i]}")
plt.plot(mix[:,0],mix[:,1], marker='o', markersize=10) # also plot and label the mix
ax.text(mix[:,0],mix[:,1], "mix")
# connect mix and removed value
dx = mix_rm - mix
ax.arrow(mix[0][0], mix[0][1], dx[0][0], dx[0][1], length_includes_head=True, head_width=0.05, head_length=0.1)
ax.annotate("", xy=(xs[i,0], xs[i,1]), xytext=(0, 0), arrowprops=dict(arrowstyle="->"))
ax.plot(mix_rm[:,0],mix_rm[:,1], marker='o', label=f'mix_remove_{labels[i_rm]}', linestyle='None', markersize=10) # plot the point in real space
ax.text(mix_rm[:,0],mix_rm[:,1], f'mix_remove_{labels[i_rm]}')
ax.plot(guess[:,0],guess[:,1], marker='o', label='guess from aa', linestyle='None', markersize=10) # guess = input recon of what was done in aa emb space
plt.axis('square')
ax.set_aspect('equal', adjustable='box')
ax.set_title("input space")
ax.legend(fancybox=True, framealpha=0.6, prop={'size': 10})
plt.show()
demix_fun()
mix_remove_drums = tensor([[0.4061, 0.8386]]) guess = tensor([[0.4127, 0.8588]])
If the green dot is covering the orange dot, then it worked!