Our use case is to build, train, and evaluate a prediction model for sales analysis.
In this model, we need to feed the advertising budget of TV, radio, and newspapers to the model and the model will forecast the possible sales.
The advertising dataset captures the sales revenue generated with respect to advertisement costs across numerous platforms like radio, TV, and newspapers.
# Import the necessary libraries
# For Data loading, Exploraotry Data Analysis, Graphing
import pandas as pd # Pandas for data processing libraries
import numpy as np # Numpy for mathematical functions
import matplotlib.pyplot as plt # Matplotlib for visualization tasks
import seaborn as sns # Seaborn for data visualization library based on matplotlib.
%matplotlib inline
import sklearn # ML tasks
from sklearn.model_selection import train_test_split # Split the dataset
from sklearn.metrics import mean_squared_error # Calculate Mean Squared Error
# Build the Network
from tensorflow import keras
from keras.models import Sequential
#from tensorflow.keras.models import Sequential
from keras.layers import Dense
#from keras.callbacks import EarlyStopping
# Next, you read the dataset into a Pandas dataframe.
url = 'https://github.com/LinkedInLearning/artificial-intelligence-foundations-neural-networks-4381282/blob/main/Advertising_2023.csv?raw=true'
advertising_df= pd.read_csv(url,index_col=0)
advertising_df.head(10)
digital | TV | radio | newspaper | sales | |
---|---|---|---|---|---|
1 | 345.15 | 156.0 | 37.8 | 69.2 | 22.1 |
2 | 66.75 | 46.0 | 39.3 | 45.1 | 10.4 |
3 | 25.80 | 18.3 | 45.9 | 69.3 | 9.3 |
4 | 227.25 | 145.1 | 41.3 | 58.5 | 18.5 |
5 | 271.20 | 165.2 | 10.8 | 58.4 | 12.9 |
6 | 13.05 | 8.7 | 48.9 | 75.0 | 7.2 |
7 | 86.25 | 57.5 | 32.8 | 23.5 | 11.8 |
8 | 180.30 | 120.2 | 19.6 | 11.6 | 13.2 |
9 | 12.90 | 8.6 | 2.1 | 1.0 | 4.8 |
10 | 299.70 | 199.8 | 2.6 | 21.2 | 10.6 |
# Pandas info() function is used to get a concise summary of the dataframe.
advertising_df.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 1199 entries, 1 to 1197 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 digital 1199 non-null float64 1 TV 1199 non-null float64 2 radio 1199 non-null float64 3 newspaper 1199 non-null float64 4 sales 1199 non-null float64 dtypes: float64(5) memory usage: 56.2 KB
### Get summary of statistics of the data
advertising_df.describe()
digital | TV | radio | newspaper | sales | |
---|---|---|---|---|---|
count | 1199.000000 | 1199.00000 | 1199.000000 | 1199.000000 | 1199.000000 |
mean | 135.472394 | 146.61985 | 23.240617 | 30.529942 | 14.005505 |
std | 135.730821 | 85.61047 | 14.820827 | 21.712507 | 5.202804 |
min | 0.300000 | 0.70000 | 0.000000 | 0.300000 | 1.600000 |
25% | 24.250000 | 73.40000 | 9.950000 | 12.800000 | 10.300000 |
50% | 64.650000 | 149.70000 | 22.500000 | 25.600000 | 12.900000 |
75% | 256.950000 | 218.50000 | 36.500000 | 45.100000 | 17.400000 |
max | 444.600000 | 296.40000 | 49.600000 | 114.000000 | 27.000000 |
#shape of dataframe - 1199 rows, five columns
advertising_df.shape
(1199, 5)
Let's check for any null values.
# The isnull() method is used to check and manage NULL values in a data frame.
advertising_df.isnull().sum()
digital 0 TV 0 radio 0 newspaper 0 sales 0 dtype: int64
#check there are any NAN values
advertising_df.isnull().values.any()
False
Let's create some simple plots to check out the data!
# The heatmap is a way of representing the data in a 2-dimensional form. The data values are represented as colors in the graph.
# The goal of the heatmap is to provide a colored visual summary of information.
sns.heatmap(advertising_df.corr())
<Axes: >
## Another option is to plot the heatmap so that the values are shown.
plt.figure(figsize=(10,5))
sns.heatmap(advertising_df.corr(),annot=True,vmin=0,vmax=1,cmap='ocean')
<Axes: >
#create a correlation matrix
corr = advertising_df.corr()
plt.figure(figsize=(10, 5))
sns.heatmap(corr[(corr >= 0.5) | (corr <= -0.7)],
cmap='viridis', vmax=1.0, vmin=-1.0, linewidths=0.1,
annot=True, annot_kws={"size": 8}, square=True)
plt.tight_layout()
display(plt.show())
None
advertising_df.corr()
digital | TV | radio | newspaper | sales | |
---|---|---|---|---|---|
digital | 1.000000 | 0.474256 | 0.041316 | 0.048023 | 0.380101 |
TV | 0.474256 | 1.000000 | 0.055697 | 0.055579 | 0.781824 |
radio | 0.041316 | 0.055697 | 1.000000 | 0.353096 | 0.576528 |
newspaper | 0.048023 | 0.055579 | 0.353096 | 1.000000 | 0.227039 |
sales | 0.380101 | 0.781824 | 0.576528 | 0.227039 | 1.000000 |
### Visualize Correlation
# Generate a mask for the upper triangle
mask = np.zeros_like(advertising_df.corr(), dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(advertising_df.corr(), mask=mask, cmap=cmap, vmax=.9, square=True, linewidths=.5, ax=ax)
<ipython-input-13-6c77a4103e7b>:4: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations mask = np.zeros_like(advertising_df.corr(), dtype=np.bool)
<Axes: >
Since Sales is our target variable, we should identify which variable correlates the most with Sales.
As we can see, TV has the highest correlation with Sales. Let's visualize the relationship of variables using scatterplots.
# It is used basically for univariant set of observations and visualizes it through a histogram i.e. only one observation
# and hence you choose one particular column of the dataset.
sns.displot(advertising_df['sales'])
<seaborn.axisgrid.FacetGrid at 0x7af11a8ad300>
Let's visualize the relationship of variables using scatterplots. -- Separately
Another way to view the linear relationsips between variables is to use a "for loop" that does the same as above.
It seems there's no clear linear relationships between the predictors.
At this point, we know that the variable TV will more likely give better prediction of Sales because of the high correlation and linearity of the two.
'''=== Show the linear relationship between features and sales Thus, it provides that how the scattered
they are and which features has more impact in prediction of house price. ==='''
# visiualize all variables with sales
from scipy import stats
#creates figure
plt.figure(figsize=(18, 18))
for i, col in enumerate(advertising_df.columns[0:13]): #iterates over all columns except for price column (last one)
plt.subplot(5, 3, i+1) # each row three figure
x = advertising_df[col] #x-axis
y = advertising_df['sales'] #y-axis
plt.plot(x, y, 'o')
# Create regression line
plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1)) (np.unique(x)), color='red')
plt.xlabel(col) # x-label
plt.ylabel('sales') # y-label
Concluding results after observing the Graph The relation bw TV and Sales is stong and increases in linear fashion The relation bw Radio and Sales is less stong The relation bw TV and Sales is weak
Regression is a supervised machine learning process. It is similar to classification, but rather than predicting a label, you try to predict a continuous value. Linear regression defines the relationship between a target variable (y) and a set of predictive features (x). Simply stated, If you need to predict a number, then use regression.
Let's now begin to train your regression model! You will need to first split up your data into an X array that contains the features to train on, and a y array with the target variable, in this case the Price column. You will toss out the Address column because it only has text info that the linear regression model can't use.
Next, let's define the features and label. Briefly, feature is input; label is output. This applies to both classification and regression problems.
X = advertising_df[['digital', 'TV', 'radio', 'newspaper']]
y = advertising_df['sales']
'''=== Noramlization the features. Since it is seen that features have different ranges, it is best practice to
normalize/standarize the feature before using them in the model ==='''
#feature normalization
normalized_feature = keras.utils.normalize(X.values)
print(normalized_feature)
[[0.89211961 0.4032179 0.0977028 0.17886333] [0.66254734 0.45658693 0.39008405 0.44765371] [0.29009225 0.20576311 0.51609436 0.77920128] ... [0.06744611 0.99272247 0.05163843 0.08536149] [0.19480049 0.91868871 0.08898294 0.33188231] [0.06744611 0.99272247 0.05163843 0.08536149]]
Now let's split the data into a training and test set. Note: Best pracices is to split into three - training, validation, and test set.
By default - It splits the given data into 75-25 ration
# Import train_test_split function from sklearn.model_selection
from sklearn.model_selection import train_test_split
# Split up the data into a training set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=101)
print(X_train.shape,X_test.shape, y_train.shape, y_test.shape )
(719, 4) (480, 4) (719,) (480,)
Because so few samples are available, we will be using a very small network with two hidden layers, each with 64 units. In general, the less training data you have, the worse overfitting will be, and using a small network is one way to mitigate overfitting.
## Build Model (Building a three layer network - with one hidden layer)
model = Sequential()
model.add(Dense(4,input_dim=4, activation='relu')) # You don't have to specify input size.Just define the hidden layers
model.add(Dense(3,activation='relu'))
model.add(Dense(1))
# Compile Model
model.compile(optimizer='adam', loss='mse',metrics=['mse'])
# Fit the Model
history = model.fit(X_train, y_train, validation_data = (X_test, y_test),
epochs = 32)
Epoch 1/32 23/23 [==============================] - 2s 20ms/step - loss: 6355.9458 - mse: 6355.9458 - val_loss: 5807.6504 - val_mse: 5807.6504 Epoch 2/32 23/23 [==============================] - 0s 10ms/step - loss: 4806.4727 - mse: 4806.4727 - val_loss: 4447.7969 - val_mse: 4447.7969 Epoch 3/32 23/23 [==============================] - 0s 9ms/step - loss: 3695.4465 - mse: 3695.4465 - val_loss: 3428.7375 - val_mse: 3428.7375 Epoch 4/32 23/23 [==============================] - 0s 8ms/step - loss: 2848.5457 - mse: 2848.5457 - val_loss: 2607.1289 - val_mse: 2607.1289 Epoch 5/32 23/23 [==============================] - 0s 7ms/step - loss: 2141.7068 - mse: 2141.7068 - val_loss: 1864.4614 - val_mse: 1864.4614 Epoch 6/32 23/23 [==============================] - 0s 6ms/step - loss: 1493.2397 - mse: 1493.2397 - val_loss: 1210.5009 - val_mse: 1210.5009 Epoch 7/32 23/23 [==============================] - 0s 12ms/step - loss: 950.9180 - mse: 950.9180 - val_loss: 668.1801 - val_mse: 668.1801 Epoch 8/32 23/23 [==============================] - 0s 12ms/step - loss: 535.9688 - mse: 535.9688 - val_loss: 365.2287 - val_mse: 365.2287 Epoch 9/32 23/23 [==============================] - 0s 11ms/step - loss: 318.2729 - mse: 318.2729 - val_loss: 231.7376 - val_mse: 231.7376 Epoch 10/32 23/23 [==============================] - 0s 8ms/step - loss: 230.9233 - mse: 230.9233 - val_loss: 193.1924 - val_mse: 193.1924 Epoch 11/32 23/23 [==============================] - 0s 7ms/step - loss: 200.3414 - mse: 200.3414 - val_loss: 181.3994 - val_mse: 181.3994 Epoch 12/32 23/23 [==============================] - 0s 9ms/step - loss: 185.2456 - mse: 185.2456 - val_loss: 171.7923 - val_mse: 171.7923 Epoch 13/32 23/23 [==============================] - 0s 8ms/step - loss: 173.6612 - mse: 173.6612 - val_loss: 162.9434 - val_mse: 162.9434 Epoch 14/32 23/23 [==============================] - 0s 9ms/step - loss: 164.4381 - mse: 164.4381 - val_loss: 155.8642 - val_mse: 155.8642 Epoch 15/32 23/23 [==============================] - 0s 8ms/step - loss: 156.3019 - mse: 156.3019 - val_loss: 149.8550 - val_mse: 149.8550 Epoch 16/32 23/23 [==============================] - 0s 11ms/step - loss: 149.5408 - mse: 149.5408 - val_loss: 144.7365 - val_mse: 144.7365 Epoch 17/32 23/23 [==============================] - 0s 10ms/step - loss: 143.6699 - mse: 143.6699 - val_loss: 140.9443 - val_mse: 140.9443 Epoch 18/32 23/23 [==============================] - 0s 15ms/step - loss: 138.9901 - mse: 138.9901 - val_loss: 136.9614 - val_mse: 136.9614 Epoch 19/32 23/23 [==============================] - 0s 15ms/step - loss: 134.6012 - mse: 134.6012 - val_loss: 133.7686 - val_mse: 133.7686 Epoch 20/32 23/23 [==============================] - 0s 19ms/step - loss: 130.8806 - mse: 130.8806 - val_loss: 130.9590 - val_mse: 130.9590 Epoch 21/32 23/23 [==============================] - 0s 20ms/step - loss: 127.4660 - mse: 127.4660 - val_loss: 128.1422 - val_mse: 128.1422 Epoch 22/32 23/23 [==============================] - 0s 10ms/step - loss: 123.8663 - mse: 123.8663 - val_loss: 125.6087 - val_mse: 125.6087 Epoch 23/32 23/23 [==============================] - 0s 9ms/step - loss: 120.4207 - mse: 120.4207 - val_loss: 122.6469 - val_mse: 122.6469 Epoch 24/32 23/23 [==============================] - 0s 13ms/step - loss: 116.7625 - mse: 116.7625 - val_loss: 119.3021 - val_mse: 119.3021 Epoch 25/32 23/23 [==============================] - 0s 9ms/step - loss: 112.5525 - mse: 112.5525 - val_loss: 115.5973 - val_mse: 115.5973 Epoch 26/32 23/23 [==============================] - 0s 15ms/step - loss: 107.6662 - mse: 107.6662 - val_loss: 110.8775 - val_mse: 110.8775 Epoch 27/32 23/23 [==============================] - 0s 9ms/step - loss: 101.2922 - mse: 101.2922 - val_loss: 103.4047 - val_mse: 103.4047 Epoch 28/32 23/23 [==============================] - 0s 12ms/step - loss: 91.9364 - mse: 91.9364 - val_loss: 91.9141 - val_mse: 91.9141 Epoch 29/32 23/23 [==============================] - 0s 8ms/step - loss: 80.0998 - mse: 80.0998 - val_loss: 79.4774 - val_mse: 79.4774 Epoch 30/32 23/23 [==============================] - 0s 9ms/step - loss: 66.3701 - mse: 66.3701 - val_loss: 61.2325 - val_mse: 61.2325 Epoch 31/32 23/23 [==============================] - 0s 8ms/step - loss: 49.9562 - mse: 49.9562 - val_loss: 43.2350 - val_mse: 43.2350 Epoch 32/32 23/23 [==============================] - 0s 9ms/step - loss: 34.8417 - mse: 34.8417 - val_loss: 31.0812 - val_mse: 31.0812
Once we've run data through the model, we can call .summary() on the model to get a high-level summary of our network.
#inspect the model
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 4) 20 dense_1 (Dense) (None, 3) 15 dense_2 (Dense) (None, 1) 4 ================================================================= Total params: 39 Trainable params: 39 Non-trainable params: 0 _________________________________________________________________
model.evaluate(X_test, y_test)[1]
15/15 [==============================] - 0s 3ms/step - loss: 31.0812 - mse: 31.0812
31.08115005493164
Running .fit (or .fit_generator) returns a History object which collects all the events recorded during training. You can plot the training and validation curves for the model loss and mse by accessing these elements of the History object.
MSE_COLS = ["mse", "val_mse"]
pd.DataFrame(history.history)[MSE_COLS].plot()
<Axes: >
You can add more 'flavor' to the graph by making it bigger and adding labels and names, as shown below.
## Plot a graph of model loss # show the graph of model loss in trainig and validation
plt.figure(figsize=(15,8))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss (MSE) on Training and Validation Data')
plt.ylabel('Loss-Mean Squred Error')
plt.xlabel('Epoch')
plt.legend(['Val Loss', 'Train Loss'], loc='upper right')
plt.show()
'''=== predict the SALES =='''
# predict SALES using the test data
test_predictions = model.predict(X_test).flatten()
print(test_predictions)
15/15 [==============================] - 0s 2ms/step [31.417734 21.220354 4.7196593 26.852045 16.403692 34.580494 20.758804 21.125383 18.414904 16.656893 6.99268 18.21966 4.8979464 9.0408535 21.9296 12.806722 18.01036 15.848582 7.797866 16.173466 9.119615 7.0110044 6.602976 14.834849 9.256564 3.559621 7.9958944 17.242098 13.51514 19.36989 9.369475 17.497568 21.569561 20.004793 11.4648075 22.048296 17.40079 23.607897 14.169772 10.9558325 11.268399 15.120288 3.8518872 14.260594 8.536372 14.463492 7.58078 17.40079 14.42262 7.5760717 6.713754 21.9296 8.604031 15.934314 4.695321 5.0222135 19.32795 20.942812 9.972568 13.3409605 15.139911 6.5433683 19.137207 15.174861 9.164914 19.908352 10.831012 14.373332 9.987815 14.764285 7.4007397 21.604433 7.204871 20.466913 12.563774 18.223497 23.103373 16.319998 17.497568 14.114753 20.655489 6.8951597 10.644683 15.720672 16.938118 17.671738 11.63856 10.355195 4.3273115 6.410767 21.544268 19.36989 7.4390297 5.310464 9.329084 10.222337 3.1649032 26.852045 28.403513 4.3525066 5.6985235 15.636931 7.946986 7.956064 9.057344 14.764285 6.0379877 5.98932 8.051751 5.288837 6.602976 7.841485 8.069265 22.08629 8.240805 20.151026 5.6985235 26.918314 17.077515 9.400619 17.497568 21.085949 6.8951597 7.4020967 13.199025 7.963542 7.993642 6.510568 4.359139 28.281649 13.444003 7.797866 10.337239 7.5432014 19.264431 15.37106 6.2916727 11.078197 10.293751 18.648272 6.510568 12.563774 7.7173758 19.32795 3.1649032 6.149497 9.1664915 30.761736 6.9206405 8.053128 15.848582 12.339154 4.120268 12.319783 22.321661 12.339154 12.563774 17.461306 7.723098 18.669308 20.151026 16.656893 20.967234 9.057344 6.8077383 16.582981 6.3122845 7.2171154 7.2173805 24.412577 17.391472 7.402831 13.671414 11.782079 7.5760717 16.656893 27.938532 22.181067 7.721855 25.166918 5.574973 8.428559 16.747652 27.272703 9.246051 9.256564 7.496188 14.328734 4.0990367 15.414613 7.7408857 13.918242 9.762905 15.139911 10.143068 5.6985235 14.463492 13.951335 4.120268 17.807077 4.0375853 17.08114 15.256427 10.089583 6.817773 15.174861 4.3273115 1.4163067 8.586685 12.37456 13.4068 17.382402 17.092617 20.655489 14.310363 11.287998 6.85508 6.676526 2.9108717 5.574973 8.051751 11.268399 15.256427 6.9206405 17.382402 6.676526 12.806722 16.582981 9.01859 11.44165 7.963542 6.817773 10.644683 11.462901 16.801722 6.562961 3.0752811 28.281649 25.166918 13.823487 15.761867 7.58078 7.834547 28.683146 7.834547 6.410767 4.3273115 7.5760717 14.125691 9.280607 18.525116 9.267183 5.539628 12.722218 5.1130857 6.072568 7.872463 16.801722 9.693255 13.396647 17.08114 20.32375 6.85508 21.945993 13.3409605 21.946344 19.00577 16.935167 18.621302 5.288837 12.339154 18.74473 20.60451 8.962937 6.562961 7.721855 8.148759 12.562656 14.761786 20.653154 17.077515 9.912691 7.9652596 6.8205976 7.872463 15.553493 19.458017 7.9652615 7.509945 9.246051 13.0507765 11.775274 9.762905 15.414613 15.202596 13.6534605 4.5267553 1.4720614 3.6611114 19.99608 21.373169 13.263423 2.6658518 6.475648 6.294458 28.6508 20.812492 11.350326 5.0215697 18.278322 21.058418 6.2225237 16.039553 14.328734 22.08629 3.559621 15.644687 18.669308 9.400615 28.937754 14.105758 17.908709 6.5618324 18.01036 12.319783 7.7173758 13.895797 14.463492 12.493129 6.0379877 6.3122845 7.4390297 8.570477 18.154806 9.649221 20.967234 16.493227 11.44165 21.946344 12.717574 11.573382 18.154806 6.979954 7.7173758 11.019221 6.294458 20.466913 19.199402 4.512447 9.119615 17.081125 4.037586 10.767676 16.29941 10.831012 9.604157 12.920837 11.951186 17.242098 6.6709585 4.8965936 8.358804 7.121706 6.5433683 3.7456455 8.358804 9.040542 6.17523 25.166918 16.598705 13.927377 15.720672 14.886226 4.695321 13.120873 5.539628 6.8205976 6.9206405 1.4163067 15.139911 11.070965 6.1766205 23.607897 9.652304 14.419147 11.183952 8.042495 18.851282 28.6508 9.008142 12.272616 11.947124 9.552407 18.21966 6.979954 8.536372 13.950852 23.607897 7.204871 9.329084 17.6393 27.938532 18.670609 18.223497 27.16662 22.212885 10.074108 22.104092 4.8979464 7.2171154 8.962937 7.5321503 9.2947035 3.908699 9.987815 2.6658518 31.417734 13.926873 28.6508 7.2173805 6.2862463 17.6393 27.938532 20.653154 9.552407 6.2020087 20.930025 11.350326 16.173466 15.37106 7.723098 7.496188 4.695321 15.80786 20.930025 20.653154 18.278322 7.4477644 6.3325458 15.636931 21.058418 8.132774 18.669308 5.164823 17.918097 16.231375 7.9652615 10.337241 6.517433 15.681098 10.401207 6.294458 19.99608 16.055796 11.331473 22.186472 18.71135 22.104092 7.7777843 17.382402 16.403692 11.019221 11.039078 9.677306 24.90838 7.0110044 22.104092 6.517433 22.321661 17.242098 21.569561 19.137207 19.36989 20.32375 14.485332 15.80786 19.137207 15.519931 3.559621 14.275462 ]
# show the true value and predicted value in dataframe
true_predicted = pd.DataFrame(list(zip(y_test, test_predictions)),
columns=['True Value','Predicted Value'])
true_predicted.head(6) # Show first six rows
True Value | Predicted Value | |
---|---|---|
0 | 26.2 | 31.417734 |
1 | 19.0 | 21.220354 |
2 | 12.8 | 4.719659 |
3 | 20.8 | 26.852045 |
4 | 16.9 | 16.403692 |
5 | 23.8 | 34.580494 |
Visualize the preditction.
# visualize the prediction uisng diagonal line
y = test_predictions #y-axis
x = y_test #x-axis
fig, ax = plt.subplots(figsize=(10,6)) # create figure
ax.scatter(x,y) #scatter plots for x,y
ax.set(xlim=(0,55), ylim=(0, 55)) #set limit
ax.plot(ax.get_xlim(), ax.get_ylim(), color ='red') # draw 45 degree diagonal in figure
plt.xlabel('True Values')
plt.ylabel('Predicted values')
plt.title('Evaluation Result')
plt.show()
Show the accuracy of Linear Regression on the dataset. The linear regression graph is created by train data and the model line is shown by the blue line which is created using test data and predicted data as we can see most of the red dots are on the line, thus we can say that model has produced the best-fit line.
#Accuracy of linear regression on the dataset
plt.figure(figsize=(10,5))
sns.regplot(x=y_test,y=test_predictions,scatter_kws={'color':'red'})
<Axes: xlabel='sales'>
Step 6 - Predict on the Test Data and Compute Evaluation Metrics The first line of code predicts on the train data, while the second line prints the RMSE value on the train data. The same is repeated in the third and fourth lines of code which predicts and prints the RMSE value on test data.
pred_train= model.predict(X_train)
print(np.sqrt(mean_squared_error(y_train,pred_train)))
pred= model.predict(X_test)
print(np.sqrt(mean_squared_error(y_test,pred)))
23/23 [==============================] - 0s 2ms/step 5.427096996448226 15/15 [==============================] - 0s 9ms/step 5.575046908767078
Evaluation of the Model Performance The output above shows that the RMSE, which is our evaluation metric, was 3.784 thousand for train data and 3.750 thousand for test data. Ideally, the lower the RMSE value, the better the model performance. However, in contrast to accuracy, it is not straightforward to interpret RMSE as we would have to look at the unit which in our case is in thousands.
Here are three common evaluation metrics for regression problems:
Mean Absolute Error (MAE) is the mean of the absolute value of the errors:
$$\frac 1n\sum_{i=1}^n|y_i-\hat{y}_i|$$Mean Squared Error (MSE) is the mean of the squared errors:
$$\frac 1n\sum_{i=1}^n(y_i-\hat{y}_i)^2$$Root Mean Squared Error (RMSE) is the square root of the mean of the squared errors:
$$\sqrt{\frac 1n\sum_{i=1}^n(y_i-\hat{y}_i)^2}$$Comparing these metrics:
All of these are loss functions, because you want to minimize them.