This notebook adapts the existing example of applying support vector classification from scikit-learn (https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#sphx-glr-auto-examples-classification-plot-digits-classification-py) to PyRCN to demonstrate, how PyRCN can be used to classify hand-written digits.
The tutorial is based on numpy, scikit-learn and PyRCN.
import numpy as np
import time
from sklearn.base import clone
from sklearn.model_selection import train_test_split
from sklearn.model_selection import (
ParameterGrid, RandomizedSearchCV, cross_validate)
from scipy.stats import uniform, loguniform
from sklearn.metrics import make_scorer
from pyrcn.model_selection import SequentialSearchCV
from pyrcn.echo_state_network import ESNClassifier
from pyrcn.metrics import accuracy_score
from pyrcn.datasets import load_digits
The dataset is already part of scikit-learn and consists of 1797 8x8 images.
We are using our dataloader that is derived from scikit-learns dataloader and returns arrays of 8x8 sequences and corresponding labels.
X, y = load_digits(return_X_y=True, as_sequence=True)
print("Number of digits: {0}".format(len(X)))
print("Shape of digits {0}".format(X[0].shape))
Number of digits: 1797 Shape of digits (8, 8)
Afterwards, we split the dataset into training and test sets. We train the ESN using 80% of the digits and test it using the remaining images.
stratify = np.asarray([np.unique(yt) for yt in y]).flatten()
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=stratify, random_state=42)
X_tr = np.copy(X_train)
y_tr = np.copy(y_train)
X_te = np.copy(X_test)
y_te = np.copy(y_test)
for k, _ in enumerate(y_tr):
y_tr[k] = np.repeat(y_tr[k], 8, 0)
for k, _ in enumerate(y_te):
y_te[k] = np.repeat(y_te[k], 8, 0)
print("Number of digits in training set: {0}".format(len(X_train)))
print("Shape of digits in training set: {0}".format(X_train[0].shape))
print("Number of digits in test set: {0}".format(len(X_test)))
print("Shape of digits in test set: {0}".format(X_test[0].shape))
Number of digits in training set: 1437 Shape of digits in training set: (8, 8) Number of digits in test set: 360 Shape of digits in test set: (8, 8)
To develop an ESN model for digit recognition, we need to tune several hyper-parameters, e.g., input_scaling, spectral_radius, bias_scaling and leaky integration.
We follow the way proposed in the introductory paper of PyRCN to optimize hyper-parameters sequentially.
We define the search spaces for each step together with the type of search (a grid search in this context).
At last, we initialize a SeqToLabelESNClassifier with the desired output strategy and with the initially fixed parameters.
initially_fixed_params = {
'hidden_layer_size': 50, 'input_activation': 'identity', 'k_in': 5,
'bias_scaling': 0.0, 'reservoir_activation': 'tanh', 'leakage': 1.0,
'bidirectional': False, 'k_rec': 10, 'continuation': False, 'alpha': 1e-5,
'random_state': 42, 'decision_strategy': "winner_takes_all"}
step1_esn_params = {'input_scaling': uniform(loc=1e-2, scale=1),
'spectral_radius': uniform(loc=0, scale=2)}
step2_esn_params = {'leakage': loguniform(1e-5, 1e0)}
step3_esn_params = {'bias_scaling': uniform(loc=0, scale=2)}
step4_esn_params = {'alpha': loguniform(1e-5, 1e0)}
kwargs_step1 = {'n_iter': 200, 'random_state': 42, 'verbose': 10, 'n_jobs': -1,
'scoring': make_scorer(accuracy_score)}
kwargs_step2 = {'n_iter': 50, 'random_state': 42, 'verbose': 1, 'n_jobs': -1,
'scoring': make_scorer(accuracy_score)}
kwargs_step3 = {'verbose': 1, 'n_jobs': -1,
'scoring': make_scorer(accuracy_score)}
kwargs_step4 = {'n_iter': 50, 'random_state': 42, 'verbose': 1, 'n_jobs': -1,
'scoring': make_scorer(accuracy_score)}
# The searches are defined similarly to the steps of a sklearn.pipeline.Pipeline:
searches = [('step1', RandomizedSearchCV, step1_esn_params, kwargs_step1),
('step2', RandomizedSearchCV, step2_esn_params, kwargs_step2),
('step3', RandomizedSearchCV, step3_esn_params, kwargs_step3),
('step4', RandomizedSearchCV, step4_esn_params, kwargs_step4)]
base_esn = ESNClassifier(**initially_fixed_params)
We provide a SequentialSearchCV that basically iterates through the list of searches that we have defined before. It can be combined with any model selection tool from scikit-learn.
sequential_search = SequentialSearchCV(base_esn, searches=searches).fit(
X_tr, y_tr)
Fitting 5 folds for each of 200 candidates, totalling 1000 fits Fitting 5 folds for each of 50 candidates, totalling 250 fits Fitting 5 folds for each of 10 candidates, totalling 50 fits Fitting 5 folds for each of 50 candidates, totalling 250 fits
After the optimization, we extract the ESN with final hyper-parameters as the result of the optimization.
base_esn = sequential_search.best_estimator_
base_esn.get_params()
{'bias_scaling': 1.500499867548985, 'bias_shift': 0.0, 'hidden_layer_size': 50, 'input_activation': 'identity', 'input_scaling': 0.016952130531190705, 'input_shift': 0.0, 'k_in': 5, 'predefined_bias_weights': None, 'predefined_input_weights': None, 'random_state': 42, 'sparsity': 0.2, 'bidirectional': False, 'k_rec': 10, 'leakage': 0.026373339933815243, 'predefined_recurrent_weights': None, 'reservoir_activation': 'tanh', 'spectral_radius': 1.0214946051551315, 'alpha': 1.2674255898937214e-05}
sequential_search.all_cv_results_["step4"]
{'mean_fit_time': array([2.6532074 , 2.78561602, 2.60994244, 2.60310826, 2.60697007, 2.64569302, 2.61672802, 2.65451388, 2.60384569, 2.65692558, 2.65675936, 2.82678428, 2.80295968, 2.59712319, 2.6227159 , 2.84535923, 2.58692093, 2.6632936 , 2.57469478, 2.56074982, 2.59788303, 2.588311 , 2.58742948, 2.59293227, 2.65385418, 2.59408164, 2.53729944, 2.56438246, 2.56678057, 2.5659205 , 2.57332883, 2.58767653, 2.5798398 , 2.57694831, 2.61877756, 2.71321197, 2.86083097, 2.80925045, 3.00275102, 2.93799133, 2.87340975, 2.7778234 , 2.76509657, 2.81323166, 2.90686755, 2.84235563, 2.84695301, 2.85544257, 2.98605399, 2.61845927]), 'std_fit_time': array([0.14858599, 0.12348577, 0.03306796, 0.08086142, 0.03734891, 0.0695286 , 0.08522908, 0.07412123, 0.03764825, 0.05583214, 0.10757 , 0.1773657 , 0.16576643, 0.06928259, 0.13255758, 0.19700425, 0.0754604 , 0.09600206, 0.07056154, 0.09768716, 0.08431972, 0.04275456, 0.04707553, 0.06725929, 0.05697506, 0.0577874 , 0.04623621, 0.03177068, 0.01874206, 0.05587149, 0.05580291, 0.04256553, 0.03625484, 0.04998939, 0.0781173 , 0.11202149, 0.16487397, 0.10515313, 0.22132628, 0.14726933, 0.03476583, 0.0293329 , 0.03798615, 0.09745271, 0.07042488, 0.04373445, 0.07296759, 0.08742276, 0.10569645, 0.22623521]), 'mean_score_time': array([0.45047684, 0.45714602, 0.4183434 , 0.44898777, 0.45384693, 0.46315026, 0.43620358, 0.46164112, 0.4391418 , 0.4501318 , 0.45846133, 0.43599329, 0.44054184, 0.46960073, 0.46084428, 0.49376464, 0.4295187 , 0.42602949, 0.43847399, 0.42765632, 0.43435025, 0.42474437, 0.44880104, 0.43577037, 0.43333535, 0.42483249, 0.43919163, 0.42167859, 0.42816215, 0.45055833, 0.41359997, 0.45228491, 0.44458857, 0.43079457, 0.44092717, 0.45220747, 0.47940674, 0.45597787, 0.47957048, 0.46685867, 0.44076443, 0.45997396, 0.46912503, 0.49164128, 0.4674027 , 0.49251075, 0.45494628, 0.51233678, 0.45821381, 0.34683628]), 'std_score_time': array([0.01720297, 0.04384708, 0.02255728, 0.03739307, 0.01739511, 0.02808892, 0.01232231, 0.03649716, 0.02363613, 0.02076977, 0.042372 , 0.01075888, 0.00828188, 0.0283921 , 0.04988519, 0.03591959, 0.01403997, 0.00931752, 0.01194943, 0.01462739, 0.01634563, 0.01704705, 0.03384312, 0.03549023, 0.03455489, 0.0084412 , 0.01340478, 0.00975513, 0.01267895, 0.02809188, 0.0162326 , 0.03091601, 0.01952567, 0.00924466, 0.01567532, 0.01916811, 0.02552566, 0.00673686, 0.04554566, 0.03245472, 0.01283099, 0.03262466, 0.02411342, 0.02191526, 0.00963087, 0.03399785, 0.01723422, 0.08456728, 0.02803731, 0.0318799 ]), 'param_alpha': masked_array(data=[0.0007459343285726547, 0.566984951147885, 0.0457056309980145, 0.009846738873614562, 6.0268891286825045e-05, 6.025215736203858e-05, 1.951722464144947e-05, 0.21423021757741043, 0.010129197956845729, 0.034702669886504146, 1.2674255898937214e-05, 0.7072114131472232, 0.1452824663751603, 0.00011526449540315612, 8.111941985431919e-05, 8.260808399079601e-05, 0.00033205591037519585, 0.00420515645091387, 0.0014445251022763056, 0.0002858549394196192, 0.01146210740342503, 4.98275235707645e-05, 0.00028888383623653144, 0.0006789053271698484, 0.0019069966103000427, 0.08431013932082466, 9.962513222055098e-05, 0.0037253938395788847, 0.009163741808778781, 1.7070728830306648e-05, 0.010907475835157693, 7.122305833333864e-05, 2.1147447960615714e-05, 0.5551721685244719, 0.6732248920775336, 0.11015056790269633, 0.0003334792728637581, 3.0786517836196185e-05, 0.026373339933815243, 0.0015876781526923994, 4.075596440072869e-05, 0.0029914693021302154, 1.485739280627924e-05, 0.3520481045526039, 0.00019674328025306103, 0.020540519425388443, 0.00036187233309596217, 0.003984190594434687, 0.005414413211338522, 8.3998644459575e-05], mask=[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], fill_value='?', dtype=object), 'params': [{'alpha': 0.0007459343285726547}, {'alpha': 0.566984951147885}, {'alpha': 0.0457056309980145}, {'alpha': 0.009846738873614562}, {'alpha': 6.0268891286825045e-05}, {'alpha': 6.025215736203858e-05}, {'alpha': 1.951722464144947e-05}, {'alpha': 0.21423021757741043}, {'alpha': 0.010129197956845729}, {'alpha': 0.034702669886504146}, {'alpha': 1.2674255898937214e-05}, {'alpha': 0.7072114131472232}, {'alpha': 0.1452824663751603}, {'alpha': 0.00011526449540315612}, {'alpha': 8.111941985431919e-05}, {'alpha': 8.260808399079601e-05}, {'alpha': 0.00033205591037519585}, {'alpha': 0.00420515645091387}, {'alpha': 0.0014445251022763056}, {'alpha': 0.0002858549394196192}, {'alpha': 0.01146210740342503}, {'alpha': 4.98275235707645e-05}, {'alpha': 0.00028888383623653144}, {'alpha': 0.0006789053271698484}, {'alpha': 0.0019069966103000427}, {'alpha': 0.08431013932082466}, {'alpha': 9.962513222055098e-05}, {'alpha': 0.0037253938395788847}, {'alpha': 0.009163741808778781}, {'alpha': 1.7070728830306648e-05}, {'alpha': 0.010907475835157693}, {'alpha': 7.122305833333864e-05}, {'alpha': 2.1147447960615714e-05}, {'alpha': 0.5551721685244719}, {'alpha': 0.6732248920775336}, {'alpha': 0.11015056790269633}, {'alpha': 0.0003334792728637581}, {'alpha': 3.0786517836196185e-05}, {'alpha': 0.026373339933815243}, {'alpha': 0.0015876781526923994}, {'alpha': 4.075596440072869e-05}, {'alpha': 0.0029914693021302154}, {'alpha': 1.485739280627924e-05}, {'alpha': 0.3520481045526039}, {'alpha': 0.00019674328025306103}, {'alpha': 0.020540519425388443}, {'alpha': 0.00036187233309596217}, {'alpha': 0.003984190594434687}, {'alpha': 0.005414413211338522}, {'alpha': 8.3998644459575e-05}], 'split0_test_score': array([0.64236111, 0.37890625, 0.42751736, 0.51128472, 0.66666667, 0.66666667, 0.67100694, 0.390625 , 0.50954861, 0.44010417, 0.67491319, 0.37934028, 0.39366319, 0.66189236, 0.66536458, 0.66493056, 0.6484375 , 0.55902778, 0.61588542, 0.64973958, 0.50173611, 0.66796875, 0.64930556, 0.64409722, 0.60590278, 0.40581597, 0.6640625 , 0.56857639, 0.51432292, 0.671875 , 0.50347222, 0.66666667, 0.67057292, 0.37890625, 0.37847222, 0.40017361, 0.6484375 , 0.66883681, 0.453125 , 0.61241319, 0.66710069, 0.58029514, 0.67274306, 0.38107639, 0.65451389, 0.46397569, 0.64713542, 0.56467014, 0.55251736, 0.66493056]), 'split1_test_score': array([0.64322917, 0.41666667, 0.4609375 , 0.52690972, 0.66970486, 0.66970486, 0.67881944, 0.43098958, 0.52473958, 0.47178819, 0.68098958, 0.41102431, 0.43967014, 0.66536458, 0.66840278, 0.66883681, 0.65321181, 0.57942708, 0.63064236, 0.65625 , 0.52039931, 0.671875 , 0.65538194, 0.64322917, 0.61848958, 0.44444444, 0.66710069, 0.58723958, 0.53211806, 0.67925347, 0.52300347, 0.66883681, 0.67708333, 0.41710069, 0.41189236, 0.44314236, 0.65321181, 0.67534722, 0.48263889, 0.62890625, 0.67491319, 0.59852431, 0.6796875 , 0.42361111, 0.66102431, 0.49435764, 0.65234375, 0.58246528, 0.56206597, 0.66883681]), 'split2_test_score': array([0.6141115 , 0.38022648, 0.43205575, 0.50827526, 0.65156794, 0.65156794, 0.65766551, 0.40374564, 0.50609756, 0.44337979, 0.66376307, 0.37412892, 0.41463415, 0.64329268, 0.64764808, 0.64764808, 0.62630662, 0.54703833, 0.5966899 , 0.62979094, 0.4956446 , 0.65243902, 0.62979094, 0.61716028, 0.58667247, 0.42116725, 0.64503484, 0.55662021, 0.51393728, 0.65897213, 0.50391986, 0.64982578, 0.65635889, 0.38066202, 0.37587108, 0.41811847, 0.62630662, 0.65374564, 0.45513937, 0.59320557, 0.65287456, 0.56445993, 0.66114983, 0.38850174, 0.6337108 , 0.46689895, 0.625 , 0.55182927, 0.53571429, 0.64764808]), 'split3_test_score': array([0.6271777 , 0.43031359, 0.47343206, 0.52439024, 0.64634146, 0.64634146, 0.6533101 , 0.44642857, 0.52351916, 0.47778746, 0.65679443, 0.42595819, 0.45121951, 0.64764808, 0.64372822, 0.64416376, 0.63719512, 0.56794425, 0.61367596, 0.64155052, 0.5152439 , 0.646777 , 0.64155052, 0.63022648, 0.60191638, 0.46341463, 0.64634146, 0.57317073, 0.5283101 , 0.6533101 , 0.51872822, 0.64503484, 0.6533101 , 0.43074913, 0.42595819, 0.45949477, 0.63719512, 0.64982578, 0.48562718, 0.60932056, 0.64939024, 0.58188153, 0.65418118, 0.44033101, 0.64416376, 0.49433798, 0.63632404, 0.57012195, 0.55444251, 0.64416376]), 'split4_test_score': array([0.60409408, 0.41419861, 0.45862369, 0.50566202, 0.63545296, 0.63545296, 0.64067944, 0.42857143, 0.5043554 , 0.46602787, 0.64067944, 0.4033101 , 0.43423345, 0.62848432, 0.6337108 , 0.6337108 , 0.6184669 , 0.54268293, 0.58580139, 0.61890244, 0.50130662, 0.63675958, 0.61890244, 0.60670732, 0.58057491, 0.44163763, 0.62891986, 0.54834495, 0.50783972, 0.63937282, 0.50348432, 0.63414634, 0.6402439 , 0.41463415, 0.4033101 , 0.43728223, 0.6184669 , 0.63850174, 0.47212544, 0.58362369, 0.6380662 , 0.55879791, 0.6402439 , 0.42203833, 0.62369338, 0.47996516, 0.61672474, 0.54442509, 0.53222997, 0.63327526]), 'mean_test_score': array([0.62619471, 0.40406232, 0.45051327, 0.51530439, 0.65394678, 0.65394678, 0.66029629, 0.42007205, 0.51365206, 0.4598175 , 0.66342794, 0.39875236, 0.42668409, 0.64933641, 0.65177089, 0.651858 , 0.63672359, 0.55922407, 0.60853901, 0.6392467 , 0.50686611, 0.65516387, 0.63898628, 0.62828409, 0.59871122, 0.43529599, 0.65029187, 0.56679037, 0.51930562, 0.6605567 , 0.51052162, 0.65290209, 0.65951383, 0.40441045, 0.39910079, 0.43164229, 0.63672359, 0.65725144, 0.46973117, 0.60549385, 0.65646898, 0.57679176, 0.66160109, 0.41111172, 0.64342123, 0.47990708, 0.63550559, 0.56270234, 0.54739402, 0.65177089]), 'std_test_score': array([0.01540736, 0.02074472, 0.01771556, 0.00866923, 0.01277211, 0.01277211, 0.01340428, 0.02009865, 0.00872502, 0.01525516, 0.01415579, 0.01946625, 0.02031353, 0.01333222, 0.01318527, 0.01315612, 0.01304699, 0.01346361, 0.01566089, 0.01347457, 0.00934325, 0.01311266, 0.01318936, 0.01460465, 0.0136164 , 0.01991613, 0.01394397, 0.01345909, 0.00927742, 0.01401038, 0.00855507, 0.01316445, 0.01304424, 0.02085063, 0.0193278 , 0.0205572 , 0.01304699, 0.01327147, 0.01351821, 0.01574781, 0.01307603, 0.01404518, 0.01387552, 0.02255094, 0.01354214, 0.0129622 , 0.01328559, 0.01342966, 0.01146787, 0.01327687]), 'rank_test_score': array([24, 48, 41, 34, 9, 9, 4, 45, 35, 40, 1, 50, 44, 16, 13, 12, 20, 31, 25, 18, 37, 8, 19, 23, 27, 42, 15, 29, 33, 3, 36, 11, 5, 47, 49, 43, 20, 6, 39, 26, 7, 28, 2, 46, 17, 38, 22, 30, 32, 13])}
Finally, we increase the reservoir size and compare the impact of uni- and bidirectional ESNs. Notice that the ESN strongly benefit from both, increasing the reservoir size and from the bi-directional working mode.
param_grid = {'hidden_layer_size': [50, 100, 200, 400, 500],
'bidirectional': [False, True]}
print("CV results\tFit time\tInference time\tAccuracy score\tSize[Bytes]")
for params in ParameterGrid(param_grid):
esn_cv = cross_validate(clone(base_esn).set_params(**params), X=X_train, y=y_train,
scoring=make_scorer(accuracy_score), n_jobs=-1)
t1 = time.time()
esn = clone(base_esn).set_params(**params).fit(X_train, y_train)
t_fit = time.time() - t1
t1 = time.time()
esn_par = clone(base_esn).set_params(**params).fit(X_train, y_train, n_jobs=-1)
t_fit_par = time.time() - t1
mem_size = esn.__sizeof__()
t1 = time.time()
acc_score = accuracy_score(y_test, esn.predict(X_test))
t_inference = time.time() - t1
print(f"{esn_cv}\t{t_fit}\t{t_inference}\t{acc_score}\t{mem_size}")
CV results Fit time Inference time Accuracy score Size[Bytes] {'fit_time': array([1.20268822, 1.31980658, 1.17576218, 1.22178006, 1.24531698]), 'score_time': array([0.20427084, 0.1787076 , 0.19799089, 0.21658278, 0.17672634]), 'test_score': array([0.91319444, 0.91666667, 0.89547038, 0.91289199, 0.87108014])} 1.092219591140747 0.16528797149658203 0.925 29892 {'fit_time': array([1.34596086, 1.32152963, 1.32597208, 1.34012151, 1.38048863]), 'score_time': array([0.22114849, 0.2336669 , 0.17909884, 0.2270937 , 0.20272541]), 'test_score': array([0.94444444, 0.93402778, 0.93728223, 0.92334495, 0.91289199])} 1.1797807216644287 0.16232609748840332 0.9416666666666667 99092 {'fit_time': array([1.71792817, 1.72016811, 1.70596719, 1.65824962, 1.68945742]), 'score_time': array([0.23715806, 0.23139477, 0.24706268, 0.26664233, 0.25166321]), 'test_score': array([0.95138889, 0.94791667, 0.94773519, 0.94425087, 0.91986063])} 1.5785596370697021 0.2425389289855957 0.9611111111111111 357492 {'fit_time': array([4.3435142 , 4.2941072 , 4.47814965, 4.45681834, 4.38268828]), 'score_time': array([0.47901511, 0.49599099, 0.41987824, 0.42973256, 0.44941235]), 'test_score': array([0.94791667, 0.9375 , 0.94425087, 0.95470383, 0.92334495])} 2.862438201904297 0.22818946838378906 0.9638888888888889 1354292 {'fit_time': array([5.95386696, 6.16443706, 6.10848475, 6.10720062, 6.09413409]), 'score_time': array([0.45308208, 0.40109324, 0.41994405, 0.4118824 , 0.42954731]), 'test_score': array([0.94791667, 0.94791667, 0.94773519, 0.95470383, 0.92334495])} 3.914069175720215 0.3445422649383545 0.9611111111111111 2092692 {'fit_time': array([3.28861213, 3.14525843, 2.96652579, 3.16590023, 3.28161478]), 'score_time': array([0.70867753, 0.73422313, 0.88020873, 0.7299583 , 0.66681576]), 'test_score': array([0.94791667, 0.95138889, 0.94076655, 0.91986063, 0.8989547 ])} 1.941004991531372 0.3332488536834717 0.9083333333333333 98692 {'fit_time': array([3.2477181 , 3.34564877, 3.25544882, 3.29266453, 3.40397859]), 'score_time': array([0.66511965, 0.68270802, 0.64036489, 0.65441465, 0.61603808]), 'test_score': array([0.97222222, 0.95833333, 0.95818815, 0.92682927, 0.93379791])} 1.932464599609375 0.3669440746307373 0.9305555555555556 356692 {'fit_time': array([5.44010997, 5.43559408, 5.36100101, 5.35568452, 5.48375392]), 'score_time': array([0.6066854 , 0.59092546, 0.6019001 , 0.60729814, 0.56611323]), 'test_score': array([0.97569444, 0.96527778, 0.96864111, 0.95121951, 0.94425087])} 3.334501028060913 0.49420928955078125 0.9583333333333334 1352692 {'fit_time': array([14.56774807, 14.57170606, 14.46315455, 14.58913779, 14.54430723]), 'score_time': array([0.93357658, 0.9349041 , 0.92975593, 0.94193864, 0.91288638]), 'test_score': array([0.97916667, 0.96180556, 0.97560976, 0.95818815, 0.94425087])} 7.724725008010864 0.6274893283843994 0.9583333333333334 5264692 {'fit_time': array([20.9128716 , 20.93491912, 20.86150718, 21.00014853, 20.94348502]), 'score_time': array([0.81963921, 0.8159306 , 0.84687257, 0.80614424, 0.81322145]), 'test_score': array([0.97569444, 0.96527778, 0.97212544, 0.95121951, 0.95470383])} 11.71635103225708 0.8147547245025635 0.9694444444444444 8180692
Alternatively, we can also use a PyTorch implementation of the ESN model
from pyrcn.nn import ESN
import torch
base_esn = ESN(input_size=8, hidden_size=50, num_layers=1, nonlinearity='tanh',
bias=True, input_scaling=0.016952130531190705,
spectral_radius=1.0214946051551315, bias_scaling=1.500499867548985,
bias_shift=0., input_sparsity=0.1, recurrent_sparsity=0.1,
bidirectional=False)
for x, y in zip(X_train, y_train):
x = torch.Tensor(x).float()
print(base_esn(x)[0].shape)
break
torch.Size([8, 50])