#!/usr/bin/env python
# coding: utf-8
#
# # Optimal planetary landing
#
#
#
# $$
# \newcommand{\eg}{{\it e.g.}}
# \newcommand{\ie}{{\it i.e.}}
# \newcommand{\argmin}{\operatornamewithlimits{argmin}}
# \newcommand{\mc}{\mathcal}
# \newcommand{\mb}{\mathbb}
# \newcommand{\mf}{\mathbf}
# \newcommand{\minimize}{{\text{minimize}}}
# \newcommand{\diag}{{\text{diag}}}
# \newcommand{\cond}{{\text{cond}}}
# \newcommand{\rank}{{\text{rank }}}
# \newcommand{\range}{{\mathcal{R}}}
# \newcommand{\null}{{\mathcal{N}}}
# \newcommand{\tr}{{\text{trace}}}
# \newcommand{\dom}{{\text{dom}}}
# \newcommand{\dist}{{\text{dist}}}
# \newcommand{\R}{\mathbf{R}}
# \newcommand{\SM}{\mathbf{S}}
# \newcommand{\ball}{\mathcal{B}}
# \newcommand{\bmat}[1]{\begin{bmatrix}#1\end{bmatrix}}
# \newcommand{\loss}{\ell}
# \newcommand{\eloss}{\mc{L}}
# \newcommand{\abs}[1]{| #1 |}
# \newcommand{\norm}[1]{\| #1 \|}
# \newcommand{\tp}{T}
# $$
# __
ASE3001: Computational Experiments for Aerospace Engineering, Inha University.
__
# _ Jong-Han Kim (jonghank@inha.ac.kr)
_
# _ Jiwoo Choi (jiwoochoi@inha.edu)
_
#
#
# In this problem we consider a soft landing problem for a planetary lander.
#
# Consider the following equations of motion in the ENU (East-North-Up) frame
#
# $$
# \begin{aligned}
# \dot{p} &= v \\
# \dot{v} &= u - \gamma v +g \\
# \end{aligned}
# $$
# with
# $$
# \begin{aligned}
# p &= (p_e, p_n, p_u)\\
# v &= (v_e, v_n, v_u) \\
# g&=(0,0,-g).
# \end{aligned}
# $$
#
# where $p$ and $v$ are the position and velocity of the vehicle and $u$ represents the acceleration of the vehicle. The gravitational acceleration is denoted by $g$ and the damping coefficient is given by $\gamma$. Note that the acceleration vector as the control input can be achieved by a set of thrusters attached on the vehicle. The objective of the problem is to find the control input plan $u_0, \dots, u_{N-1}$ that drives the vehicle to $p=0$ and $v=0$ at $t=N$, from the specified initial condition.
#
#
# The above system can be descretized using trapezoidal integration as follows.
#
# $$
# \begin{aligned}
# v_{t+1} &= v_t + {h}\left( u_t - \gamma v_t -g \right) \\
# &= \left(1-\gamma h\right) v_t + h u_t - hg \\
# p_{t+1} &= p_t + \frac{h}{2}\left( v_t + v_{t+1} \right) \\
# &= p_t + \frac{h}{2}\left( v_t + \left(1-\gamma h\right) v_t + h u_t \right) \\
# &= p_t + \left(h-\frac{1}{2}\gamma h^2\right) v_t + \frac{1}{2} h^2 u_t
# \end{aligned}
# $$
#
# Given initial points and desired points, the problem of minimizing the control input while satisfying the dynamics of the system can be defined as follows,
#
# $$
# \begin{aligned}
# \text{minimize} \quad & \sum_{t=0}^{N-1} \left\|u_t \right\|_2^2 \\
# \text{subject to} \quad & x_{t+1} = Ax_t + Bu_t+b , \quad t=0,\dots,N-1,\\
# &x_0 = x_\text{init},\\
# &x_N = x_\text{des},
# \end{aligned}
# $$
#
# where $N$ is the prediction horizon, which represents the number of future time steps considered for the optimization problem.
#
# The above optimal control problem can be addressed using a weighted least squares approach.
# In[ ]:
import numpy as np
import numpy.linalg as lg
import scipy.sparse as sp
import scipy.sparse.linalg as sla
import matplotlib.pyplot as plt
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from dataclasses import dataclass
from tqdm import tqdm
# In[ ]:
N = 100
T = 6
dt = T / N
gamma = 1e-5
g = np.array([0, 0, -9.8])
x_init = np.array([10,-5,40, 0,0,-1]) # initial pe, pn, pu, ve, vn, vu at t=0
x_des = np.array([0,0,0,0,0,0]) # desired pe, pn, pu, ve, vn, vu at t=N
#
#
# ---
#
#
#
# _**(Problem 1)**_ Construct the appearing matrices and a vector, $A$, $B$, and $b$ expressed with given parameters. Define $x = (p_e, p_n, p_u, v_e, v_n, v_u)$, and $u=(u_e, u_n, u_u)$.
# In[ ]:
## your code here ##
#
#
# ---
#
#
#
# _**(Problem 2)**_
# The original problem can be reformulated as minimizing a sum of square terms as follows.
#
# $$
# \underset{x_1,\dots,x_N, u_0,\dots,u_{N-1}}{\text{minimize}} \ \|x_N-x_\text{des}\|^2 + \sum_{t=0}^{N-1}\left\| x_{t+1}-Ax_{t}-Bu_{t}-b\right\|^2 + \sum_{t=0}^{N-1}\left\|Q_u u_t\right\|^2.
# $$
#
# Stacking the state variables and the control inputs by $x = (x_1,x_2, \dots, x_N)$, $u = (u_0, u_1, \dots, u_{N-1})$, and $y = (x, u) \in \R^{9N}$, the above problem can be expressed as the following,
#
# $$
# \underset{y}{\text{minimize}} \ \left\|Gy - c\right\|^2 + \sum_{t=0}^{N-1}\left\|Q_u u_t\right\|^2,
# $$
#
# which can again be reformulated in,
#
# $$
# \underset{y}{\text{minimize}} \ \left\|\tilde{G}y - \tilde{c}\right\|^2 ,
# $$
#
# where, $Q_u$ is a weighting matrix that regulates the size of control inputs, which in given in the code cell below.
#
#
#
# Construct $G, \tilde{G}, c, \tilde{c}$ and make a code implementing these, noting that $G, \tilde{G}, c, \tilde{c}$ contain $A$, $B$, $b$, and $I$ (the identity matrix).
#
# Note that the state variables $x_1, \dots, x_N$ appear in the above formulation. In other words, the decision variables of the above problem include not only the control inputs but also the state variables.
#
# Try using the `scipy.sparse` module with indexing where possible.
# In[ ]:
w_u = 3e-2
Q_u = np.diag([w_u]*3)
# In[ ]:
## your code here ##
#
#
# ---
#
#
#
# _**(Problem 3)**_
# Find the optimal solution $y^*$ to the problem using the gradient descent method. Compare this solution with the one obtained via the closed-form solution. Use the initial value `yp_init` specified in the code cell below to achieve faster convergence.
#
# Display the optimal trajectories and the optimal control inputs. Also, check the convergence profile by examining how fast the magnitude of the gradient vectors decreases as the iterations proceed.
# In[ ]:
xp_init = np.linspace(x_init, x_des, num=N+1, endpoint=True)
yp_init = np.hstack([xp_init[1:].reshape(-1), *[g]*N])[:,None] # initial values
lr = 1e-1
eps_crt = 1e-5 # terminal condition of gradient descent
epochs = 100000
# In[ ]:
## your code here ##
# In[ ]:
plt.figure(figsize=(8,6), dpi=100)
plt.semilogy(grads)
plt.xlabel("Iteration")
plt.title("Magnitude of the gradient")
plt.grid()
# In[ ]:
## your code here ##