Neural network model to predict TA¶

Created by Ivan Lima on Mon Jan 23 2023 22:41:50 -0500

This version of the neural network model does not include dissolved oxygen as an input feature.

In [1]:
%matplotlib inline
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os, datetime, warnings
print('Last updated on {}'.format(datetime.datetime.now().ctime()))
Last updated on Sun Feb  5 15:52:03 2023
In [2]:
import sns_settings
sns.set_context('paper')
pd.options.display.max_columns = 50
warnings.filterwarnings('ignore')

Load TA bottle data¶

In [3]:
df_bottle_ta = pd.read_csv('data/bottle_data_TA_prepared.csv', parse_dates=['Date'], index_col=0, na_values=['<undefined>',-9999.])
df_bottle_ta['log_Chl'] = np.log(df_bottle_ta.Chl)
df_bottle_ta['log_KD490'] = np.log(df_bottle_ta.KD490)

Select input features and target variable¶

In [4]:
features = ['Depth', 'Temperature', 'Salinity', 'pCO2_atm', 'ADT', 'SST_hires', 'log_KD490']
target = ['TALK']
varlist = features + target
fg = sns.pairplot(df_bottle_ta, vars=varlist, hue='Season', plot_kws={'alpha':0.7}, diag_kind='hist')
In [5]:
data = df_bottle_ta[varlist]
corr_mat = data.corr()
fig, ax = plt.subplots(figsize=(7,7))
_ = sns.heatmap(corr_mat, ax=ax, cmap='vlag', center=0, square=True, annot=True, annot_kws={'fontsize':9})
_ = ax.set_title('Correlation')

Split data into training and test sets¶

In [6]:
from sklearn.model_selection import train_test_split, cross_val_score

data = df_bottle_ta[features + target + ['Season']].dropna()

X = data[features].values
y = data[target].values

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=data.Season.values, random_state=77)
X.shape, X_train.shape, X_test.shape, y_train.shape, y_test.shape
Out[6]:
((4151, 7), (3113, 7), (1038, 7), (3113, 1), (1038, 1))

Train Neural Network regression¶

In [7]:
import tensorflow as tf
from tensorflow import keras

keras.utils.set_random_seed(42) # make things reproducible
n_hidden = 256 # number of nodes in hidden layers
alpha=0.01

model = keras.models.Sequential([
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(y_train.shape[1])
])

early_stopping_cb = keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True)
model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())
history = model.fit(X_train, y_train, epochs=700, verbose=2, validation_split=0.2, callbacks=[early_stopping_cb])
Epoch 1/700
2023-02-05 15:52:55.286987: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
78/78 - 1s - loss: 5058212.0000 - val_loss: 4755282.0000 - 1s/epoch - 16ms/step
Epoch 2/700
78/78 - 0s - loss: 5024661.0000 - val_loss: 4625778.0000 - 214ms/epoch - 3ms/step
Epoch 3/700
78/78 - 0s - loss: 4959320.5000 - val_loss: 4672549.0000 - 218ms/epoch - 3ms/step
Epoch 4/700
78/78 - 0s - loss: 4857510.0000 - val_loss: 4640279.5000 - 210ms/epoch - 3ms/step
Epoch 5/700
78/78 - 0s - loss: 4720803.0000 - val_loss: 4542702.5000 - 205ms/epoch - 3ms/step
Epoch 6/700
78/78 - 0s - loss: 4552594.0000 - val_loss: 4390247.5000 - 220ms/epoch - 3ms/step
Epoch 7/700
78/78 - 0s - loss: 4356885.0000 - val_loss: 4217663.5000 - 217ms/epoch - 3ms/step
Epoch 8/700
78/78 - 0s - loss: 4137693.7500 - val_loss: 4006994.2500 - 214ms/epoch - 3ms/step
Epoch 9/700
78/78 - 0s - loss: 3899381.7500 - val_loss: 3764762.2500 - 218ms/epoch - 3ms/step
Epoch 10/700
78/78 - 0s - loss: 3646258.7500 - val_loss: 3517949.5000 - 221ms/epoch - 3ms/step
Epoch 11/700
78/78 - 0s - loss: 3382718.5000 - val_loss: 3242001.7500 - 221ms/epoch - 3ms/step
Epoch 12/700
78/78 - 0s - loss: 3112779.7500 - val_loss: 2968738.0000 - 209ms/epoch - 3ms/step
Epoch 13/700
78/78 - 0s - loss: 2840525.5000 - val_loss: 2703481.0000 - 208ms/epoch - 3ms/step
Epoch 14/700
78/78 - 0s - loss: 2569675.0000 - val_loss: 2435161.5000 - 214ms/epoch - 3ms/step
Epoch 15/700
78/78 - 0s - loss: 2303867.2500 - val_loss: 2174971.0000 - 213ms/epoch - 3ms/step
Epoch 16/700
78/78 - 0s - loss: 2046316.3750 - val_loss: 1918482.8750 - 203ms/epoch - 3ms/step
Epoch 17/700
78/78 - 0s - loss: 1799702.2500 - val_loss: 1679871.0000 - 204ms/epoch - 3ms/step
Epoch 18/700
78/78 - 0s - loss: 1566482.2500 - val_loss: 1455805.8750 - 200ms/epoch - 3ms/step
Epoch 19/700
78/78 - 0s - loss: 1348736.0000 - val_loss: 1243942.3750 - 198ms/epoch - 3ms/step
Epoch 20/700
78/78 - 0s - loss: 1147907.2500 - val_loss: 1046972.1250 - 204ms/epoch - 3ms/step
Epoch 21/700
78/78 - 0s - loss: 965102.4375 - val_loss: 878818.6250 - 209ms/epoch - 3ms/step
Epoch 22/700
78/78 - 0s - loss: 801014.7500 - val_loss: 726779.8125 - 207ms/epoch - 3ms/step
Epoch 23/700
78/78 - 0s - loss: 655753.1250 - val_loss: 587985.0000 - 211ms/epoch - 3ms/step
Epoch 24/700
78/78 - 0s - loss: 529128.6875 - val_loss: 474185.0000 - 207ms/epoch - 3ms/step
Epoch 25/700
78/78 - 0s - loss: 420390.5938 - val_loss: 375147.9062 - 206ms/epoch - 3ms/step
Epoch 26/700
78/78 - 0s - loss: 328732.5938 - val_loss: 288946.5625 - 212ms/epoch - 3ms/step
Epoch 27/700
78/78 - 0s - loss: 252586.5625 - val_loss: 219471.0312 - 214ms/epoch - 3ms/step
Epoch 28/700
78/78 - 0s - loss: 190597.8438 - val_loss: 162022.1875 - 211ms/epoch - 3ms/step
Epoch 29/700
78/78 - 0s - loss: 141218.2812 - val_loss: 120601.9766 - 214ms/epoch - 3ms/step
Epoch 30/700
78/78 - 0s - loss: 102554.3125 - val_loss: 85881.3906 - 223ms/epoch - 3ms/step
Epoch 31/700
78/78 - 0s - loss: 72992.0234 - val_loss: 60718.5234 - 223ms/epoch - 3ms/step
Epoch 32/700
78/78 - 0s - loss: 50887.9375 - val_loss: 41942.8164 - 226ms/epoch - 3ms/step
Epoch 33/700
78/78 - 0s - loss: 34753.2695 - val_loss: 27812.6484 - 217ms/epoch - 3ms/step
Epoch 34/700
78/78 - 0s - loss: 23226.8633 - val_loss: 18206.6152 - 215ms/epoch - 3ms/step
Epoch 35/700
78/78 - 0s - loss: 15251.6494 - val_loss: 11946.8994 - 222ms/epoch - 3ms/step
Epoch 36/700
78/78 - 0s - loss: 9802.7051 - val_loss: 7273.3999 - 220ms/epoch - 3ms/step
Epoch 37/700
78/78 - 0s - loss: 6164.3086 - val_loss: 4223.9810 - 224ms/epoch - 3ms/step
Epoch 38/700
78/78 - 0s - loss: 3888.6680 - val_loss: 2749.3240 - 205ms/epoch - 3ms/step
Epoch 39/700
78/78 - 0s - loss: 2397.9233 - val_loss: 1687.9000 - 206ms/epoch - 3ms/step
Epoch 40/700
78/78 - 0s - loss: 1557.4573 - val_loss: 899.3726 - 200ms/epoch - 3ms/step
Epoch 41/700
78/78 - 0s - loss: 1029.5681 - val_loss: 595.7206 - 208ms/epoch - 3ms/step
Epoch 42/700
78/78 - 0s - loss: 656.4769 - val_loss: 368.6837 - 215ms/epoch - 3ms/step
Epoch 43/700
78/78 - 0s - loss: 510.5060 - val_loss: 220.0774 - 205ms/epoch - 3ms/step
Epoch 44/700
78/78 - 0s - loss: 385.8689 - val_loss: 165.4612 - 204ms/epoch - 3ms/step
Epoch 45/700
78/78 - 0s - loss: 339.1792 - val_loss: 123.6520 - 211ms/epoch - 3ms/step
Epoch 46/700
78/78 - 0s - loss: 319.6021 - val_loss: 136.8729 - 207ms/epoch - 3ms/step
Epoch 47/700
78/78 - 0s - loss: 374.0228 - val_loss: 145.6210 - 210ms/epoch - 3ms/step
Epoch 48/700
78/78 - 0s - loss: 307.1165 - val_loss: 93.8999 - 208ms/epoch - 3ms/step
Epoch 49/700
78/78 - 0s - loss: 284.6694 - val_loss: 98.5620 - 207ms/epoch - 3ms/step
Epoch 50/700
78/78 - 0s - loss: 347.5403 - val_loss: 123.7102 - 206ms/epoch - 3ms/step
Epoch 51/700
78/78 - 0s - loss: 373.5650 - val_loss: 114.5216 - 199ms/epoch - 3ms/step
Epoch 52/700
78/78 - 0s - loss: 325.6646 - val_loss: 104.8159 - 204ms/epoch - 3ms/step
Epoch 53/700
78/78 - 0s - loss: 351.9190 - val_loss: 152.5561 - 205ms/epoch - 3ms/step
Epoch 54/700
78/78 - 0s - loss: 278.3582 - val_loss: 118.0578 - 201ms/epoch - 3ms/step
Epoch 55/700
78/78 - 0s - loss: 318.8863 - val_loss: 105.6264 - 207ms/epoch - 3ms/step
Epoch 56/700
78/78 - 0s - loss: 343.8539 - val_loss: 145.2590 - 216ms/epoch - 3ms/step
Epoch 57/700
78/78 - 0s - loss: 306.9551 - val_loss: 92.5853 - 211ms/epoch - 3ms/step
Epoch 58/700
78/78 - 0s - loss: 284.6496 - val_loss: 97.5956 - 213ms/epoch - 3ms/step
Epoch 59/700
78/78 - 0s - loss: 390.7843 - val_loss: 98.4494 - 214ms/epoch - 3ms/step
Epoch 60/700
78/78 - 0s - loss: 339.9972 - val_loss: 108.5916 - 212ms/epoch - 3ms/step
Epoch 61/700
78/78 - 0s - loss: 296.4086 - val_loss: 129.2590 - 214ms/epoch - 3ms/step
Epoch 62/700
78/78 - 0s - loss: 358.6159 - val_loss: 100.5396 - 207ms/epoch - 3ms/step
Epoch 63/700
78/78 - 0s - loss: 344.7013 - val_loss: 99.9832 - 212ms/epoch - 3ms/step
Epoch 64/700
78/78 - 0s - loss: 319.8690 - val_loss: 87.3222 - 213ms/epoch - 3ms/step
Epoch 65/700
78/78 - 0s - loss: 260.9448 - val_loss: 111.4705 - 217ms/epoch - 3ms/step
Epoch 66/700
78/78 - 0s - loss: 339.0999 - val_loss: 94.1067 - 214ms/epoch - 3ms/step
Epoch 67/700
78/78 - 0s - loss: 336.1225 - val_loss: 109.5571 - 210ms/epoch - 3ms/step
Epoch 68/700
78/78 - 0s - loss: 348.8043 - val_loss: 115.7898 - 214ms/epoch - 3ms/step
Epoch 69/700
78/78 - 0s - loss: 309.7022 - val_loss: 85.6981 - 210ms/epoch - 3ms/step
Epoch 70/700
78/78 - 0s - loss: 276.8446 - val_loss: 94.5712 - 212ms/epoch - 3ms/step
Epoch 71/700
78/78 - 0s - loss: 295.1182 - val_loss: 90.5230 - 216ms/epoch - 3ms/step
Epoch 72/700
78/78 - 0s - loss: 270.6868 - val_loss: 146.0694 - 215ms/epoch - 3ms/step
Epoch 73/700
78/78 - 0s - loss: 300.1377 - val_loss: 93.0880 - 207ms/epoch - 3ms/step
Epoch 74/700
78/78 - 0s - loss: 306.6849 - val_loss: 93.5830 - 213ms/epoch - 3ms/step
Epoch 75/700
78/78 - 0s - loss: 349.6228 - val_loss: 148.4473 - 215ms/epoch - 3ms/step
Epoch 76/700
78/78 - 0s - loss: 258.3727 - val_loss: 111.9083 - 215ms/epoch - 3ms/step
Epoch 77/700
78/78 - 0s - loss: 281.9753 - val_loss: 89.0754 - 225ms/epoch - 3ms/step
Epoch 78/700
78/78 - 0s - loss: 339.8921 - val_loss: 93.9596 - 218ms/epoch - 3ms/step
Epoch 79/700
78/78 - 0s - loss: 274.0501 - val_loss: 148.7480 - 214ms/epoch - 3ms/step
Epoch 80/700
78/78 - 0s - loss: 351.6968 - val_loss: 119.0440 - 215ms/epoch - 3ms/step
Epoch 81/700
78/78 - 0s - loss: 345.4958 - val_loss: 89.2260 - 211ms/epoch - 3ms/step
Epoch 82/700
78/78 - 0s - loss: 257.3540 - val_loss: 98.9971 - 214ms/epoch - 3ms/step
Epoch 83/700
78/78 - 0s - loss: 305.4283 - val_loss: 93.4942 - 215ms/epoch - 3ms/step
Epoch 84/700
78/78 - 0s - loss: 267.3723 - val_loss: 104.7438 - 212ms/epoch - 3ms/step
Epoch 85/700
78/78 - 0s - loss: 327.3465 - val_loss: 113.2769 - 218ms/epoch - 3ms/step
Epoch 86/700
78/78 - 0s - loss: 351.1830 - val_loss: 90.6680 - 224ms/epoch - 3ms/step
Epoch 87/700
78/78 - 0s - loss: 317.8509 - val_loss: 91.0121 - 227ms/epoch - 3ms/step
Epoch 88/700
78/78 - 0s - loss: 312.4104 - val_loss: 108.2114 - 229ms/epoch - 3ms/step
Epoch 89/700
78/78 - 0s - loss: 255.5267 - val_loss: 87.7622 - 221ms/epoch - 3ms/step

Save trained model¶

In [8]:
model.save('models/nn_regression_ta_noO2.h5')

Learning curve¶

In [9]:
df_history = pd.DataFrame(history.history)
df_history.index.name = 'epoch'
df_history = df_history.reset_index()
df_history.to_csv('results/nn_regression_history_ta_noO2.csv')

fig, ax = plt.subplots(figsize=(6, 6))
_ = sns.lineplot(x=df_history.epoch-0.5, y='loss', data=df_history, ax=ax, label='training set')
_ = sns.lineplot(x=df_history.epoch, y='val_loss', data=df_history, ax=ax, label='validation set')
_ = ax.set(ylabel = 'MSE')
# _ = ax.set(yscale='log')

MSE & $R^2$¶

In [10]:
from sklearn.metrics import r2_score

print('MSE on training set = {:.2f}'.format(model.evaluate(X_train, y_train, verbose=0)))
print('MSE on test set     = {:.2f}\n'.format(model.evaluate(X_test, y_test, verbose=0)))

y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

print('R squared on training set = {:.3f}'.format(r2_score(y_train, y_pred_train)))
print('R squared on test set     = {:.3f}'.format(r2_score(y_test, y_pred_test)))
MSE on training set = 74.67
MSE on test set     = 93.41

R squared on training set = 0.987
R squared on test set     = 0.985
In [11]:
fig, ax = plt.subplots(figsize=(6,6))
_ = sns.scatterplot(x=y_test.ravel(), y=y_pred_test.ravel(), ax=ax)
_ = ax.set(xlabel='observed TA', ylabel='predicted TA', title='Test dataset')
_ = ax.axis('equal')
In [12]:
# save test set features, target & predictions
df_test = pd.DataFrame(np.c_[X_test, y_test, y_pred_test], columns = features + ['TA observed', 'TA predicted'])
df_test['TA residuals'] = df_test['TA observed'] - df_test['TA predicted']
df_test.to_csv('results/bottle_data_test_ta_noO2.csv')

K-fold cross-validation¶

In [13]:
from sklearn.model_selection import KFold

kf = KFold(n_splits=5, shuffle=True, random_state=42)
score_vals = [] # store score values

nn_reg = keras.models.Sequential([
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(y_train.shape[1])
])
nn_reg.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())

for k, (train_idx, test_idx) in enumerate(kf.split(X_train)):
    X_tr, X_te = X_train[train_idx], X_train[test_idx]
    y_tr, y_te = y_train[train_idx], y_train[test_idx]
    history_cv = nn_reg.fit(X_tr, y_tr, epochs=700, verbose=0, validation_split=0.2, callbacks=[early_stopping_cb])
    y_pred = nn_reg.predict(X_te)
    score = r2_score(y_te, y_pred)
    score_vals.append(score)
    print('Fold {} test set R squared: {:.3f}'.format(k+1, score))

scores = np.array(score_vals)
print('\nBest R squared:  {:.3f}'.format(scores.max()))
print('Worst R squared: {:.3f}'.format(scores.min()))
print('Mean R squared:  {:.3f}'.format(scores.mean()))
Fold 1 test set R squared: 0.975
Fold 2 test set R squared: 0.986
Fold 3 test set R squared: 0.986
Fold 4 test set R squared: 0.984
Fold 5 test set R squared: 0.986

Best R squared:  0.986
Worst R squared: 0.975
Mean R squared:  0.983