We will fit a Polyclonal
model to the RBD antibody mix we simulated.
First, we read in that simulated data. Recall that we simulated both "exact" and "noisy" data, with several average per-library mutations rates, and at six different concentrations. Here we analyze the noisy data for the library with an average of 2 mutations per gene, measured at three different concentrations, as this represents a fairly realistic representation of a real experiment:
import pandas as pd
import polyclonal
noisy_data = (
pd.read_csv('RBD_variants_escape_noisy.csv', na_filter=None)
.query('library == "avg2muts"')
.query('concentration in [0.25, 1, 4]')
.reset_index(drop=True)
)
noisy_data
library | aa_substitutions | concentration | prob_escape | IC90 | |
---|---|---|---|---|---|
0 | avg2muts | 0.25 | 0.050440 | 0.1128 | |
1 | avg2muts | 0.25 | 0.143100 | 0.1128 | |
2 | avg2muts | 0.25 | 0.054520 | 0.1128 | |
3 | avg2muts | 0.25 | 0.084730 | 0.1128 | |
4 | avg2muts | 0.25 | 0.041740 | 0.1128 | |
... | ... | ... | ... | ... | ... |
89995 | avg2muts | Y396T Y473L | 4.00 | 0.000000 | 0.5832 |
89996 | avg2muts | Y421W S359K | 4.00 | 0.044600 | 0.5777 |
89997 | avg2muts | Y449L V503T L335M | 4.00 | 0.000000 | 1.0520 |
89998 | avg2muts | Y473E L518F D427L | 4.00 | 0.002918 | 1.1600 |
89999 | avg2muts | Y505N H519T | 4.00 | 0.000000 | 0.3505 |
90000 rows × 5 columns
Initialize a Polyclonal
model with these data, including three epitopes.
We know from prior work the three most important epitopes and a key mutation in each, so we use this prior knowledge to "seed" initial guesses that assign large escape values to a key site in each epitope:
poly_abs = polyclonal.Polyclonal(data_to_fit=noisy_data,
activity_wt_df=pd.DataFrame.from_records(
[('1', 1.0),
('2', 3.0),
('3', 2.0),
],
columns=['epitope', 'activity'],
),
site_escape_df=pd.DataFrame.from_records(
[('1', 417, 10.0),
('2', 484, 10.0),
('3', 444, 10.0),
],
columns=['epitope', 'site', 'escape'],
),
data_mut_escape_overlap='fill_to_data',
)
Now fit the Polyclonal
model using the default optimization settings and logging output every 100 steps.
Note how the fitting first just fits a site level model to estimate the average effects of mutations at each site, and then fits the full model:
# NBVAL_IGNORE_OUTPUT
opt_res, lossreg = poly_abs.fit(logfreq=100)
# First fitting site-level model. # Starting optimization of 522 parameters at Tue Dec 14 09:51:18 2021. step time_sec loss fit_loss reg_escape regspread 0 0.058178 9144.4 9144.2 0.29701 0 100 6.9439 1337.1 1333.6 3.532 0 200 13.383 1313.1 1308.8 4.3331 0 300 20.263 1304.4 1299.2 5.1782 0 400 26.882 1301.1 1295 6.019 0 500 34.004 1297.9 1291.5 6.3803 0 600 40.905 1297.3 1290.7 6.5574 0 700 47.681 1296.5 1289.8 6.7586 0 800 54.673 1296.1 1289.2 6.8988 0 900 66.549 1295.6 1288.6 6.9791 0 1000 73.827 1295.3 1288.3 7.0348 0 1100 81.151 1295.1 1288 7.1436 0 1200 88.188 1294.9 1287.6 7.326 0 1300 95.754 1294.6 1287.1 7.518 0 1357 99.936 1294.5 1287 7.5161 0 # Successfully finished at Tue Dec 14 09:52:58 2021. # Starting optimization of 5799 parameters at Tue Dec 14 09:52:58 2021. step time_sec loss fit_loss reg_escape regspread 0 0.085596 1646.3 1551.7 94.634 2.2843e-29 100 8.8431 845.71 738.75 94.388 12.571 200 17.843 831.92 722.59 93.203 16.12 300 26.284 823.55 715.48 90.135 17.94 400 35.075 814.84 709.89 85.955 18.995 500 43.314 805.74 705.6 80.771 19.365 600 52.035 797.64 703.24 74.597 19.807 700 62.589 788.33 702.33 64.906 21.091 800 73.745 779.89 701.85 57.442 20.598 900 83.739 773.57 700.36 52.572 20.63 1000 96.357 769.17 698.58 49.943 20.654 1100 106.85 763.85 696.45 46.09 21.304 1200 117.55 756.77 691.6 43.319 21.852 1300 128.49 752.71 687.58 42.5 22.621 1400 139.88 748.93 682.52 42.569 23.844 1500 151.01 744.21 675.52 43.276 25.417 1600 162.39 737.87 665.84 44.426 27.603 1700 174.12 733.4 658.15 45.618 29.638 1800 186.24 728.47 650.56 46.77 31.138 1900 197.15 719.11 636.75 48.475 33.883 2000 208.01 705.27 618.49 50.204 36.569 2100 218.28 686.35 597.18 51.389 37.787 2200 229.13 673.08 584.12 51.914 37.044 2300 239.41 667.73 578.73 52.328 36.667 2400 251.85 666.16 577.31 52.558 36.291 2500 262.01 665.63 576.75 52.767 36.109 2600 272.64 665.39 576.32 52.995 36.073 2700 283.95 665.28 576.22 53.059 35.994 2723 287.47 665.26 576.22 53.063 35.979 # Successfully finished at Tue Dec 14 09:57:45 2021.
previously_fit_params = poly_abs._params
We make some shim functions that will allow us to use Will's code.
Note that we have no actual new regularization happening: we're just trying to use Will's code to fit the existing objective.
def g_shim(params):
return lossreg.loss_reg(params)[0]
def grad_shim(params):
return lossreg.loss_reg(params)[1]
def zero_function(params):
return 0.
def trivial_prox(params, t):
return params
from polyclonal import optimization
prox_grad = optimization.AccProxGrad(g_shim, grad_shim, zero_function, trivial_prox, verbose=True)
When we try previously fit params, nothing happens, which is a good thing.
prox_grad.run(previously_fit_params)
initial objective 6.652598e+02 iteration 1, objective 6.653e+02, relative change 8.699e-08 relative change in objective function 8.7e-08 is within tolerance 1e-06 after 1 iterations
array([1.06698781, 3.22729169, 1.94937872, ..., 0.35142203, 0.69839784, 0.3097154 ])
Let's try with some stupid starting parameters: all ones.
This stops after less than a thousand iterations. Perhaps it hit a local minimum.
import numpy as np
new_prox_grad = optimization.AccProxGrad(g_shim, grad_shim, zero_function, trivial_prox, verbose=True)
new_params = np.ones(previously_fit_params.shape[0])
new_prox_grad.run(new_params, max_iter=3000)
initial objective 1.782483e+04
/home/ematsen/re/polyclonal/polyclonal/polyclonal.py:1583: RuntimeWarning: overflow encountered in exp exp_minus_phi_e_v = numpy.exp(-phi_e_v)
iteration 636, objective 1.108e+03, relative change 9.955e-07 relative change in objective function 1e-06 is within tolerance 1e-06 after 636 iterations
array([2.27645376, 2.27645376, 2.27645376, ..., 0.54204399, 0.54204399, 0.54204399])
OK, let's try something easier: start with just a perturbation of the parameters. This works!!
new_params = previously_fit_params
new_params *= np.random.uniform(0.9, 1.1, size = new_params.shape[0])
new_params
array([0.97677739, 3.30086873, 1.76248978, ..., 0.32171498, 0.73367543, 0.28246869])
new_prox_grad.run(new_params, max_iter=3000)
initial objective 7.195510e+02 iteration 766, objective 6.654e+02, relative change 9.976e-07 relative change in objective function 1e-06 is within tolerance 1e-06 after 766 iterations
array([1.06829617, 3.2274661 , 1.94871031, ..., 0.36086176, 0.69005368, 0.32927252])
Interestingly, we don't seem to get the same parameters.
new_prox_grad.x - previously_fit_params
array([ 0.09151878, -0.07340263, 0.18622053, ..., 0.03914678, -0.04362175, 0.04680383])
import copy
new_poly_abs = copy.deepcopy(poly_abs)
new_poly_abs._params = new_prox_grad.x
BUT, the fit is very close to optimal according to polyclonal.
new_poly_abs.fit(logfreq=100, fit_site_level_first=False)
# Starting optimization of 5799 parameters at Tue Dec 14 10:03:40 2021. step time_sec loss fit_loss reg_escape regspread 0 0.11994 665.37 576.04 53.15 36.181 100 10.045 665.23 576.15 53.124 35.96 101 10.047 665.23 576.15 53.124 35.96 # Successfully finished at Tue Dec 14 10:03:50 2021.
( fun: 665.2345519341873 hess_inv: <5799x5799 LbfgsInvHessProduct with dtype=float64> jac: array([ 0.59013223, 0.93013372, 0.90565914, ..., -0.00113136, -0.00192544, -0.00261379]) message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH' nfev: 116 nit: 100 njev: 116 status: 0 success: True x: array([1.06887108, 3.22690596, 1.94935764, ..., 0.3569214 , 0.69360593, 0.31413904]), <polyclonal.polyclonal.Polyclonal.fit.<locals>.LossReg at 0x79237ab84ca0>)
I wonder what happens if we start with the site-level optimization first. I couldn't get this to work right away.
site_model = poly_abs.site_level_model()
site_model.fit(logfreq=100)
# First fitting site-level model.
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-16-a2b0ed72a1be> in <module> 1 site_model = poly_abs.site_level_model() ----> 2 site_model.fit(logfreq=100) ~/re/polyclonal/polyclonal/polyclonal.py in fit(self, loss_delta, reg_escape_weight, reg_escape_delta, reg_spread_weight, fit_site_level_first, scipy_minimize_kwargs, log, logfreq) 1193 fit_kwargs = {key: values[key] for key in keys if key != 'self'} 1194 fit_kwargs['fit_site_level_first'] = False -> 1195 site_model = self.site_level_model() 1196 site_model.fit(**fit_kwargs) 1197 self._params = self._params_from_dfs( ~/re/polyclonal/polyclonal/polyclonal.py in site_level_model(self, aggregate_mut_escapes) 998 ) 999 site_escape_df = ( -> 1000 polyclonal.utils.site_level_variants( 1001 self.mut_escape_df 1002 .rename(columns={'mutation': 'aa_substitutions'}) ~/re/polyclonal/polyclonal/utils.py in site_level_variants(df, original_alphabet, wt_char, mut_char) 115 site_subs = [] 116 for sub in subs.split(): --> 117 wt, site, _ = mutparser.parse_mut(sub) 118 if site in wts and wts[site] != wt: 119 raise ValueError(f"inconsistent wildtype at {site}: " ~/re/polyclonal/polyclonal/utils.py in parse_mut(self, mutation) 49 m = self._mutation_regex.fullmatch(mutation) 50 if not m: ---> 51 raise ValueError(f"invalid mutation {mutation}") 52 else: 53 return (m.group('wt'), int(m.group('site')), m.group('mut')) ValueError: invalid mutation w331m
site_fit_params = poly_abs._params_from_dfs(
activity_wt_df=site_model.activity_wt_df,
mut_escape_df=(
site_model.mut_escape_df
[['epitope', 'site', 'escape']]
.merge(poly_abs.mut_escape_df.drop(columns='escape'),
on=['epitope', 'site'],
how='right',
validate='one_to_many',
)
),
)
new_prox_grad.run(new_params)
We can now visualize the resulting fits for the activities and escape values, and they can be compared back to the earlier "true" results used to simulate the data:
# NBVAL_IGNORE_OUTPUT
poly_abs.activity_wt_barplot()
# NBVAL_IGNORE_OUTPUT
poly_abs.mut_escape_lineplot()
# NBVAL_IGNORE_OUTPUT
poly_abs.mut_escape_heatmap()
For these simulated data, we can also see how well the fit model does on the "true" simulated values from a library with a different (higher) mutation rate. We therefore read in the "exact" simulated data from a library with a different mutation rate:
exact_data = (
pd.read_csv('RBD_variants_escape_exact.csv', na_filter=None)
.query('library == "avg3muts"')
.query('concentration in [0.25, 1, 0.5]')
.reset_index(drop=True)
)
First, we will compare the true simulated IC90 values to those predicted by the fit model. We make the comparison on a log scale, and clip IC90s at values >50 as likely to be way outside the dynamic range given the concentrations used:
import numpy
from plotnine import *
max_ic90 = 50
# we only need the variants, not the concentration for the IC90 comparison
ic90s = (exact_data[['aa_substitutions', 'IC90']]
.assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
.drop_duplicates()
)
ic90s = poly_abs.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)
ic90s = (
ic90s
.assign(log_IC90=lambda x: numpy.log10(x['IC90']),
predicted_log_IC90=lambda x: numpy.log10(x['predicted_IC90']),
)
)
corr = ic90s['log_IC90'].corr(ic90s['predicted_log_IC90'])
print(f"Correlation is {corr:.2f}")
ic90_corr_plot = (
ggplot(ic90s) +
aes('log_IC90', 'predicted_log_IC90') +
geom_point(alpha=0.1, size=1) +
theme_classic() +
theme(figure_size=(3, 3))
)
_ = ic90_corr_plot.draw()
Next we see how well the model predicts the variant-level escape probabilities $p_v\left(c\right)$, by reading in exact data from the simulations, and then making predictions of escape probabilities. We both examine and plot the correlations:
exact_vs_pred = poly_abs.prob_escape(variants_df=exact_data)
print(f"Correlations at each concentration:")
display(exact_vs_pred
.groupby('concentration')
.apply(lambda x: x['prob_escape'].corr(x['predicted_prob_escape']))
.rename('correlation')
.reset_index()
.round(2)
)
pv_corr_plot = (
ggplot(exact_vs_pred) +
aes('prob_escape', 'predicted_prob_escape') +
geom_point(alpha=0.1, size=1) +
facet_wrap('~ concentration', nrow=1) +
theme_classic() +
theme(figure_size=(3 * exact_vs_pred['concentration'].nunique(), 3))
)
_ = pv_corr_plot.draw()
We also examine the correlation between the "true" and inferred mutation-escape values, $\beta_{m,e}$.
In general, it's necessary to ensure the epitopes match up for this type of comparison as it is arbitrary which epitope in the model is given which name.
But above we seeded the epitopes at the site level using site_effects_df
when we initialized the Polyclonal
object, so they match up with class 1, 2, and 3:
# NBVAL_IGNORE_OUTPUT
import altair as alt
mut_escape_pred = (
pd.read_csv('RBD_mut_escape_df.csv')
.merge((poly_abs.mut_escape_df
.assign(epitope=lambda x: 'class ' + x['epitope'].astype(str))
.rename(columns={'escape': 'predicted escape'})
),
on=['mutation', 'epitope'],
validate='one_to_one',
)
)
print('Correlation between predicted and true values:')
corr = (mut_escape_pred
.groupby('epitope')
.apply(lambda x: x['escape'].corr(x['predicted escape']))
.rename('correlation')
.reset_index()
)
display(corr.round(2))
# for testing since we nbval ignore cell output
numpy.allclose(corr['correlation'], numpy.array([0.82, 0.96, 0.93]), atol=0.02)
corr_chart = (
alt.Chart(mut_escape_pred)
.encode(x='escape',
y='predicted escape',
color='epitope',
tooltip=['mutation', 'epitope'],
)
.mark_point(opacity=0.5)
.properties(width=250, height=250)
.facet(column='epitope')
.resolve_scale(x='independent',
y='independent',
)
)
corr_chart
The correlations are strongest for the dominant epitope (class 2), which makes sense as this will drive the highest escape signal.