In the first part, we will generate non linear data and use a linear regression model in a higher dimensional feature space to find a regression model for this data. The beauty of many machine learning algorithms rely relies on this idea of using (possibly high dimensional) linear models to learn complex (possibly non linear) datasets.
# We start by generating points according to a non linear (quadratic) relation %matplotlib notebook import numpy as np import matplotlib.pyplot as plt from matplotlib.pylab import rcParams rcParams['figure.figsize'] = 6, 4 x = np.linspace(-2,1, 100) t = 1+x+x**2 t = t + np.random.normal(0,.1, len(x)) plt.scatter(x, t,color = 'r') plt.show()
# As a first illustration of the limitations of linear regression, we try to fit a linear model # to the non linear data. As could have been expected, this obviously does not work. from sklearn.linear_model import LinearRegression x = x.reshape(-1,1) print(np.shape(x)) regr = LinearRegression() regr.fit(x, t) test_points = np.linspace(-2,1, 100) y_predicted = regr.predict(test_points.reshape(-1,1)) plt.scatter(x, t, color = 'r') plt.plot(test_points, y_predicted) plt.show()
# We then bring the data in 2D # (adding one more feature, thus getting a model of the form $y(X)= beta_0 + beta_1 X_1 + beta_2 X_2), # In order to learn this new model, we thus need to provide data of the form $t^i, X^i_1, X^i_2$. # Since we only generated data of the form $t, X_1$, we generated the remaining feature $X_2$ from the # first feature as X_2 = X_1^2. Note that the model is still linear (it does not know that we generated # the new feature by squaring X_1. All it knows is that we now give it two features X1 and X2) X = np.hstack((x.reshape(-1,1), x.reshape(-1,1)**2)) #print(X) regr = LinearRegression() regr.fit(X, t) test_points = np.linspace(-2,1, 100) Testmat2 = np.hstack((test_points.reshape(-1,1), test_points.reshape(-1,1)**2)) y_predicted2 = regr.predict(Testmat2) from mpl_toolkits.mplot3d import Axes3D print(np.shape(X[:,0])) print(np.shape(X[:,1])) print(np.shape(t)) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(X[:,0], X[:,1], t, color='r') plt.show()
(100,) (100,) (100,)