import seaborn as sns
fig, ax = plt.subplots(1, 3, figsize=[11, 5], sharey=True)
for i, dat in enumerate([dat_saz, dat_pfz, dat_miz]):
window=40
ml_t_roll = rolling_mean(dat.ml_t_smooth, window=window)[window:-window]
ml_s_roll = rolling_mean(dat.ml_s_smooth, window=window)[window:-window]
T_anom = rolling_mean(dat.ml_t_smooth, 4)-ml_t_roll
S_anom = rolling_mean(dat.ml_s_smooth, 4)-ml_s_roll
alpha = gsw.alpha(rolling_mean(dat.ml_s_smooth, 4), rolling_mean(dat.ml_t_smooth, 4), 0)
beta = gsw.beta (rolling_mean(dat.ml_s_smooth, 4), rolling_mean(dat.ml_t_smooth, 4), 0)
dT = T_anom.diff(dim='time')*alpha
dS = S_anom.diff(dim='time')*beta
R = dT/dS
Tu = [math.atan(r) for r in R]
Tu = np.array(Tu)
# Tu, R, p = gsw.Turner_Rsubrho(dS, dT, p=0, axis=0)
# # ax[i].hist(R, bins=np.arange(-7, 7.2, 0.5), facecolor=blue, zorder=10, alpha=0.6, edgecolor='0.95')
ax[i].hist(Tu, bins=np.arange(-1.5, 1.6, 0.1), facecolor=blue, zorder=10, alpha=0.55, edgecolor='0.95', density=True)
ax[i].set_xlim(-math.pi/2, math.pi/2)
ax[i].xaxis.set_ticks([-math.pi/2, -math.pi/4, 0, math.pi/4, math.pi/2])
ax[i].xaxis.set_ticklabels(['-$\pi$/2', '-$\pi$/4', 0, '$\pi$/4', '$\pi$/2'])
# ax[i].grid(axis='x', zorder=0, c='0.25', ls='-')
ax[i].plot([0,0],[0, 1.2], c='k', lw=1, alpha=0.5, ls='--')
ax[i].set_ylim(0, 1.2)
ax[i].fill_between(x=[-math.pi/2, -math.pi/4], y1=0, y2=1.2, facecolor=red, alpha=0.15)
ax[i].fill_between(x=[ math.pi/2, math.pi/4], y1=0, y2=1.2, facecolor=red, alpha=0.15)
ax[i].fill_between(x=[-math.pi/4, math.pi/4], y1=0, y2=1.2, facecolor=lightblue, alpha=0.15)
sns.kdeplot(Tu, ax=ax[i], zorder=10, c=blue, lw=3)
ax[i].set_ylabel('')
ax[i].set_xlabel('Turner angle', labelpad=20)
ax[1].text(-1.5, 1.0, 'Anti- \n compensated', fontsize=13)
ax[1].text( 0.1, 1.0, 'Compensated', fontsize=13)
ax[1].text(-0.5, 1.25, 'Sal dom.', fontsize=16, color=blue)
ax[1].text(-2, 1.25, 'Temp dom.', fontsize=16, color=red)
ax[1].text( 0.7, 1.25, 'Temp dom.', fontsize=16, color=red)
ax[0].text(-1.5, 1.13, 'a', fontweight='bold')
ax[1].text(-1.5, 1.13, 'b', fontweight='bold')
ax[2].text(-1.5, 1.13, 'c', fontweight='bold')
# ax[0].set_ylabel('Probability')
plt.savefig('../figs_submission2/fig9.png', dpi=300, bbox_inches='tight')