# Combined plot
f, axes = plt.subplots(figsize = [12, 11])
# Panel A
ax = plt.subplot2grid((5, 2), (0, 0), rowspan=2)
nma = 10
idx = np.where([x!='None' for x in loss['train_mse']])[0]
plt.plot(np.arange(len(idx)-nma+1),
moving_average(np.array([float(x) for x in loss['train_mse'][idx]]),n=nma),
alpha = 0.8, color="C2")
plt.plot(np.arange(len(idx)-nma+1),
moving_average(np.array([float(x) for x in loss['valid_mse'][idx]]),n=nma),
alpha = 0.8, color="C1")
plt.plot(np.arange(len(idx)-nma+1),
moving_average(np.array([float(x) for x in loss['test_mse'][idx]]),n=nma),
alpha = 0.8, color="C0")
plt.xlabel('Training iterations')
plt.ylabel('Mean Squared Error')
custom_lines = [Line2D([0], [0], color='C2'),
Line2D([0], [0], color='C1'),
Line2D([0], [0], color='C0')]
legend = plt.legend(custom_lines, ['Training', 'Validation', 'Test'], loc='upper right',
frameon=False)
ax.add_artist(legend)
#plt.legend(['Training', 'Validation', 'Test'], frameon=False)
plt.xticks([0,1000,2000,3000,4000])
plt.yticks([0,0.2,0.4,0.6,0.8,1.0])
plt.title("Prediction error decreases during training",
weight='bold', size=15)
plt.text(-0.13,1.02,'A', weight='bold',transform=ax.transAxes)
# Panel B
ax = plt.subplot2grid((5, 2), (0, 1), rowspan=2)
trace, real = b(index = index, condition = 0, nT = 400)
#plt.axvline(x = 40, ymin=0.03, ymax=0.73, color="black", ls="dashed", alpha = 0.8, linewidth=2)
plt.xlabel('ODE simulation steps')
plt.ylabel('Cell Response', labelpad=-8)
plt.xlim([-1,41])
plt.ylim([-1.1,1.3])
plt.xticks([0,10,20,30,40],[0,100,200,300,400])
plt.yticks([-1,-0.5,0,0.5,1.0])
custom_lines = [Line2D([0], [0], color='k'), Line2D([0], [0], color='k', ls="dashed")]
legend1 = plt.legend(loc='upper right', bbox_to_anchor=(1.02, 0.93), frameon=False,
ncol=3, prop={'size': 12, "weight": "normal"}, columnspacing=1.2)
legend2 = plt.legend(custom_lines, ['Simulation', 'Experimental'], loc='upper right',
ncol=2, bbox_to_anchor=(0.82, 1.01), frameon=False,
prop={'size': 12,'style':'normal','weight':'normal','variant':'normal','stretch':'normal'},
columnspacing=1.8)
ax.add_artist(legend1)
ax.add_artist(legend2)
plt.title("ODE simulation agrees with experiments",
weight='bold', size=15)
plt.text(-0.13,1.02,'B', weight='bold',transform=ax.transAxes)
# Panel C
ax = plt.subplot2grid((14, 2), (7, 0), rowspan=8)
x_all = y.values.flatten()
y_all = y_hat.values.flatten()
x_prot = y.iloc[:,0:82]
y_prot = y_hat.iloc[:,0:82]
x_pheno = y.iloc[:,82:87]
y_pheno = y_hat.iloc[:,82:87]
plt.scatter(x_prot, y_prot, s = 15, alpha = 0.7, color="#74A6D1",zorder=3)
#plt.scatter(x_prot, y_prot, s = 15, alpha = 0.7, color="#FC5A5B",zorder=3)
plt.scatter(x_pheno, y_pheno, s = 15, alpha = 0.7, color="#3D6CA3",zorder=4)
#plt.scatter(x_pheno, y_pheno, s = 15, alpha = 0.7, color="#FECD7F",zorder=4)
plt.legend(["Molecular nodes","Phenotypic nodes"], loc="lower right", frameon=False,
handletextpad=0.1)
plt.plot([-10, 10], [-10, 10], c = 'white', alpha = 0, ls = '--')
#plt.scatter(x_all, y_all, s = 15, alpha = 0.6)
sns.regplot(x_all, y_all, scatter_kws={'s': 15, 'alpha': 0},line_kws={'color': '#1B406C', 'alpha': 1})
#sns.regplot(x_all, y_all, scatter_kws={'s': 15, 'alpha': 0},line_kws={'color': '#F18A64', 'alpha': 1})
plt.xticks(np.arange(-6,3))
plt.yticks(np.arange(-6,3))
#plt.grid(True, which='both')
lower = np.min([x_all, y_all])
upper = np.max([x_all, y_all])
plt.xlim([lower*1.2, upper*1.2])
plt.ylim([lower*1.2, upper*1.2])
r = np.corrcoef(x_all, y_all)[0][1]
plt.text(x = -5.6, y= 1.6, s='Pearson\'s correlation: ρ=%1.3f'%r,
size = 15)
plt.xlabel('Experimental response')
plt.ylabel('Predicted response')
plt.title("Correlation between predictions and \n experiments across all conditions",
weight='bold', size=15)
plt.text(-0.13,1.06,'C', weight='bold',transform=ax.transAxes)
# Panel D (across different conditions)
ax = plt.subplot2grid((14, 2), (7, 1), rowspan=8)
x_all = y.values
y_all = y_hat.values
rs = [np.corrcoef(x_all[i], y_all[i])[0][1] for i in range(y.shape[0])]
plt.hist(rs, bins = 22, color = 'grey', alpha = 0.6, rwidth=0.93)
plt.axvline(x = r, linewidth=2, label = 'Median', color="#1B406C")
plt.xlabel('Experiment-prediction correlation')
plt.ylabel('Number of conditions')
plt.xticks([0.2,0.4,0.6,0.8,1.0])
plt.yticks([0,10,20,30,40])
plt.text(0.62,33,"correlation for \nall conditions", color="#1B406C",
size = 15)
plt.title("Correlation between predictions and \n experiments for individual conditions",
weight='bold', size=15)
plt.text(-0.13,1.06,'D', weight='bold',transform=ax.transAxes)
plt.tight_layout()