import numpy as np import matplotlib.pyplot as plt n = 1000 # number of timesteps T = 50 # time will vary from 0 to T with step delt ts = np.linspace(0,T,n+1) delt = T/n gamma = .05 # damping, 0 is no damping A = np.zeros((4,4)) B = np.zeros((4,2)) C = np.zeros((2,4)) A[0,0] = 1 A[1,1] = 1 A[0,2] = (1-gamma*delt/2)*delt A[1,3] = (1-gamma*delt/2)*delt A[2,2] = 1 - gamma*delt A[3,3] = 1 - gamma*delt B[0,0] = delt**2/2 B[1,1] = delt**2/2 B[2,0] = delt B[3,1] = delt import scipy.sparse as ssp import scipy.sparse.linalg as sla x_0 = np.array([10, -20, 15, -5]) x_des = np.array([100, 50, 0, 0]) G = np.zeros((4,2*n)) for i in range(n): G[:, 2*i:2*(i+1)] = np.linalg.matrix_power(A,max(0,n-i-1))@B u_hat = sla.lsqr(G,x_des - np.linalg.matrix_power(A,n)@x_0)[0] u_vec = u_hat u_opt = u_vec.reshape(1000,2).T x = np.zeros((4,n+1)) x[:,0] = x_0 for t in range(n): x[:,t+1] = A.dot(x[:,t]) + B.dot(u_opt[:,t]) plt.figure(figsize=(14,9), dpi=100) plt.subplot(2,2,1) plt.plot(ts,x[0,:]) plt.xlabel('time') plt.ylabel('x position') plt.grid() plt.subplot(2,2,2) plt.plot(ts,x[1,:]) plt.xlabel('time') plt.ylabel('y position') plt.grid() plt.subplot(2,2,3) plt.plot(ts,x[2,:]) plt.xlabel('time') plt.ylabel('x velocity') plt.grid() plt.subplot(2,2,4) plt.plot(ts,x[3,:]) plt.xlabel('time') plt.ylabel('y velocity') plt.grid() plt.show() plt.figure(figsize=(14,9), dpi=100) plt.subplot(2,2,1) plt.plot(ts[:-1],u_opt[0,:]) plt.xlabel('time') plt.ylabel(r'$u_1$') plt.grid() plt.subplot(2,2,2) plt.plot(ts[:-1],u_opt[1,:]) plt.xlabel('time') plt.ylabel(r'$u_2$') plt.grid() plt.show() plt.figure(figsize=(14,9), dpi=100) plt.plot(x[0,:],x[1,:], label='Optimal trajectory') plt.plot(x_0[0], x_0[1], 'o', markersize=7, label='Initial position') plt.plot(x_des[0], x_des[1], '*', markersize=10, label='Target position') plt.title('Trajectory') plt.legend() for i in range(0,n-1,10): plt.arrow(x[0,i], x[1,i], 10*u_opt[0,i], 10*u_opt[1,i], head_width=1, width=0.2, fc='tab:red', ec='none') plt.axis('equal') plt.xlabel(r'$x$ position') plt.ylabel(r'$y$ position') plt.grid() plt.show()