--- skip_showdoc: true skip_exec: true --- #| default_exp aa_mixer_toy # install a few non-standard dependencies (for colab) %pip install -Uqq pip einops tqdm wandb #| export import math import os import matplotlib.pyplot as plt from PIL import Image import random import string import numpy as np import torch from torch import optim, nn, Tensor from torch.nn import functional as F from torch.utils import data as torchdata from einops import rearrange from tqdm.auto import tqdm import wandb from aeiou.core import get_device device = get_device() print("device = ",device) SCRATCH_DIR = "/scratch" if os.path.isdir('/scratch') else "/tmp" SCRATCH_DIR += '/aa' # give audio alg scratch its own spot if not os.path.exists(SCRATCH_DIR): os.makedirs(SCRATCH_DIR) print("SCRATCH_DIR =",SCRATCH_DIR) # generate a unique 2-character code for the run suffix_len = 2 RUN_SUFFIX = '_'+''.join(random.choices(string.ascii_lowercase, k=suffix_len)) print("RUN_SUFFIX =",RUN_SUFFIX) seed = 5 torch.manual_seed(seed) batch_size = 1024 n_train_points, train_val_split = batch_size*2000, 0.8 # we're going to grab random data anyway n_points = int(n_train_points*1.25) in_dims = 2 # number of dimensions the input data will live in, e.g. 2 or 3 emb_dims = 2 # number of dimensions for embeddings train_val_split = 0.8 train_len, val_len = round(n_points*train_val_split), round(n_points*(1-train_val_split)) class RandVecDataset(torchdata.Dataset): "very simple dataset" def __init__(self, length, dims): super().__init__() self.data = 2*torch.rand(length, dims) - 1 def __getitem__(self, idx): return self.data[idx] def __len__(self): return self.data.shape[0] train_dataset = RandVecDataset(train_len, in_dims) val_dataset = RandVecDataset(val_len, in_dims) v = train_dataset.__getitem__(0) print(train_dataset.__len__(), v.shape) print(val_dataset.__len__(), v.shape) train_dl = torchdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_dl = torchdata.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) val_iter = iter(val_dl) #| export def get_demo_batch(val_iter, demo_line = True, debug=False): demo_batch = next(val_iter).clone() absmax = (torch.abs(demo_batch)).max().item() if demo_line: # make demo batch lie along a line in order to showcase nonlinearity demo_batch[:,1] = demo_batch[:,0] # line y = x demo_batch[::2,0] *= -1 # even-numbered points go to y = -x demo_batch[::4,0] = 0.01 # data along y axis demo_batch[1::4,1] = 0.01 # data along x axis demo_batch[::5] = absmax*(2*torch.rand((demo_batch[::5]).shape)-1) # make some points random demo_batch[1::7] = absmax*(2*torch.rand((demo_batch[1::7]).shape)-1) # and more points random return demo_batch # plot what the demo batch looks like in input space -- you're going to be seeing this alot! fig, ax = plt.subplots(1,3, figsize=(12,4)) demo_batch = get_demo_batch(val_iter, debug=True) ax[0].plot(demo_batch[:,0], demo_batch[:,1], marker='.', linestyle='None') ax[0].set_title('single "lines+dots" stem') print("") demo_mix = demo_batch + get_demo_batch(val_iter, demo_line=True, debug=True) # add another 'line stem' ax[1].plot(demo_mix[:,0], demo_mix[:,1], marker='.', linestyle='None') ax[1].set_title('mix of two "lines+dots" stems') demo_mix = demo_batch + get_demo_batch(val_iter, demo_line=False, debug=True) # random noise will swamp lines ax[2].plot(demo_mix[:,0], demo_mix[:,1], marker='.', linestyle='None') ax[2].set_title('mix of "lines+dots" + unit noise') #| export class EmbedBlock(nn.Module): def __init__(self, in_dims:int, out_dims:int, act=nn.GELU(), resid=True, use_bn=False, requires_grad=True, **kwargs) -> None: "generic little block for embedding stuff. note residual-or-not doesn't seem to make a huge difference for a-a" super().__init__() self.in_dims, self.out_dims, self.act, self.resid = in_dims, out_dims, act, resid self.lin = nn.Linear(in_dims, out_dims, **kwargs) self.bn = nn.BatchNorm1d(out_dims) if use_bn else None # even though points in 2d, only one non-batch dim in data if requires_grad == False: self.lin.weight.requires_grad = False self.lin.bias.requires_grad = False def forward(self, xin: Tensor) -> Tensor: x = self.lin(xin) if self.act is not None: x = self.act(x) if self.bn is not None: x = self.bn(x) # re. "BN before or after Activation? cf. https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md" return xin + x if (self.resid and self.in_dims==self.out_dims) else x # test that emb_test = EmbedBlock(2, 2, use_bn=True, requires_grad=False) #res_test = ResBlock(2, 2) #| export def friendly_tanh(x, a=3, b=0.1, s=0.9): "like a tanh but not as harsh" return (torch.tanh(a*x) + b*x)*s #return torch.tanh(a*x) + b*torch.sgn(x)*torch.sqrt(torch.abs(x)) def compressor(x, thresh=0.7, ratio=10): "another nonlinearity: kind of like an audio compressor" return torch.where(torch.abs(x) < thresh, x, torch.sgn(x)*(1/ratio*(torch.abs(x)-thresh)+thresh)) x = torch.linspace(-2.5,2.5,500) fig, ax = plt.subplots(figsize=(9,4)) ax.plot(x.cpu().numpy(), torch.tanh(3*x).cpu().numpy(), label='tanh') ax.plot(x.cpu().numpy(), friendly_tanh(x,a=3).cpu().numpy(), label='friendly_tanh') ax.plot(x.cpu().numpy(), compressor(x).cpu().numpy(), label='"compressor"') ax.legend() #| export class TwistAndScrunch(nn.Module): """for now just something random with a bit of nonlinearity """ def __init__(self, in_dims=2, # number of input dimensions to use emb_dims=2, # number of output dimensions to use hidden_dims=32, # number of hidden dimensions scrunch_fac=1.7, twist_fac=0.2, debug=False, act=nn.GELU(), ): super().__init__() self.sf, self.debug, = scrunch_fac, debug, self.pi = torch.acos(torch.zeros(1)).item() * 2 self.decoder = nn.Sequential( EmbedBlock(emb_dims, hidden_dims, act=act), EmbedBlock(hidden_dims, hidden_dims, act=act), nn.Linear(hidden_dims, in_dims) ) # note we'll never use this encoder. Just included this to keep other codes from crashing. ;-) self.encoder = nn.Sequential( # note that this is not going to be used. EmbedBlock(emb_dims, hidden_dims, act=act), EmbedBlock(hidden_dims, hidden_dims, act=act), nn.Linear(hidden_dims, in_dims) ) def encode(self, v): # some nonlinear function i made up vp = v.clone() vp[:,0], vp[:,1] = vp[:,0] - 0.5*friendly_tanh(v[:,1],a=3), vp[:,1] + 0.7*friendly_tanh(v[:,0],a=1.2) return friendly_tanh(self.sf*vp,a=2,b=0.1) def decode(self, emb, exact=False): return self.decoder(emb) def forward(self, x, debug=False, exact=False): debug_save, self.debug = self.debug, debug # only turn on debug for this one call emb = self.encode(x) outs = self.decode(emb, exact=exact), emb self.debug = debug_save return outs torch.manual_seed(seed) #given_model = SimpleAutoEncoder(in_dims=in_dims, emb_dims=emb_dims).to(device) given_model = TwistAndScrunch() given_model = given_model.to(device) #| export def viz_given_batch( val_batch, # should be from val_batch given_model, # the main autoencoder model we are splicing into aa_model=None, # our audio algebra model debug=False, device=get_device(), return_img = True, show_recon=False, ): markers=['o', '+', 'x', '*', '.', 'X', "v", "^", "<", ">", "8", "s", "p", "P", "h", "H", "D", "d","|","_"] colors=['blue', 'magenta','orange', 'green', 'red', 'cyan', 'black'] markersize=2.5 val_batch = val_batch.to(device) with torch.no_grad(): emb_batch = given_model.encode(val_batch) recon_batch = given_model.decode(emb_batch) fig, ax = plt.subplots(1,2, figsize=(11,5)) ax[0].set_title(f"input space. $x \in X$") ax[1].set_title(f"given encoder space. $y \in Y$") if debug: print("absmax of val_batch = ",torch.abs(val_batch).max().item()) print("absmax of emb_batch = ",torch.abs(emb_batch).max().item()) nm = len(markers) for mi in range(nm): #for i in range(val_batch.shape[0]): # TODO: looping over every points is very slow, but apparently needed for arrows for bi,b in enumerate([val_batch, emb_batch, recon_batch]): x, y = b[mi::nm,0].cpu().numpy(), b[mi::nm,1].cpu().numpy() if (bi<2) or show_recon: ax[bi%2].plot(x, y, marker=markers[mi], markersize=markersize, color=colors[bi], linestyle='None', label='' if mi > 0 else ['input', 'embedding', 'recon'][bi]) ax[0].legend(loc='upper left') val_plot_fname = f'{SCRATCH_DIR}/val_plot{RUN_SUFFIX}{random.randint(1,5)}.png' plt.savefig(val_plot_fname, bbox_inches='tight', dpi=300) plt.close() im = Image.open(val_plot_fname) return im print("for single stem:") viz_given_batch(demo_batch, given_model, debug=True) print("Also operate on the demo mix:") viz_given_batch(demo_mix, given_model, debug=True) #| export class AudioAlgebra(nn.Module): """ Main AudioAlgebra model. Contrast to aa-mixer code, keep this one simple & move mixing stuff outside """ def __init__(self, dims=2, hidden_dims=64, act=nn.GELU(), use_bn=False, resid=True, block=EmbedBlock): super().__init__() self.encoder = nn.Sequential( block( dims, hidden_dims, act=act, use_bn=use_bn, resid=resid), block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid), block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid), block( hidden_dims, dims, act=None, use_bn=use_bn, resid=resid), ) self.decoder = nn.Sequential( # same as encoder, in fact. block( dims, hidden_dims, act=act, use_bn=use_bn, resid=resid), block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid), block( hidden_dims, hidden_dims, act=act, use_bn=use_bn, resid=resid), block( hidden_dims, dims, act=None, use_bn=use_bn, resid=resid), ) def encode(self,x): return x + self.encoder(x) def decode(self,x): return x + self.decoder(x) def forward(self, x # the enbedding vector from the given encoder ): xprime = self.encode(x) xprimeprime = self.decode(xprime) # train system to invert itself (and hope it doesn't all collapse to nothing!) return xprime, xprimeprime # encoder output, decoder output #| export def get_stems_faders(batch, # "1 stem" (or batch thereof) already drawn fron the dataloader (val or train) dl_iter, # pre-made the iterator for the/a dataloader dl, # the dataloader itself, for restarting maxstems=2, # how many total stems will be used, i.e. draw maxstems-1 new stems from dl_iter unity_gain=False, # this will force all faders to be +/-1 instead of random numers debug=False): "grab some more inputs and multiplies and some gain values to go with them" nstems = random.randint(2, maxstems) if debug: print("maxstems, nstems =",maxstems, nstems) device=batch.device faders = torch.sgn(2*torch.rand(nstems)-1) # random +/- 1's if not unity_gain: faders += 0.5*torch.tanh(2*(2*torch.rand(nstems)-1)) # gain is now between 0.5 and 1.5 stems = [batch] # note that stems is a list for i in range(nstems-1): # in addtion to the batch of stem passed in, grab some more try: next_stem = next(dl_iter).to(device) # this is just another batch of input data except StopIteration: dl_iter = iter(dl) # time to restart. hoping this propagates out as a pointer next_stem = next(dl_iter).to(device) if debug: print(" next_stem.shape = ",next_stem.shape) stems.append(next_stem) return stems, faders.to(device), dl_iter # also return the iterator def do_mixing(stems, faders, given_model, aa_model, device, debug=False, **kwargs): """ here we actually encode inputs. 0's denote values in the given model space, non-0's denode those in our aa_model e.g., "y" denotes an embedding from the frozen encoder, "z" denotes re-mapped embeddings """ zs, ys, zsum, ysum, yrecon_sum, fadedstems = [], [], None, None, None, [] mix = torch.zeros_like(stems[0]).to(device) #if debug: print("do_mixing: stems, faders =",stems, faders) for s, f in zip(stems, faders): # iterate through list of stems, encode a bunch of stems at different fader settings fadedstem = (s * f).to(device) # audio stem adjusted by gain fader f with torch.no_grad(): y = given_model.encode(fadedstem) # encode the stem z, y_recon = aa_model(y) # <-- this is the main work of the model zsum = z if zsum is None else zsum + z # <---- compute the sum of all the z's so far. we'll end up using this in our (metric) loss as "pred" mix += fadedstem # make full mix in input space with torch.no_grad(): ymix = given_model.encode(mix) # encode the mix in the given model zmix, ymix_recon = aa_model(ymix) # <----- map that according to our learned re-embedding. this will be the "target" in the metric loss # for diagnostics: ysum = y if ysum is None else ysum + y # = sum of embeddings in original model space; we don't really care about ysum except for diagnostics yrecon_sum = y_recon if yrecon_sum is None else yrecon_sum + y_recon # = sum of embeddings in original model space; we don't really care about ysum except for diagnostics zs.append(z) # save a list of individual z's ys.append(y) # save a list of individual y's fadedstems.append(fadedstem) # safe a list of each thing that went into the mix archive = {'zs':zs, 'mix':mix, 'znegsum':None, 'ys': ys, 'ysum':ysum, 'ymix':ymix, 'ymix_recon':ymix_recon, 'yrecon_sum':yrecon_sum, 'fadedstems':fadedstems} # more info for diagnostics return zsum, zmix, archive # we will try to get these two to be close to each other via loss. archive is for diagnostics aa_use_bn = False # batch norm? aa_use_resid = True # use residual connections? (doesn't make much difference tbh) hidden_dims = 64 # number of hidden dimensions in aa model. usually was 64 torch.manual_seed(seed) # chose this value because it shows of nice nonlinearity aa_model = AudioAlgebra(dims=emb_dims, hidden_dims=hidden_dims, use_bn=aa_use_bn, resid=aa_use_resid).to(device) #| export def multimarkerplot(ax, data, color='blue', markersize=2.7, label=''): "little thingy for using lots of symbols for a plot" markers = ['o', '+', 'x', '*', '.', 'X', "v", "^", "<", ">", "8", "s", "p", "P", "h", "H", "D", "d","|","_"] nm = len(markers) for mi in range(nm): ax.plot(data[mi::nm,0], data[mi::nm,1], marker=markers[mi], markersize=markersize, color=color, linestyle='None', label=label if mi==0 else '') def viz_aa_demo( zsum, zmix, archive, aa_model, debug=False, ): "for visualizing what aa model is doing" colors=['blue', 'red', 'green', 'magenta', 'orange', 'cyan', 'black',] col_labels = ["inputs", "given representations", "a-a embeddings"] nrows, ncols = 2, len(col_labels) figsize = (4*ncols, 4*nrows) fig, ax = plt.subplots(nrows,ncols, figsize=figsize) # top row: stems*faders, ys, zs: each of these is a list keys, varnames = ['fadedstems', 'ys', 'zs'], ['x', 'y', 'z'] for coli, key in enumerate(keys): ax[0,coli].set_title(f"{col_labels[coli]} ${varnames[coli]}$") datalist = [x.detach().cpu().numpy() for x in archive[key]] for i, stem in enumerate(datalist): label = f"{varnames[coli]}{i+1}" if coli > 0: label += f" := f(x{i+1})" if coli==1 else f" := h(y{i+1})" multimarkerplot( ax[0,coli], stem, color=colors[i], label=label) ax[0,coli].legend(fancybox=True, framealpha=0.6, loc='upper left', prop={'size': 9}) # bottom row: mix archive["zmix"] = zmix # lil prep keys = ['mix', 'ymix', 'zmix'] annotations = [' := x1 + x2', r' := $f\,$(mix)', r' := $h$(ymix)'] for coli, key in enumerate(keys): ax[1,coli].set_title(f"{col_labels[coli]}") data = archive[key].detach().cpu().numpy() multimarkerplot( ax[1,coli], data, label=key+annotations[coli], color=colors[2]) if key=='ymix': # add another plot multimarkerplot( ax[1,coli], archive["ymix_recon"].detach().cpu().numpy(), label=r'ymix_recon := $h^{-1}$(zmix)', color=colors[4]) elif key=='zmix': multimarkerplot( ax[1,coli], zsum.detach().cpu().numpy(), label='zsum := z1 + z2', color=colors[3]) ax[1,coli].legend(fancybox=True, framealpha=0.6, loc='upper left', prop={'size': 9}) if debug: print("saving to file") filename = f'{SCRATCH_DIR}/aa_val_plot{RUN_SUFFIX}.png' plt.savefig(filename, bbox_inches='tight', dpi=225) plt.close() if debug: print("loading file") im = Image.open(filename) return im # test the viz torch.manual_seed(seed) demo_stems, demo_faders, val_iter = get_stems_faders(demo_batch, val_iter, val_dataset, debug=True, unity_gain=False, maxstems=2) print("demo_faders = ",demo_faders) print("calling do_mixing...") demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device, debug=True) print("calling viz_aa_batch...") im = viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model, debug=False) im mseloss = nn.MSELoss() def rel_loss(y_pred: torch.Tensor, y: torch.Tensor, eps=1e-3) -> float: "relative error loss --- note we're never going to actually use this. it was just part of development" e = torch.abs(y.view_as(y_pred) - y_pred) / ( torch.abs(y.view_as(y_pred)) + eps ) return torch.median(e) #| export def vicreg_var_loss(z, gamma=1, eps=1e-4): std_z = torch.sqrt(z.var(dim=0) + eps) return torch.mean(F.relu(gamma - std_z)) # the relu gets us the max(0, ...) #| export def off_diagonal(x): n, m = x.shape assert n == m return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() def vicreg_cov_loss(z): "the regularization term that is the sum of the off-diagaonal terms of the covariance matrix" cov_z = torch.cov(z.T) return off_diagonal(cov_z).pow_(2).sum().div(z.shape[1]) #| export model_path = 'aa_model.pth' def load_aa_checkpoint(model_path='aa_model.pth', emb_dims=2, device=get_device()): aa_model = None try: aa_model = torch.load(model_path, map_location=torch.device('cpu')).to(device) except Exception as e: print(f"Error trying to load presaved models:\n{e}\n\nStarting from scratch.") return None, None print("Successfully loaded pretrained models") return aa_model def save_aa_checkpoint(aa_model, model_path='aa_model.pth', suffix=''): "save a model checkpoint -- full version. note we save to scratch so as not to overwrite existing file" torch.save(aa_model, SCRATCH_DIR+'/'+model_path.replace(".pth", f"{suffix}.pth")) start_from_pretrained = False result = "Starting from scratch." if start_from_pretrained: aa_load = load_aa_checkpoint(emb_dims=emb_dims, device=device) if aa_load is not None: aa_model = aa_load # re-viz showing loaded model demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device, debug=True) result = viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model=aa_model) print("Successfully loaded checkpoint") result wandb.login() #| export train_new_aa = True # set to false to disable this training def sayno(b, suffix=''): # just a little util for naming runs return f'no{suffix}' if (b is None) or (b==False) else '' def train_aa_vicreg(debug = False): "packaging as def to re-use later. this will use global variables, not sorry" max_epochs = 40 lossinfo_every, viz_demo_every = 20, 1000 # in units of steps checkpoint_every = 10000 max_lr= 0.002 # optimizer and learning rate scheduler opt = optim.Adam([*aa_model.parameters()], lr=5e-4) total_steps = len(train_dataset)//batch_size * max_epochs print("total_steps =",total_steps) # for when I'm checking wandb scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=max_lr, total_steps=total_steps) # rename our wandb run with some significant info wandb.init(project='aa-toy-vicreg') run_name = wandb.run.name wandb.run.name = f"h{hidden_dims}_{'bn_' if aa_use_bn else ''}bs{batch_size}_lr{max_lr}{RUN_SUFFIX}" print("New run name =",wandb.run.name) wandb.run.save() # training loop train_iter = iter(train_dl) # this is only for use with get_stems_faders epoch, step = 0, 0 torch.manual_seed(seed) # for reproducibility while (epoch < max_epochs) or (max_epochs < 0): # training loop with tqdm(train_dl, unit="batch") as tepoch: for batch in tepoch: # train opt.zero_grad() log_dict = {} batch = (1.2*(2*torch.rand(batch.shape, device=device)-1)) # overwrite batch with unending random data! #batch = batch.to(device) stems, faders, train_iter = get_stems_faders(batch, train_iter, train_dl, debug=debug) # vicreg: 1. invariance zsum, zmix, archive = do_mixing(stems, faders, given_model, aa_model, device, debug=debug) mix_loss = mseloss(zsum, zmix) var_loss = (vicreg_var_loss(zsum) + vicreg_var_loss(zmix))/2 # vicreg: 2. variance cov_loss = (vicreg_cov_loss(zsum) + vicreg_cov_loss(zmix))/2 # vicreg: 3. covariance # reconstruction loss: inversion of aa map h^{-1}(z): z -> y, i.e. train the aa decoder y = given_model.encode(batch) z, yrecon = aa_model.forward(y) # (re)compute ys for one input batch (not stems&faders) aa_recon_loss = mseloss(y, yrecon) aa_recon_loss = aa_recon_loss + mseloss(archive['ymix'], archive['ymix_recon']) # also recon of the mix ecoding #aa_recon_loss = aa_recon_loss + mseloss(archive['ysum'], archive['yrecon_sum']) # Never use this: ysum shouldn't matter / is poorly defined loss = mix_loss + var_loss + cov_loss + aa_recon_loss # --- full loss function log_dict['train_loss'] = loss.detach() # --- this is the full loss log_dict['mix_loss'] = mix_loss.detach() log_dict['aa_recon_loss'] = aa_recon_loss.detach() log_dict['var_loss'] = var_loss.detach() log_dict['cov_loss'] = cov_loss.detach() log_dict['learning_rate'] = opt.param_groups[0]['lr'] log_dict['epoch'] = epoch if step % lossinfo_every == 0: tepoch.set_description(f"Epoch {epoch+1}/{max_epochs}") tepoch.set_postfix(loss=loss.item()) # TODO: use EMA for loss display? loss.backward() opt.step() with torch.no_grad(): # run on demo data (same each time) demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device, debug=debug) demo_mix_loss = mseloss(demo_zsum, demo_zmix) log_dict['demo_mix_loss'] = demo_mix_loss.detach() if step % viz_demo_every == 0: # but only draw the viz every now and then. try: # to avoid "OSError: image file is truncated" im = viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model) log_dict["aa_mapping"] = wandb.Image(im) except: pass if step % checkpoint_every == 0: save_aa_checkpoint(aa_model, suffix=RUN_SUFFIX+f"_{step}") wandb.log(log_dict) scheduler.step() step += 1 epoch += 1 #----- training loop finished save_aa_checkpoint(aa_model, suffix=RUN_SUFFIX+f"_{step}") if train_new_aa: train_aa_vicreg() torch.manual_seed(seed) demo_stems, demo_faders, val_iter = get_stems_faders(demo_batch, val_iter, val_dataset, maxstems=2) demo_zsum, demo_zmix, demo_archive = do_mixing(demo_stems, demo_faders, given_model, aa_model, device) viz_aa_demo(demo_zsum, demo_zmix, demo_archive, aa_model) from IPython.display import Video video_url = "https://hedges.belmont.edu/aa_1280.mp4" Video(video_url, width=720) # NOTE that GitHub won't display this video; you have to execute the above code # try to load a checkpoint for the given model given_path = 'given_model.pth' train_given_model = True try: given_model = TwistAndScrunch() # GivenAutoEncoder(in_dims=in_dims, act=act, emb_dims=emb_dims) given_model.load_state_dict(torch.load(given_path)) given_model.to(device) given_model.eval() train_given_model = False print("Loaded trained given_model just fine. Proceed.") except Exception as e: print(f"{e}:\nFailed to load given model checkpoint. You'll need to (re)train the decoder.") given_model.to(device) if train_given_model: opt_given = optim.Adam(given_model.parameters(), lr=5e-4) #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_given, 'min', factor=0.3, verbose=True) mseloss = nn.MSELoss() wandb.finish() wandb.init(project='aa-toy-given') epoch, step, max_epochs = 0, 0, 20 lossinfo_every, demo_every = 20, 1000 # in unites of steps freeze_enc_at = 65 # if we don't freeze the encoder early on, then it linearizes everything, making the challenge for aa too easy. debug = False total_steps = len(train_dataset)//batch_size * max_epochs print("total_steps = ",total_steps) scheduler = torch.optim.lr_scheduler.OneCycleLR(opt_given, max_lr=0.025, total_steps=total_steps) while epoch < max_epochs: # training loop with tqdm(train_dl, unit="batch") as tepoch: for batch in tepoch: # training opt_given.zero_grad() batch = (2*torch.rand(batch.shape)-1).to(device) batch_out, emb = given_model(batch) batch_loss = mseloss(batch_out, batch) mix = batch + (2*torch.rand(batch.shape)-1).to(device) # "the mix" mix_out, emb2 = given_model(mix) mix_loss = mseloss(mix_out, mix) loss = batch_loss + mix_loss #if step <= freeze_enc_at: # loss += 0.001*mseloss(emb, 0*emb) # wee bit of L2 decay on given embeddings log_dict = {} log_dict['train_loss'] = loss.detach() log_dict['learning_rate'] = opt_given.param_groups[0]['lr'] if step % lossinfo_every == 0: tepoch.set_description(f"Epoch {epoch+1}/{max_epochs}") tepoch.set_postfix(loss=loss.item()) if step == freeze_enc_at: # freeze encoder after this many steps print(f"Step = {step}: Freezing encoder.") for param in given_model.encoder.parameters(): param.requires_grad = False given_model.encoder.train(False) given_model.encoder.eval() loss.backward() opt_given.step() # run on validation set with torch.no_grad(): val_batch = next(iter(val_dl)).to(device) val_mix = val_batch + next(iter(val_dl)).to(device) val_batch_out, emb = given_model(val_batch) val_batch_loss = mseloss(val_batch_out, val_batch) val_mix_out, emb2 = given_model(val_mix) val_mix_loss = mseloss(val_mix_out, val_mix) val_loss = val_batch_loss + val_mix_loss log_dict['val_batch_loss'] = val_batch_loss.detach() log_dict['val_mix_loss'] = val_mix_loss.detach() log_dict['val_loss'] = val_loss.detach() if step % demo_every == 0: log_dict["given_map_demo_batch"] = wandb.Image(viz_given_batch(demo_batch, given_model)) log_dict["given_map_demo_mix"] = wandb.Image(viz_given_batch(demo_mix, given_model)) log_dict["outs_hist"] = wandb.Histogram(batch_out.cpu().detach().numpy()) log_dict["emb_hist"] = wandb.Histogram(emb.cpu().detach().numpy()) wandb.log(log_dict) step += 1 scheduler.step() epoch += 1 wandb.finish() torch.save(given_model.state_dict(), given_path) else: print("Nevermind all that. Proceed.") print(r"viz model ops for single 'stem' x:") viz_given_batch(demo_batch, given_model, show_recon=True) print(r"viz model ops for demo mix x1+x2:") viz_given_batch(demo_mix, given_model, show_recon=True) #| export def inp_to_aa(x): "z = h(f(x))" with torch.no_grad(): y = given_model.encode(x.to(device)) return aa_model.encode(y), y def aa_to_inp(z): "x = h^-1 ( f^-1 ( z ) )" with torch.no_grad(): y = aa_model.decode(z.to(device)) return given_model.decode(y), y def plot_kmwq_demo(q, forward=True, labels = ['man', 'king','woman','queen']): # forward = inputs to z's, backaward = z's to inputs " little routine to plot our math ops" if forward: xs = q zs, ys = inp_to_aa(q) maybe = zs[1] - zs[0] + zs[2] else: zs = q xs, ys = aa_to_inp(q) maybe = xs[1] - xs[0] + xs[2] [xs, zs, ys, maybe] = [ q.cpu().numpy() for q in [xs, zs, ys, maybe]] fig, ax = plt.subplots(1,3,figsize=(12,4)) titles, qs = ['input','given emb','aa emb'], [xs, ys, zs] for i, q in enumerate(qs): ax[i].set_title(titles[i]) ax[i].plot(q[:,0], q[:,1], marker='o') print("orange dot shows the 'guess'") ax[forward*2].plot(maybe[0], maybe[1], marker="o") for i,lab in enumerate(labels): # show the individual data points and label them ax[0].text(xs[i,0],xs[i,1], lab) ax[2].text(zs[i,0],zs[i,1], lab) points = np.vstack((xs,zs)) valmin, valmax = 1.1*points.min(axis=0), 1.1*points.max(axis=0) for i in range(3): ax[i].set_xlim(-2,2) ax[i].set_ylim(-2,2) plt.plot() # vectors in input space xs = 2*torch.rand(4,2, requires_grad=False)-1 # xs[3] = xs[1] - xs[0] + xs[2] # queen = king - man + woman plot_kmwq_demo(xs) # vectors in input space zs = 2*torch.rand(4,2, requires_grad=False)-1 zs[3] = zs[1] - zs[0] + zs[2] plot_kmwq_demo(zs, forward=False) def demix_fun(labels = ['guitar','drums','bass','vocal']): i_rm = random.randint(0,len(labels)-1) # input space xs = 2*torch.rand(4,2, requires_grad=False) -1 mix = xs.sum(dim=0).unsqueeze(0) mix_rm = mix - xs[i_rm,:] # subtract a stem in input space print(f"mix_remove_{labels[i_rm]} = {mix_rm}") # do the op in embedding space with torch.no_grad(): zs = aa_model.encode( given_model.encode(xs.to(device)) ) zmix = aa_model.encode( given_model.encode(mix.to(device)) ) zmix_rm = zmix - zs[i_rm,:] # subtract a stem in aa embedding space! guess = given_model.decode( aa_model.decode(zmix_rm) ).cpu() # convert to input space print(f"guess = {guess}") #plot what we got, in input space fig, ax = plt.subplots(figsize=(7,7)) ax.text(0,0 ,f"origin") for i in range(4): # plot all the stems and put labels on them ax.arrow(0,0, xs[i,0], xs[i,1], length_includes_head=True, head_width=0.05, head_length=0.1) # plot vectors for stems ax.annotate("", xy=(xs[i,0], xs[i,1]), xytext=(0, 0), arrowprops=dict(arrowstyle="->")) ax.text(xs[i,0],xs[i,1],f"{labels[i]}") plt.plot(mix[:,0],mix[:,1], marker='o', markersize=10) # also plot and label the mix ax.text(mix[:,0],mix[:,1], "mix") # connect mix and removed value dx = mix_rm - mix ax.arrow(mix[0][0], mix[0][1], dx[0][0], dx[0][1], length_includes_head=True, head_width=0.05, head_length=0.1) ax.annotate("", xy=(xs[i,0], xs[i,1]), xytext=(0, 0), arrowprops=dict(arrowstyle="->")) ax.plot(mix_rm[:,0],mix_rm[:,1], marker='o', label=f'mix_remove_{labels[i_rm]}', linestyle='None', markersize=10) # plot the point in real space ax.text(mix_rm[:,0],mix_rm[:,1], f'mix_remove_{labels[i_rm]}') ax.plot(guess[:,0],guess[:,1], marker='o', label='guess from aa', linestyle='None', markersize=10) # guess = input recon of what was done in aa emb space plt.axis('square') ax.set_aspect('equal', adjustable='box') ax.set_title("input space") ax.legend(fancybox=True, framealpha=0.6, prop={'size': 10}) plt.show() demix_fun()