from sympy.interactive.printing import init_printing
init_printing(use_latex=True)
from sympy import Rational, pi
Rational(3,2)*pi

%%latex
\begin{align}
a_{00}\,x_0 + a_{01}\,x_1 + \ldots + a_{0,n-1}\,x_{n-1}   &= b_0 \\\
a_{10}\,x_0  + a_{11}\,x_1 + \ldots + a_{1,n-1}\,x_{n-1}   &= b_1 \\\
\vdots &  \\\
a_{n-1,0}\,x_0 + a_{n-1,1}\,x_1 + \ldots + a_{n-1,n-1}\,x_{n-1} &= b_{n-1}\ .
\end{align}

%load_ext cythonmagic
import numpy as np
import scipy.linalg as la

%%cython
cimport cython
import numpy as np
cimport numpy as np

def cython_matprod(np.ndarray[double, ndim = 2] A, B):
    '''
    Matrix-by-vector or matrix-by-matrix multiplication.

    The arguments are dispatched to one of two functions
    depending on whether B is a vector or a matrix.
    '''
    if B.ndim == 1:
       # B is a vector
       return matvecprod(A, B)
    else:
        # B is a matrix
        return matmatprod(A, B)

@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[double, ndim=2] matmatprod(
    np.ndarray[double, ndim=2] A,
    np.ndarray[double, ndim=2] B):
    '''
    Matrix-matrix multiplication.
    '''
    cdef: 
        int i, j, k
        int A_n = A.shape[0]
        int A_m = A.shape[1]
        int B_n = B.shape[0]
        int B_m = B.shape[1]
        np.ndarray[double, ndim=2] C
    
    # Are matrices conformable?
    assert A_m == B_n, \
        'Non-conformable shapes.'
    
    # Initialize the results matrix.
    C = np.zeros((A_n, B_m))
    for i in xrange(A_n):
        for j in xrange(B_m):
            for k in xrange(A_m):
                C[i, j] += A[i, k] * B[k, j]
    return C

@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[double, ndim=1] matvecprod(
    np.ndarray[double, ndim=2] A,
    np.ndarray[double, ndim=1] b):
    '''
    Matrix-vector multiplication.
    '''
    cdef: 
        Py_ssize_t i, j, k
        Py_ssize_t A_n = A.shape[0]
        Py_ssize_t A_m = A.shape[1]
        Py_ssize_t b_n = b.shape[0]
        np.ndarray[double, ndim=1] c
    
    # Are matrices conformable?
    assert A_m == b_n, \
        'Non-conformable shapes.'
    
    # Initialize the results matrix.
    c = np.zeros(A_n)
    for i in xrange(A_n):
            for k in xrange(b_n):
                c[i] += A[i, k] * b[k]
    return c

def python_matmatprod(A, B):
    '''
    Matrix-matrix multiplication
    '''
    A_n, A_m = A.shape
    B_n, B_m = B.shape
    assert A_m == B_n, "Non-conformable shapes."
    C = np.zeros((A_n, B_m))
    for i in xrange(A_n):
        for j in xrange(B_m):
            for k in xrange(A_m):
                C[i, j] += A[i, k] * B[k, j]
    return C

# A is 2x3
A = np.array([[2.0, 0.25, -1.0], 
              [3.0, 0.0 ,  5.0]])
# B is 3x2
B = np.array([[-3.0,  0.5], 
              [ 2.0,  1.5], 
              [ 4.0, -4.0]])
# C is 2x2
C = np.array([[1.0,  1.5], 
              [2.5, -1.0]])
# b is 3x1 (a vector)
b = np.array([1.0, -2.0, 0.5])

print 'Cython:'
print '-------'
print "A x B =\n", cython_matprod(A, B), "\n"
print "A x b =\n", cython_matprod(A, b), "\n"
print 'Numpy dot:'
print '----------'
print "A x B =\n", np.dot(A, B), "\n"
print "A x b =\n", np.dot(A, b), "\n"
print 'Python loops:'
print '-------------'
print "A x B =\n", python_matmatprod(A, B), "\n"

%timeit np.dot(A, B)

%timeit cython_matprod(A, B)

%timeit python_matmatprod(A, B)