#!/usr/bin/env python
# coding: utf-8

# This is a simple, pure-python notebook to reproduce the constraints on the Gaia BH1 orbit from joint fitting of RVs and the astrometric constraints. 

# In[1]:


import numpy as np
import matplotlib.pyplot as plt
import emcee
import corner

get_ipython().run_line_magic('matplotlib', 'inline')


# In[2]:


# the correlation matrix from Gaia, and the vector of best-fit parameters. 
# these can be retrieved from the Gaia archive; I just copy them here so 
# this notebook can be run without an internet connection or TAP instance. 


corr_vec_BH1 = np.array([-0.9385065 , -0.10964689,  0.09768217,  0.02599761, -0.09885585,
                        -0.04858593,  0.04631094, -0.18919176,  0.12997928,  0.08754877,
                        -0.7026825 ,  0.7199984 , -0.1333831 , -0.16456777,  0.06737773,
                         0.976118  , -0.90516096, -0.0790209 , -0.0389514 ,  0.14850111,
                        -0.64331377,  0.7340439 , -0.7446727 , -0.32354277, -0.02299099,
                         0.10600532, -0.12226893,  0.70452565,  0.31863543, -0.2744563 ,
                         0.2560132 ,  0.1476755 , -0.23248082, -0.85903525,  0.25882402,
                        -0.35804698,  0.9656088 , -0.9219766 , -0.2000572 ,  0.05323359,
                         0.00881674, -0.5612588 ,  0.9152694 ,  0.8657674 ,  0.1358042 ,
                        -0.01977275,  0.06291951, -0.08401399,  0.08238114, -0.81962156,
                        -0.07304911, -0.14815669,  0.03363582,  0.13756153,  0.0689249 ,
                        -0.21668243,  0.16537377, -0.2642203 , -0.15781169,  0.33261746,
                         0.7964632 , -0.14971647,  0.4369154 , -0.98625296, -0.0418497 ,
                        -0.25472474])

# ra, dec, parallax, pmra, pmdec, a_thiele_innes, b_thiele_innes, f_thiele_innes, g_thiele_innes, ecc, period, t_peri
mu_vec_BH1 = np.array([262.171208162297, -0.5810920153840334, 2.0954515888832432, -7.702050443675317, 
                       -25.85042074524512, -0.262289119199277, 2.9291159041806485, 1.5248071482113947,
                       0.5343685858549198, 0.48893589298452034, 185.7656578920495, -12.024680365644883])

err_vec_BH1 = np.array([0.49589708, 0.15092158, 0.01745574, 0.02040691, 0.02699432, 0.16984475, 
                        0.17521776, 0.15387644, 0.54653114, 0.074341,0.30688563, 6.34896183])

# the RV data
jds_BH1 = np.array([2457881.2899, 2458574.3663, 2459767.6226, 2459791.9186, 2459795.6461, 2459796.4995, 
                2459798.8399, 2459805.5101, 2459808.7388224313,  2459813.6045, 2459814.58740, 
                2459815.5927, 2459817.5278, 2459818.5266, 2459818.78698893, 2459819.5543, 
                2459820.5465, 2459821.5669, 2459822.5745, 2459823.5422430662, 2459824.5305, 
                2459825.5361, 2459823.8525, 2459824.8516, 2459826.7920, 2459828.5677, 2459829.5373, 
                2459829.5768, 2459830.6452, 2459831.6223, 2459833.7523, 2459834.5509, 2459834.7691, 
                2459835.7678, 2459838.8082, 2459838.7208, 2459840.7729, 2459845.5069, 2459855.5012, 
                2459868.5128, 2459877.6978])
rvs_BH1 = np.array([20.0, 8.9, 63.8, 131.90, 141.4, 142.7, 140.6, 127.7, 118, 90.5, 86.1, 81.90, 74.5, 71.0, 
                67.8, 67.0, 64.0, 60.7,  57.8, 54.8, 52.1, 49.8, 53.76, 51.18, 46.59, 42.2, 42.1, 40.4, 
                38.5, 36.5, 33.23, 32.3, 31.74, 30.14, 26.35, 27.5, 24.20, 19.4, 14.2, 9.3, 10.5])  
rv_errs_BH1 = np.array([4.1, 5.6, 3, 0.1, 3, 3, 4, 1.0, 4, 0.3, 0.3, 0.3, 0.3, 0.3, 4, 0.3, 0.3, 0.3, 0.3, 
                    0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.3, 3, 0.3, 0.3, 0.5, 0.1, 0.3, 0.1, 0.1, 0.1, 4, 
                    0.1, 1.0, 1.0, 1.0, 1.5])

rv_names_BH1 = np.array(['LAMOST', 'LAMOST', 'MagE', 'HIRES', 'MagE', 'MagE', 'GMOS', 'XSHOOTER', 'GMOS', 'FEROS', 
                     'FEROS', 'FEROS', 'FEROS', 'FEROS', 'GMOS', 'FEROS', 'FEROS', 'FEROS', 'FEROS', 'FEROS', 
                     'FEROS', 'FEROS', 'HIRES', 'HIRES', 'HIRES', 'FEROS', 'MagE', 'FEROS',  'FEROS', 'FEROS', 
                     'HIRES', 'FEROS', 'HIRES', 'HIRES', 'HIRES', 'GMOS', 'HIRES', 'XSHOOTER', 'XSHOOTER', 
                     'XSHOOTER', 'ESI'])

# prior on the mass of the luminous star
m1_obs, m1_err = 0.95, 0.05 

# the reference epoch of the gaia t_periastron; i.e., JD 2016.0
jd_ref = 2457389.0


# In[3]:


# convert the correlation matrix to a covariance matrix

Nparam = 12
triangle_nums = (np.arange(1, Nparam)*(np.arange(1, Nparam) - 1)//2)

cov_mat_BH1 = []
count = 0
for i, sigx in enumerate(err_vec_BH1):
    this_row = []
    for j, sigy in enumerate(err_vec_BH1):
        if i == j:
            this_row.append(err_vec_BH1[i]**2)
        elif j < i:
            this_row.append(cov_mat_BH1[j][i])
        else:
            this_idx = triangle_nums[j-1] + i
            this_row.append(corr_vec_BH1[this_idx]*sigx*sigy)
    cov_mat_BH1.append(this_row)
cov_mat_BH1 = np.array(cov_mat_BH1)


# In[4]:


# functions to predict RVs at a given time
def fsolve_newton(Mi, ecc, xtol = 1e-10):
    '''
    numerically solve with newton's method.
    Mi: float, 2*pi/P*(t - T_p)
    ecc: float, eccentricity
    xtol: float, tolerance
    '''
    eps = 1
    EE = Mi + ecc*np.sin(Mi) + ecc**2/2*np.sin(2*Mi) + ecc**3/8*(3*np.sin(3*Mi) - np.sin(Mi))
    while eps > xtol:
        EE0 = EE
        EE = EE0 + (Mi + ecc*np.sin(EE0) - EE0)/(1 - ecc*np.cos(EE0))
        eps = np.abs(EE0 - EE) 
    return EE

def solve_kep_eqn(t, T_p, P, ecc):
    '''
    Solve Keplers equation E - e*sin(E) = M for E
    Here, M = 2*pi/P*(t - T_p)
    t: array of times
    T_p: float; periastron time
    P: float, period 
    ecc: float, eccentricity
    '''
    M = 2*np.pi/P * (t - T_p)
    E = np.zeros(t.shape)
    for i,Mi in enumerate(M):
        E[i] = fsolve_newton(Mi = Mi, ecc = ecc, xtol = 1e-10)
    return E

def get_radial_velocities(t, P, T_p, ecc, K, omega, gamma):
    '''
    t: array of times
    P: float, period
    T_p: float, periastron time
    ecc: float, eccentricity
    K: float, RV semi-amplitude
    omega: float, longitude of periastron (radians, not degrees)
    gamma: float, center-of-mass RV
    '''
    E = solve_kep_eqn(t = t, T_p = T_p, P = P, ecc = ecc)
    f = 2*np.arctan2(np.sqrt(1+ecc)*np.sin(E*.5),np.sqrt(1-ecc)*np.cos(E*.5))
    vr = K*(np.cos(f + omega) + ecc*np.cos(omega)) + gamma
    return vr


# In[5]:


# likelihood for joint fit of astrometry and RVs


def get_a0_mas(period, m1, m2, parallax):
    '''
    predicts a0 (photocenter semi-major axis) in milliarcsec, 
    assuming a dark companion
    period: float, days
    m1: float, Msun, (mass of luminous star)
    m2: float, Msun, (mass of companion)
    parallax: float, milliarcsec
    '''
    G = 6.6743e-11 # SI units
    Msun = 1.98840987069805e+30
    AU = 1.4959787e+11
    a_au = (((period*86400)**2 * G * (m1*Msun + m2*Msun)/(4*np.pi**2))**(1/3.))/AU
    a_mas = a_au*parallax
    q = m2/m1
    a0_mas = a_mas*q/(1+q)
    return a0_mas

def get_Kstar_kms(period, inc_deg, m1, m2, ecc):
    '''
    RV semi-amplitude of the luminous star
    period: float, days
    inc_deg: float, inclination in degrees
    m1: float, mass of luminous star in Msun
    m2: float, mass of companion in Msun
    ecc: float, eccentricity
    '''
    G = 6.6743e-11 # SI units
    Msun = 1.98840987069805e+30
    Kstar = (2*np.pi*G*(m2*Msun) * (m2/(m1 + m2))**2 * np.sin(inc_deg*np.pi/180)**3 / \
             (period*86400 *  (1 - ecc**2)**(3/2)))**(1/3)
    return Kstar/1000 # km/s

def lnL_with_covariance(y_vec, mu_vec, Sigma_mat):
    '''
    calculates the likelihood of a vector y_vec, given a multivariate Gaussian likelihood characterized by 
    a mean mu_vec and a covariance matrix Sigma_ma
    '''
    return -0.5*np.dot((y_vec - mu_vec).T, np.dot(np.linalg.inv(Sigma_mat), y_vec - mu_vec))  \
        -0.5*np.log(np.linalg.det(Sigma_mat)) -len(mu_vec)/2*np.log(2*np.pi)
    

def ln_likelihood(theta, fit_rvs = False):
    '''
    theta = ra, dec, parallax, pmra, pmdec, period, ecc, inc_deg, omega_deg, w_deg, t_peri, v_com, m2, m1
    fit_rvs: bool; whether to include the RVs or not 
    '''
    ra, dec, parallax, pmra, pmdec, period, ecc, inc_deg, omega_deg, w_deg, t_peri, v_com, m2, m1 = theta 
    inc, omega, w = inc_deg*np.pi/180, omega_deg*np.pi/180, w_deg*np.pi/180

    if fit_rvs:
        Kstar_kms = get_Kstar_kms(period = period, inc_deg = inc_deg, m1 = m1, m2 = m2, ecc = ecc)
        rv_pred = get_radial_velocities(t = jds_BH1-jd_ref, P = period, T_p = t_peri, ecc = ecc, K = Kstar_kms,
                                        omega = w, gamma = v_com)
        lnL = -0.5*np.sum((rv_pred - rvs_BH1)**2/rv_errs_BH1**2)
    else:
        lnL = 0
    
    a0_mas = get_a0_mas(period = period, m1 = m1, m2 = m2, parallax = parallax)
    A_pred = a0_mas*( np.cos(w)*np.cos(omega) - np.sin(w)*np.sin(omega)*np.cos(inc) )
    B_pred = a0_mas*( np.cos(w)*np.sin(omega) + np.sin(w)*np.cos(omega)*np.cos(inc) )
    F_pred = -a0_mas*( np.sin(w)*np.cos(omega) + np.cos(w)*np.sin(omega)*np.cos(inc) )
    G_pred = -a0_mas*( np.sin(w)*np.sin(omega) - np.cos(w)*np.cos(omega)*np.cos(inc) )
            
    y_vec = np.array([ra, dec, parallax, pmra, pmdec, A_pred, B_pred, F_pred, G_pred, ecc, period, t_peri])
    lnL += lnL_with_covariance(y_vec = y_vec, mu_vec = mu_vec_BH1, Sigma_mat = cov_mat_BH1)
    lnL += -0.5*(m1 - m1_obs)**2/m1_err**2 # prior on the mass of the luminous star
    
    # put a flat prior on gamma to avoid walkers running to infinity
    if np.abs(v_com - 50) > 50:
        lnL += -np.inf
    
    if np.isfinite(lnL):
        return lnL
    else:
        return -np.inf




# In[6]:


# first the "no-RVs" fit

ndim = 14
nwalkers = 64

p0 = np.array([262.17120816,  -0.58109202,   2.09545159,  -7.70205044,
       -25.85042075, 185.76565789,   0.48893589, 126, 97.8, 12.8, -1.1, 46.6, 9.62, 0.93])
p0 = np.tile(p0, (nwalkers, 1))
p0 += p0 * 1.0e-10 * np.random.normal(size=(nwalkers, ndim))
sampler_norvs = emcee.EnsembleSampler(nwalkers, ndim, ln_likelihood, args=[False])

# 5000 step burn-in
state = sampler_norvs.run_mcmc(p0, 5000)

# 3000 more steps
sampler_norvs.reset()
state = sampler_norvs.run_mcmc(state, 3000)


# In[7]:


# next, the "with-rvs" fit. This will take a few minutes because predicting RVs is moderately expensive
# (in production runs we use a compiled C function to speed up the RV prediction)
sampler_withrvs = emcee.EnsembleSampler(nwalkers, ndim, ln_likelihood, args=[True])

# 5000 step burn-in
state = sampler_withrvs.run_mcmc(p0, 5000)

# 5000 more steps
sampler_withrvs.reset()
state = sampler_withrvs.run_mcmc(state, 3000)


# In[8]:


labels = [r'$P_{\rm orb}\,\,[\rm days]$', r'$\rm ecc$', r'$\rm inc\,\,[\rm deg]$', r'$\Omega\,\,[\rm deg]$', 
          r'$\omega\,\,[\rm deg]$', r'$T_{p}\,\,[\rm days]$', r'$\gamma\,\,[\rm km\,s^{-1}]$', 
          r'$M_2\,\,[M_{\odot}]$', r'$M_{\star}\,\,[M_{\odot}]$']

# now compare the fits
half_sampler = sampler_withrvs.flatchain.T[5:].T    
half_sampler_norvs = sampler_norvs.flatchain.T[5:].T

fig = corner.corner(half_sampler_norvs, labels=labels, show_titles=True, plot_datapoints = False, 
                    plot_density = False, color = 'c')
fig = corner.corner(half_sampler, labels = labels, show_titles=True, plot_datapoints = False, 
                    plot_density = False, color = 'k', fig = fig, label_kwargs = {'fontsize': 20})
fig.axes[4].plot([], [], 'c', label = r'$\rm astrometry\,\,only$')
fig.axes[4].plot([], [], 'k', label = r'$\rm astrometry+RVs$')
fig.axes[4].legend(loc = 'upper right', frameon = False, fontsize=30)
    


# In[9]:


# posterior predictive check for the "with RVs" solution. 
from astropy.time import Time
ra, dec, parallax, pmra, pmdec, period, ecc, inc_deg, omega_deg, w_deg, t_peri, \
    v_com, m2, m1 = sampler_withrvs.flatchain.T

t_grid = np.linspace(np.min(jds_BH1) - 150, np.max(jds_BH1) + 150, 3000)
randints = np.random.randint(0, len(period), 50)

colors = ['#e41a1c', '#feb24c', '#e7298a', '#66a61e',  '#a6761d',  '#377eb8', '#756bb1']
f, ax = plt.subplots(2, 1, figsize = (14, 10))
plt.subplots_adjust(hspace=0.2)
for i in randints:
    Kstar_kms = get_Kstar_kms(period = period[i], inc_deg = inc_deg[i], m1 = m1[i], m2 = m2[i], ecc = ecc[i])
    rv_pred = get_radial_velocities(t = t_grid-jd_ref, P = period[i], T_p = t_peri[i], ecc = ecc[i], K = Kstar_kms,
                                        omega = w_deg[i]*np.pi/180, gamma = v_com[i])    
    ax[0].plot(Time(t_grid, format= 'jd').decimalyear, rv_pred, 'k', alpha=0.1, rasterized=True)
    ax[1].plot(Time(t_grid, format= 'jd').decimalyear, rv_pred, 'k', alpha=0.1, rasterized=True)


for i, name in enumerate(['LAMOST',  'MagE', 'GMOS', 'XSHOOTER', 'FEROS', 'HIRES', 'ESI']):
    m = rv_names_BH1 == name
    ax[0].errorbar(Time(jds_BH1[m], format='jd').decimalyear, rvs_BH1[m], yerr= rv_errs_BH1[m], fmt='o', color = colors[i], label = name)
    ax[1].errorbar(Time(jds_BH1[m], format='jd').decimalyear, rvs_BH1[m], yerr= rv_errs_BH1[m], fmt='o', color = colors[i], label = name)
ax[0].set_xlabel(r'$\rm year$', fontsize=20)
ax[1].set_xlabel(r'$\rm year$', fontsize=20)
ax[0].set_ylabel(r'$\rm RV\,\,[\rm km\,s^{-1}]$', fontsize=20)
ax[1].set_ylabel(r'$\rm RV\,\,[\rm km\,s^{-1}]$', fontsize=20)
ax[1].legend(loc = 'upper left')
ax[0].legend(loc = 'upper left', ncol=4)
ax[1].set_xlim(2022.3, 2022.9)
ax[0].set_xlim(2017, 2023)
ax[1].set_xticks([2022.3, 2022.4, 2022.5, 2022.6, 2022.7, 2022.8, 2022.9])
ax[1].set_xticklabels([2022.3, 2022.4, 2022.5, 2022.6, 2022.7, 2022.8, 2022.9])
ax[0].set_ylim(-10, 165)
ax[1].set_ylim(-10, 165)

import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

def custom_mark_inset(axA, axB, fc='None', ec='k'):
    xx = axB.get_xlim()
    yy = axB.get_ylim()
    xy = (xx[0], yy[0])
    width = xx[1] - xx[0]
    height = yy[1] - yy[0]
    pp = axA.add_patch(patches.Rectangle(xy, width, height, fc=fc, ec=ec))
    p1 = axA.add_patch(patches.ConnectionPatch(
        xyA=(xx[0], yy[0]), xyB=(xx[0], yy[1]),
        coordsA='data', coordsB='data',
        axesA=axA, axesB=axB, linewidth=1, color='0.2'))
    p2 = axA.add_patch(patches.ConnectionPatch(
        xyA=(xx[1], yy[0]), xyB=(xx[1], yy[1]),
        coordsA='data', coordsB='data',
        axesA=axA, axesB=axB,linewidth=1,  color='0.2'))
    return pp, p1, p2
pp, p1, p2 = custom_mark_inset(ax[0], ax[1])


# In[10]:


# posterior predictive check for the "no RVs" solution. 
from astropy.time import Time
ra, dec, parallax, pmra, pmdec, period, ecc, inc_deg, omega_deg, w_deg, t_peri, \
    v_com, m2, m1 = sampler_norvs.flatchain.T

t_grid = np.linspace(np.min(jds_BH1) - 150, np.max(jds_BH1) + 150, 3000)
randints = np.random.randint(0, len(period), 50)

colors = ['#e41a1c', '#feb24c', '#e7298a', '#66a61e',  '#a6761d',  '#377eb8', '#756bb1']
f, ax = plt.subplots(2, 1, figsize = (14, 10))
plt.subplots_adjust(hspace=0.2)
for i in randints:
    Kstar_kms = get_Kstar_kms(period = period[i], inc_deg = inc_deg[i], m1 = m1[i], m2 = m2[i], ecc = ecc[i])
    rv_pred = get_radial_velocities(t = t_grid-jd_ref, P = period[i], T_p = t_peri[i], ecc = ecc[i], K = Kstar_kms,
                                        omega = w_deg[i]*np.pi/180, gamma = 46.6)    
    ax[0].plot(Time(t_grid, format= 'jd').decimalyear, rv_pred, 'c', alpha=0.1, rasterized=True)
    ax[1].plot(Time(t_grid, format= 'jd').decimalyear, rv_pred, 'c', alpha=0.1, rasterized=True)


for i, name in enumerate(['LAMOST',  'MagE', 'GMOS', 'XSHOOTER', 'FEROS', 'HIRES', 'ESI']):
    m = rv_names_BH1 == name
    ax[0].errorbar(Time(jds_BH1[m], format='jd').decimalyear, rvs_BH1[m], yerr= rv_errs_BH1[m], fmt='o', color = colors[i], label = name)
    ax[1].errorbar(Time(jds_BH1[m], format='jd').decimalyear, rvs_BH1[m], yerr= rv_errs_BH1[m], fmt='o', color = colors[i], label = name)
ax[0].set_xlabel(r'$\rm year$', fontsize=20)
ax[1].set_xlabel(r'$\rm year$', fontsize=20)
ax[0].set_ylabel(r'$\rm RV\,\,[\rm km\,s^{-1}]$', fontsize=20)
ax[1].set_ylabel(r'$\rm RV\,\,[\rm km\,s^{-1}]$', fontsize=20)
ax[1].legend(loc = 'upper left')
ax[0].legend(loc = 'upper left', ncol=4)
ax[1].set_xlim(2022.3, 2022.9)
ax[0].set_xlim(2017, 2023)
ax[1].set_xticks([2022.3, 2022.4, 2022.5, 2022.6, 2022.7, 2022.8, 2022.9])
ax[1].set_xticklabels([2022.3, 2022.4, 2022.5, 2022.6, 2022.7, 2022.8, 2022.9])
ax[0].set_ylim(-10, 195)
ax[1].set_ylim(-10, 195)

import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

def custom_mark_inset(axA, axB, fc='None', ec='k'):
    xx = axB.get_xlim()
    yy = axB.get_ylim()
    xy = (xx[0], yy[0])
    width = xx[1] - xx[0]
    height = yy[1] - yy[0]
    pp = axA.add_patch(patches.Rectangle(xy, width, height, fc=fc, ec=ec))
    p1 = axA.add_patch(patches.ConnectionPatch(
        xyA=(xx[0], yy[0]), xyB=(xx[0], yy[1]),
        coordsA='data', coordsB='data',
        axesA=axA, axesB=axB, linewidth=1, color='0.2'))
    p2 = axA.add_patch(patches.ConnectionPatch(
        xyA=(xx[1], yy[0]), xyB=(xx[1], yy[1]),
        coordsA='data', coordsB='data',
        axesA=axA, axesB=axB,linewidth=1,  color='0.2'))
    return pp, p1, p2
pp, p1, p2 = custom_mark_inset(ax[0], ax[1])


# In[ ]: