#!/usr/bin/env python # coding: utf-8 # ## Model and Related Functions # First up is a callback for tracking loss history. This is helpful because we can use it to view how training is progressing. The specifics aren't too important, but it can be altered to trigger at different frequencies and on different events (like an epoch end vs. a batch end). # In[69]: class LossHistory(tf.keras.callbacks.Callback): def __init__(self, frequency=30): self.frequency = frequency self.batch_counter = 0 self.i = 0 self.x = [] self.mae = [] self.mse = [] self.accuracy = [] self.logs = [] def on_train_begin(self, logs={}): self.fig = plt.figure() def on_batch_end(self, batch, logs={}): # if batch % self.frequency == 0: self.batch_counter += self.frequency self.x.append(self.i) self.mae.append(logs.get('mae')) self.mse.append(logs.get('mse')) self.accuracy.append(logs.get('accuracy')) self.i += 1 if batch % self.frequency == 0: self.logs.append(logs) clear_output(wait=True) print(f"MAE: {self.mae[-1]} \t\t MSE: {self.mse[-1]} \t\t Accuracy: {self.accuracy[-1]}") plt.figure(figsize=(18,5)) plt.subplot(131) plt.plot(self.x, self.mae, color='#ff6347',label="mae") plt.plot(self.x[-1], self.mae[-1],marker = 'o', markersize=10, color='#ff6347') plt.legend() plt.xlabel(r'batch'); plt.ylabel('Mean Absolute Error'); plt.ylim([0.,100.]) plt.subplot(132) plt.plot(self.x, self.mse, color='#6495ed') plt.plot(self.x[-1], self.mse[-1],marker = 'o', markersize=10, color='#6495ed') plt.xlabel('batch') plt.ylabel(r'Mean Squared Error [$cm^2/s^2$]') plt.ylim([0.,1000.]) plt.subplot(133) plt.plot(self.x, self.accuracy, color='#3cb371') plt.plot(self.x[-1], self.accuracy[-1],marker = 'o', markersize=10, color='#3cb371') plt.xlabel('batch') plt.ylabel('Model Accuracy') plt.ylim([0.,1.]) plt.show() # We define a custom loss function to take the MAE of multidimensional objects. (This should be able to filter NaNs, thereby solving the coastline problem, but this doesn't work!). Not sure if this gives different answers from a standard MAE in the case that no NaNs are present. When NaNs are present, this should give different results from using `.fillna(0)` on the training data, because it won't entrain zeros when the model takes a convolution. But does it actually work? ¯\\\_(ツ)_/¯ # In[70]: # Corner case: what happens when everything is NaN? class Grid_MAE(tf.keras.losses.Loss): def call(self, y_true, y_pred): avg = tf.math.abs(y_true - y_pred) masked = tf.where(tf.math.is_finite(avg), avg, tf.zeros_like(avg)) return tf.math.reduce_sum(masked) # The `get_model` function generates a neural network based on Sinha and Abernathey (2021), but offers some parameters to enable a broader class of neural networks of similar form. # In[71]: def get_model(halo_size, ds, sc, conv_dims, nfilters, conv_kernels, dense_layers): conv_init = tf.keras.Input(shape=tuple(conv_dims) + (len(sc.conv_var),)) last_layer = conv_init for kernel in conv_kernels: this_layer = tf.keras.layers.Conv2D(nfilters, kernel)(last_layer) last_layer = this_layer nfilters = nfilters / 2. halo_dims = [x - 2*halo_size for x in conv_dims] input_init = tf.keras.Input(shape=tuple(halo_dims) + (len(sc.input_var),)) last_layer = tf.keras.layers.concatenate([last_layer, input_init]) last_layer = tf.keras.layers.LeakyReLU(alpha=0.3)(last_layer) for layer in range(dense_layers): this_layer = tf.keras.layers.Dense(nfilters, activation='relu')(last_layer) last_layer = this_layer nfilters = nfilters / 2. output_layer = tf.keras.layers.Dense(len(sc.target))(last_layer) model = tf.keras.Model(inputs=[conv_init, input_init], outputs=output_layer) opt = tf.keras.optimizers.Adam(learning_rate=1e-3) model.compile(loss=Grid_MAE(), optimizer=opt, metrics=['mae', 'mse', 'accuracy']) model.summary() return model # In[ ]: def train(ds, sc, conv_dims=[3,3], nfilters=80, conv_kernels=[3], dense_layers=3): pars = locals() halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2) nlons, nlats = conv_dims # bgen = xb.BatchGenerator( # ds, # {'nlon':nlons, 'nlat':nlats}, # {'nlon':2*halo_size, 'nlat':2*halo_size}, # concat_input_dims=True # ) latlen = len(ds['nlat']) lonlen = len(ds['nlon']) nlon_range = range(nlons,lonlen,nlons - 2*halo_size) nlat_range = range(nlats,latlen,nlats - 2*halo_size) batch = ( ds .rolling({"nlat": nlats, "nlon": nlons}) .construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}] .stack({"input_batch": ("nlat", "nlon")}, create_index=False) .rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'}) .transpose('input_batch',...) # .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons}) .dropna('input_batch') ) rnds = list(range(len(batch['input_batch']))) np.random.shuffle(rnds) batch = batch[{'input_batch':(rnds)}] def batch_generator(batch_set, batch_size): n = 0 while n < len(batch_set['input_batch']) - batch_size: yield batch_set.isel({'input_batch':range(n,(n+batch_size))}) n += batch_size # We need this subsetting stencil to compensate for the fact that a halo is # removed by each convolution layer. This means that the input_var variables # will be the wrong size at the concat layer unless we strip a halo from them sub = {'nlon_input':range(halo_size,nlons-halo_size), 'nlat_input':range(halo_size,nlats-halo_size)} model = get_model(halo_size, **pars) history = LossHistory() bgen = batch_generator(batch, 4096) for batch in bgen: batch_conv = [batch[x] for x in sc.conv_var] batch_input = [batch[x][sub] for x in sc.input_var] batch_target = [batch[x][sub] for x in sc.target] batch_conv = xr.merge(batch_conv).to_array('var').transpose(...,'var') batch_input = xr.merge(batch_input).to_array('var').transpose(...,'var') batch_target = xr.merge(batch_target).to_array('var').transpose(...,'var') clear_output(wait=True) model.fit([batch_conv, batch_input], batch_target, batch_size=32, verbose=0, #epochs=4, callbacks=[history]) model.save('models/'+ sc.name) np.savez('models/history_'+sc.name, losses=history.mae, mse=history.mse, accuracy=history.accuracy) return model, history # In[ ]: def test(ds, sc, conv_dims=[3,3], conv_kernels=[3]): halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2) nlons, nlats = conv_dims latlen = len(ds['nlat']) lonlen = len(ds['nlon']) nlon_range = range(nlons,lonlen,nlons - 2*halo_size) nlat_range = range(nlats,latlen,nlats - 2*halo_size) batch = ( ds .rolling({"nlat": nlats, "nlon": nlons}) .construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}] .stack({"input_batch": ("nlat", "nlon")}, create_index=False) .rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'}) .transpose('input_batch',...) # .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons}) .dropna('input_batch') ) model = tf.keras.models.load_model('models/'+ sc.name, custom_objects={'Grid_MAE':Grid_MAE}) batch_conv = [batch[x] for x in sc.conv_var] batch_input = [batch[x][sub] for x in sc.input_var] batch_target = [batch[x][sub] for x in sc.target] batch_conv = xr.merge(batch_conv ).to_array('var').transpose(...,'var') batch_input = xr.merge(batch_input ).to_array('var').transpose(...,'var') batch_target = xr.merge(batch_target).to_array('var').transpose(...,'var') model.evaluate([batch_conv, batch_input], batch_target) pass # In[ ]: def predict(ds, sc, conv_dims=[3,3], conv_kernels=[3]): halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2) nlons, nlats = conv_dims latlen = len(ds['nlat']) lonlen = len(ds['nlon']) nlon_range = range(nlons,lonlen,nlons - 2*halo_size) nlat_range = range(nlats,latlen,nlats - 2*halo_size) batch = ( ds .rolling({"nlat": nlats, "nlon": nlons}) .construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}] .stack({"input_batch": ("nlat", "nlon")}, create_index=False) .rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'}) .transpose('input_batch',...) # .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons}) .dropna('input_batch') ) model = tf.keras.models.load_model('models/'+ sc.name, custom_objects={'Grid_MAE':Grid_MAE}) batch_conv = [batch[x] for x in sc.conv_var] batch_input = [batch[x][sub] for x in sc.input_var] batch_conv = xr.merge(batch_conv ).to_array('var').transpose(...,'var') batch_input = xr.merge(batch_input ).to_array('var').transpose(...,'var') target = model.predict([batch_conv, batch_input]) return target