Neural network model to predict DIC¶

Created by Ivan Lima on Tue Jan 17 2023 19:23:16 -0500

In [1]:
%matplotlib inline
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os, datetime, warnings
print('Last updated on {}'.format(datetime.datetime.now().ctime()))
Last updated on Sun Feb  5 13:18:56 2023
In [2]:
import sns_settings
sns.set_context('paper')
pd.options.display.max_columns = 50
warnings.filterwarnings('ignore')

Load DIC bottle data¶

In [3]:
df_bottle_dic = pd.read_csv('data/bottle_data_DIC_prepared.csv', parse_dates=['Date'],
                            index_col=0, na_values=['<undefined>',-9999.])
df_bottle_dic = df_bottle_dic.loc[df_bottle_dic.Oxygen_flag.isin([2, 6])]
df_bottle_dic = df_bottle_dic.loc[df_bottle_dic.Oxygen.notnull()]
df_bottle_dic['log_Chl'] = np.log(df_bottle_dic.Chl)
df_bottle_dic['log_KD490'] = np.log(df_bottle_dic.KD490)

Select input features and target variable¶

In [4]:
features = ['Depth', 'Temperature', 'Salinity', 'Oxygen', 'pCO2_atm', 'ADT', 'SST_hires', 'log_KD490']
target = ['DIC']
varlist = features + target
fg = sns.pairplot(df_bottle_dic, vars=varlist, hue='Season', plot_kws={'alpha':0.7}, diag_kind='hist')
In [5]:
data = df_bottle_dic[varlist]
corr_mat = data.corr()
fig, ax = plt.subplots(figsize=(7,7))
_ = sns.heatmap(corr_mat, ax=ax, cmap='vlag', center=0, square=True, annot=True, annot_kws={'fontsize':9})
_ = ax.set_title('Correlation')

Split data into training and test sets¶

In [6]:
from sklearn.model_selection import train_test_split, cross_val_score

data = df_bottle_dic[features + target + ['Season']].dropna()

X = data[features].values
y = data[target].values

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=data.Season.values, random_state=77)
X.shape, X_train.shape, X_test.shape, y_train.shape, y_test.shape
Out[6]:
((3970, 8), (2977, 8), (993, 8), (2977, 1), (993, 1))

Train Neural Network regression¶

In [7]:
import tensorflow as tf
from tensorflow import keras

keras.utils.set_random_seed(42) # make things reproducible
n_hidden = 256 # number of nodes in hidden layers
alpha=0.01

model = keras.models.Sequential([
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(y_train.shape[1])
])

early_stopping_cb = keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True)
model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())
history = model.fit(X_train, y_train, epochs=700, verbose=2, validation_split=0.2, callbacks=[early_stopping_cb])
2023-02-05 13:19:57.878647: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 1/700
75/75 - 1s - loss: 4301052.5000 - val_loss: 4142640.0000 - 1s/epoch - 16ms/step
Epoch 2/700
75/75 - 0s - loss: 4272121.0000 - val_loss: 3999604.5000 - 208ms/epoch - 3ms/step
Epoch 3/700
75/75 - 0s - loss: 4216856.5000 - val_loss: 3994105.5000 - 208ms/epoch - 3ms/step
Epoch 4/700
75/75 - 0s - loss: 4131298.0000 - val_loss: 3949729.0000 - 213ms/epoch - 3ms/step
Epoch 5/700
75/75 - 0s - loss: 4015520.7500 - val_loss: 3862589.7500 - 210ms/epoch - 3ms/step
Epoch 6/700
75/75 - 0s - loss: 3872733.2500 - val_loss: 3750626.7500 - 208ms/epoch - 3ms/step
Epoch 7/700
75/75 - 0s - loss: 3706083.0000 - val_loss: 3600402.2500 - 215ms/epoch - 3ms/step
Epoch 8/700
75/75 - 0s - loss: 3519330.2500 - val_loss: 3412134.2500 - 208ms/epoch - 3ms/step
Epoch 9/700
75/75 - 0s - loss: 3315900.0000 - val_loss: 3220617.0000 - 215ms/epoch - 3ms/step
Epoch 10/700
75/75 - 0s - loss: 3099819.2500 - val_loss: 3001288.5000 - 215ms/epoch - 3ms/step
Epoch 11/700
75/75 - 0s - loss: 2874579.2500 - val_loss: 2760121.0000 - 209ms/epoch - 3ms/step
Epoch 12/700
75/75 - 0s - loss: 2643835.0000 - val_loss: 2522569.2500 - 205ms/epoch - 3ms/step
Epoch 13/700
75/75 - 0s - loss: 2411119.5000 - val_loss: 2285000.5000 - 208ms/epoch - 3ms/step
Epoch 14/700
75/75 - 0s - loss: 2179730.7500 - val_loss: 2067624.7500 - 199ms/epoch - 3ms/step
Epoch 15/700
75/75 - 0s - loss: 1952585.0000 - val_loss: 1832976.7500 - 197ms/epoch - 3ms/step
Epoch 16/700
75/75 - 0s - loss: 1732681.2500 - val_loss: 1621207.2500 - 198ms/epoch - 3ms/step
Epoch 17/700
75/75 - 0s - loss: 1522108.7500 - val_loss: 1416498.6250 - 200ms/epoch - 3ms/step
Epoch 18/700
75/75 - 0s - loss: 1323307.7500 - val_loss: 1220637.6250 - 199ms/epoch - 3ms/step
Epoch 19/700
75/75 - 0s - loss: 1137731.3750 - val_loss: 1039171.1250 - 208ms/epoch - 3ms/step
Epoch 20/700
75/75 - 0s - loss: 966973.1875 - val_loss: 886933.5000 - 210ms/epoch - 3ms/step
Epoch 21/700
75/75 - 0s - loss: 811651.2500 - val_loss: 739620.5625 - 215ms/epoch - 3ms/step
Epoch 22/700
75/75 - 0s - loss: 672437.1875 - val_loss: 601959.3750 - 208ms/epoch - 3ms/step
Epoch 23/700
75/75 - 0s - loss: 549439.8125 - val_loss: 488426.4062 - 211ms/epoch - 3ms/step
Epoch 24/700
75/75 - 0s - loss: 442531.1562 - val_loss: 394424.0000 - 206ms/epoch - 3ms/step
Epoch 25/700
75/75 - 0s - loss: 350930.2188 - val_loss: 312412.0938 - 205ms/epoch - 3ms/step
Epoch 26/700
75/75 - 0s - loss: 273789.0312 - val_loss: 236741.2344 - 206ms/epoch - 3ms/step
Epoch 27/700
75/75 - 0s - loss: 210101.6875 - val_loss: 179171.1406 - 217ms/epoch - 3ms/step
Epoch 28/700
75/75 - 0s - loss: 158210.1094 - val_loss: 133281.4688 - 204ms/epoch - 3ms/step
Epoch 29/700
75/75 - 0s - loss: 117071.5469 - val_loss: 100188.4531 - 199ms/epoch - 3ms/step
Epoch 30/700
75/75 - 0s - loss: 85004.8047 - val_loss: 71474.6797 - 196ms/epoch - 3ms/step
Epoch 31/700
75/75 - 0s - loss: 60397.3398 - val_loss: 49187.9961 - 202ms/epoch - 3ms/step
Epoch 32/700
75/75 - 0s - loss: 42211.6172 - val_loss: 34480.4297 - 207ms/epoch - 3ms/step
Epoch 33/700
75/75 - 0s - loss: 28830.6055 - val_loss: 22966.5527 - 203ms/epoch - 3ms/step
Epoch 34/700
75/75 - 0s - loss: 19346.0469 - val_loss: 14648.2334 - 204ms/epoch - 3ms/step
Epoch 35/700
75/75 - 0s - loss: 12795.6367 - val_loss: 9636.8887 - 209ms/epoch - 3ms/step
Epoch 36/700
75/75 - 0s - loss: 8295.2812 - val_loss: 6862.2290 - 216ms/epoch - 3ms/step
Epoch 37/700
75/75 - 0s - loss: 5374.4155 - val_loss: 3738.5437 - 213ms/epoch - 3ms/step
Epoch 38/700
75/75 - 0s - loss: 3416.1621 - val_loss: 2329.8206 - 210ms/epoch - 3ms/step
Epoch 39/700
75/75 - 0s - loss: 2242.4575 - val_loss: 1475.1674 - 214ms/epoch - 3ms/step
Epoch 40/700
75/75 - 0s - loss: 1493.3206 - val_loss: 1122.0057 - 211ms/epoch - 3ms/step
Epoch 41/700
75/75 - 0s - loss: 1052.3135 - val_loss: 751.3673 - 206ms/epoch - 3ms/step
Epoch 42/700
75/75 - 0s - loss: 750.0993 - val_loss: 601.1572 - 213ms/epoch - 3ms/step
Epoch 43/700
75/75 - 0s - loss: 631.6475 - val_loss: 425.9563 - 205ms/epoch - 3ms/step
Epoch 44/700
75/75 - 0s - loss: 612.9835 - val_loss: 451.7712 - 206ms/epoch - 3ms/step
Epoch 45/700
75/75 - 0s - loss: 569.8580 - val_loss: 360.0299 - 210ms/epoch - 3ms/step
Epoch 46/700
75/75 - 0s - loss: 581.5479 - val_loss: 341.1945 - 215ms/epoch - 3ms/step
Epoch 47/700
75/75 - 0s - loss: 486.9858 - val_loss: 364.0634 - 197ms/epoch - 3ms/step
Epoch 48/700
75/75 - 0s - loss: 499.4584 - val_loss: 350.9095 - 200ms/epoch - 3ms/step
Epoch 49/700
75/75 - 0s - loss: 478.2133 - val_loss: 346.8938 - 198ms/epoch - 3ms/step
Epoch 50/700
75/75 - 0s - loss: 514.4147 - val_loss: 340.2244 - 202ms/epoch - 3ms/step
Epoch 51/700
75/75 - 0s - loss: 517.6034 - val_loss: 318.6063 - 202ms/epoch - 3ms/step
Epoch 52/700
75/75 - 0s - loss: 535.0484 - val_loss: 356.8185 - 196ms/epoch - 3ms/step
Epoch 53/700
75/75 - 0s - loss: 457.0552 - val_loss: 337.6630 - 197ms/epoch - 3ms/step
Epoch 54/700
75/75 - 0s - loss: 531.2640 - val_loss: 322.6632 - 200ms/epoch - 3ms/step
Epoch 55/700
75/75 - 0s - loss: 510.0555 - val_loss: 331.3161 - 200ms/epoch - 3ms/step
Epoch 56/700
75/75 - 0s - loss: 467.1401 - val_loss: 409.5538 - 201ms/epoch - 3ms/step
Epoch 57/700
75/75 - 0s - loss: 490.1585 - val_loss: 318.1010 - 201ms/epoch - 3ms/step
Epoch 58/700
75/75 - 0s - loss: 514.7844 - val_loss: 377.9467 - 205ms/epoch - 3ms/step
Epoch 59/700
75/75 - 0s - loss: 556.1674 - val_loss: 367.2943 - 202ms/epoch - 3ms/step
Epoch 60/700
75/75 - 0s - loss: 544.2922 - val_loss: 305.4519 - 198ms/epoch - 3ms/step
Epoch 61/700
75/75 - 0s - loss: 452.5020 - val_loss: 311.5174 - 201ms/epoch - 3ms/step
Epoch 62/700
75/75 - 0s - loss: 528.6728 - val_loss: 356.4193 - 200ms/epoch - 3ms/step
Epoch 63/700
75/75 - 0s - loss: 542.7446 - val_loss: 503.5054 - 202ms/epoch - 3ms/step
Epoch 64/700
75/75 - 0s - loss: 483.4487 - val_loss: 311.5475 - 204ms/epoch - 3ms/step
Epoch 65/700
75/75 - 0s - loss: 454.4919 - val_loss: 352.5090 - 201ms/epoch - 3ms/step
Epoch 66/700
75/75 - 0s - loss: 503.8775 - val_loss: 376.5328 - 210ms/epoch - 3ms/step
Epoch 67/700
75/75 - 0s - loss: 451.4953 - val_loss: 328.0042 - 206ms/epoch - 3ms/step
Epoch 68/700
75/75 - 0s - loss: 438.5162 - val_loss: 322.8679 - 205ms/epoch - 3ms/step
Epoch 69/700
75/75 - 0s - loss: 480.0949 - val_loss: 343.3659 - 204ms/epoch - 3ms/step
Epoch 70/700
75/75 - 0s - loss: 496.3212 - val_loss: 306.0331 - 204ms/epoch - 3ms/step
Epoch 71/700
75/75 - 0s - loss: 464.2759 - val_loss: 323.9979 - 204ms/epoch - 3ms/step
Epoch 72/700
75/75 - 0s - loss: 482.0170 - val_loss: 320.0938 - 200ms/epoch - 3ms/step
Epoch 73/700
75/75 - 0s - loss: 453.5753 - val_loss: 305.2671 - 207ms/epoch - 3ms/step
Epoch 74/700
75/75 - 0s - loss: 490.5117 - val_loss: 349.7020 - 207ms/epoch - 3ms/step
Epoch 75/700
75/75 - 0s - loss: 447.0436 - val_loss: 349.7736 - 206ms/epoch - 3ms/step
Epoch 76/700
75/75 - 0s - loss: 467.6292 - val_loss: 326.6443 - 206ms/epoch - 3ms/step
Epoch 77/700
75/75 - 0s - loss: 436.6336 - val_loss: 297.8269 - 213ms/epoch - 3ms/step
Epoch 78/700
75/75 - 0s - loss: 496.3705 - val_loss: 343.8814 - 206ms/epoch - 3ms/step
Epoch 79/700
75/75 - 0s - loss: 487.2369 - val_loss: 297.7052 - 201ms/epoch - 3ms/step
Epoch 80/700
75/75 - 0s - loss: 425.9744 - val_loss: 296.2877 - 205ms/epoch - 3ms/step
Epoch 81/700
75/75 - 0s - loss: 530.9971 - val_loss: 311.2696 - 210ms/epoch - 3ms/step
Epoch 82/700
75/75 - 0s - loss: 457.5061 - val_loss: 345.8722 - 201ms/epoch - 3ms/step
Epoch 83/700
75/75 - 0s - loss: 510.9530 - val_loss: 369.6526 - 209ms/epoch - 3ms/step
Epoch 84/700
75/75 - 0s - loss: 473.9654 - val_loss: 340.1275 - 204ms/epoch - 3ms/step
Epoch 85/700
75/75 - 0s - loss: 437.6663 - val_loss: 369.4005 - 205ms/epoch - 3ms/step
Epoch 86/700
75/75 - 0s - loss: 428.2475 - val_loss: 337.7111 - 203ms/epoch - 3ms/step
Epoch 87/700
75/75 - 0s - loss: 504.7291 - val_loss: 322.0513 - 200ms/epoch - 3ms/step
Epoch 88/700
75/75 - 0s - loss: 485.7505 - val_loss: 354.6273 - 198ms/epoch - 3ms/step
Epoch 89/700
75/75 - 0s - loss: 486.8000 - val_loss: 319.9012 - 198ms/epoch - 3ms/step
Epoch 90/700
75/75 - 0s - loss: 479.6288 - val_loss: 299.9872 - 203ms/epoch - 3ms/step
Epoch 91/700
75/75 - 0s - loss: 468.4537 - val_loss: 321.7562 - 207ms/epoch - 3ms/step
Epoch 92/700
75/75 - 0s - loss: 474.4741 - val_loss: 368.4492 - 206ms/epoch - 3ms/step
Epoch 93/700
75/75 - 0s - loss: 479.3322 - val_loss: 309.3318 - 204ms/epoch - 3ms/step
Epoch 94/700
75/75 - 0s - loss: 559.3271 - val_loss: 309.1451 - 204ms/epoch - 3ms/step
Epoch 95/700
75/75 - 0s - loss: 459.7544 - val_loss: 352.1678 - 212ms/epoch - 3ms/step
Epoch 96/700
75/75 - 0s - loss: 407.6223 - val_loss: 352.5784 - 207ms/epoch - 3ms/step
Epoch 97/700
75/75 - 0s - loss: 503.9336 - val_loss: 344.6494 - 205ms/epoch - 3ms/step
Epoch 98/700
75/75 - 0s - loss: 454.4256 - val_loss: 329.2439 - 206ms/epoch - 3ms/step
Epoch 99/700
75/75 - 0s - loss: 470.0044 - val_loss: 338.0790 - 201ms/epoch - 3ms/step
Epoch 100/700
75/75 - 0s - loss: 415.4849 - val_loss: 353.1748 - 205ms/epoch - 3ms/step

Save trained model¶

In [8]:
model.save('models/nn_regression_dic_all_vars.h5')

Learning curve¶

In [9]:
df_history = pd.DataFrame(history.history)
df_history.index.name = 'epoch'
df_history = df_history.reset_index()
df_history.to_csv('results/nn_regression_history_dic_all_vars.csv')

fig, ax = plt.subplots(figsize=(6, 6))
_ = sns.lineplot(x=df_history.epoch-0.5, y='loss', data=df_history, ax=ax, label='training set')
_ = sns.lineplot(x=df_history.epoch, y='val_loss', data=df_history, ax=ax, label='validation set')
_ = ax.set(ylabel = 'MSE')
# _ = ax.set(yscale='log')

MSE & $R^2$¶

In [10]:
from sklearn.metrics import r2_score

print('MSE on training set = {:.2f}'.format(model.evaluate(X_train, y_train, verbose=0)))
print('MSE on test set     = {:.2f}\n'.format(model.evaluate(X_test, y_test, verbose=0)))

y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

print('R squared on training set = {:.3f}'.format(r2_score(y_train, y_pred_train)))
print('R squared on test set     = {:.3f}'.format(r2_score(y_test, y_pred_test)))
MSE on training set = 201.72
MSE on test set     = 237.44

R squared on training set = 0.969
R squared on test set     = 0.963
In [11]:
fig, ax = plt.subplots(figsize=(6,6))
_ = sns.scatterplot(x=y_test.ravel(), y=y_pred_test.ravel(), ax=ax)
_ = ax.set(xlabel='observed DIC', ylabel='predicted DIC', title='Test dataset')
_ = ax.axis('equal')
In [12]:
# save test set features, target & predictions
df_test = pd.DataFrame(np.c_[X_test, y_test, y_pred_test], columns = features + ['DIC observed', 'DIC predicted'])
df_test['DIC residuals'] = df_test['DIC observed'] - df_test['DIC predicted']
df_test.to_csv('results/bottle_data_test_dic_all_vars.csv')

K-fold cross-validation¶

In [13]:
from sklearn.model_selection import KFold

kf = KFold(n_splits=5, shuffle=True, random_state=42)
score_vals = [] # store score values

nn_reg = keras.models.Sequential([
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(n_hidden),
    keras.layers.LeakyReLU(alpha=alpha),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(y_train.shape[1])
])
nn_reg.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())

for k, (train_idx, test_idx) in enumerate(kf.split(X_train)):
    X_tr, X_te = X_train[train_idx], X_train[test_idx]
    y_tr, y_te = y_train[train_idx], y_train[test_idx]
    history_cv = nn_reg.fit(X_tr, y_tr, epochs=700, verbose=0, validation_split=0.2, callbacks=[early_stopping_cb])
    y_pred = nn_reg.predict(X_te)
    score = r2_score(y_te, y_pred)
    score_vals.append(score)
    print('Fold {} test set R squared: {:.3f}'.format(k+1, score))

scores = np.array(score_vals)
print('\nBest R squared:  {:.3f}'.format(scores.max()))
print('Worst R squared: {:.3f}'.format(scores.min()))
print('Mean R squared:  {:.3f}'.format(scores.mean()))
Fold 1 test set R squared: 0.955
Fold 2 test set R squared: 0.964
Fold 3 test set R squared: 0.961
Fold 4 test set R squared: 0.968
Fold 5 test set R squared: 0.959

Best R squared:  0.968
Worst R squared: 0.955
Mean R squared:  0.961

Confidence interval for predictions¶

In [14]:
from mapie.regression import MapieRegressor
from keras.wrappers.scikit_learn import KerasRegressor
# from scikeras.wrappers import KerasRegressor

def build_model():
    model = keras.models.Sequential([
        keras.layers.BatchNormalization(),
        keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
        keras.layers.LeakyReLU(alpha=alpha),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(n_hidden),
        keras.layers.LeakyReLU(alpha=alpha),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(y_train.shape[1])
    ])
    model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())
    return model

# np.random.seed(42) # fix random seed for reproducibility
# estimator = KerasRegressor(build_fn=build_model, nb_epoch=1000)

# mapie_reg = MapieRegressor(estimator, method='plus', cv=5, agg_function='mean', n_jobs=-1)
# mapie_reg.fit(X_train, y_train.ravel())
# y_test_pred2, y_test_pi = mapie_reg.predict(X_test, alpha=0.05)
# df_interval = pd.DataFrame(
#     {
#         'observed': y_test.ravel(),
#         'predicted': y_test_pred2,
#         'pred_lower_bound': y_test_pi[:,0,0],
#         'pred_upper_bound':  y_test_pi[:,1,0]}
# )