Created by Ivan Lima on Tue Jan 17 2023 22:28:20 -0500
%matplotlib inline
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os, datetime, warnings
print('Last updated on {}'.format(datetime.datetime.now().ctime()))
Last updated on Sun Feb 5 15:48:44 2023
import sns_settings
sns.set_context('paper')
pd.options.display.max_columns = 50
warnings.filterwarnings('ignore')
df_bottle_ta = pd.read_csv('data/bottle_data_TA_prepared.csv', parse_dates=['Date'], index_col=0, na_values=['<undefined>',-9999.])
df_bottle_ta = df_bottle_ta.loc[df_bottle_ta.Oxygen_flag.isin([2, 6])]
df_bottle_ta = df_bottle_ta.loc[df_bottle_ta.Oxygen.notnull()]
df_bottle_ta['log_Chl'] = np.log(df_bottle_ta.Chl)
df_bottle_ta['log_KD490'] = np.log(df_bottle_ta.KD490)
features = ['Depth', 'Temperature', 'Salinity', 'Oxygen', 'pCO2_atm', 'ADT', 'SST_hires', 'log_KD490']
target = ['TALK']
varlist = features + target
fg = sns.pairplot(df_bottle_ta, vars=varlist, hue='Season', plot_kws={'alpha':0.7}, diag_kind='hist')
data = df_bottle_ta[varlist]
corr_mat = data.corr()
fig, ax = plt.subplots(figsize=(7,7))
_ = sns.heatmap(corr_mat, ax=ax, cmap='vlag', center=0, square=True, annot=True, annot_kws={'fontsize':9})
_ = ax.set_title('Correlation')
from sklearn.model_selection import train_test_split, cross_val_score
data = df_bottle_ta[features + target + ['Season']].dropna()
X = data[features].values
y = data[target].values
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=data.Season.values, random_state=77)
X.shape, X_train.shape, X_test.shape, y_train.shape, y_test.shape
((3797, 8), (2847, 8), (950, 8), (2847, 1), (950, 1))
import tensorflow as tf
from tensorflow import keras
keras.utils.set_random_seed(42) # make things reproducible
n_hidden = 256 # number of nodes in hidden layers
alpha=0.01
model = keras.models.Sequential([
keras.layers.BatchNormalization(),
keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
keras.layers.LeakyReLU(alpha=alpha),
keras.layers.BatchNormalization(),
keras.layers.Dense(n_hidden),
keras.layers.LeakyReLU(alpha=alpha),
keras.layers.BatchNormalization(),
keras.layers.Dense(y_train.shape[1])
])
early_stopping_cb = keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True)
model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())
history = model.fit(X_train, y_train, epochs=700, verbose=2, validation_split=0.2, callbacks=[early_stopping_cb])
2023-02-05 15:49:43.210394: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 1/700 72/72 - 1s - loss: 5069441.5000 - val_loss: 4707083.5000 - 1s/epoch - 17ms/step Epoch 2/700 72/72 - 0s - loss: 5040604.0000 - val_loss: 4592761.5000 - 191ms/epoch - 3ms/step Epoch 3/700 72/72 - 0s - loss: 4985238.0000 - val_loss: 4566380.0000 - 200ms/epoch - 3ms/step Epoch 4/700 72/72 - 0s - loss: 4899123.5000 - val_loss: 4624969.0000 - 203ms/epoch - 3ms/step Epoch 5/700 72/72 - 0s - loss: 4782374.5000 - val_loss: 4587585.5000 - 201ms/epoch - 3ms/step Epoch 6/700 72/72 - 0s - loss: 4637739.5000 - val_loss: 4481434.0000 - 209ms/epoch - 3ms/step Epoch 7/700 72/72 - 0s - loss: 4468178.0000 - val_loss: 4317864.5000 - 203ms/epoch - 3ms/step Epoch 8/700 72/72 - 0s - loss: 4276821.0000 - val_loss: 4154593.0000 - 200ms/epoch - 3ms/step Epoch 9/700 72/72 - 0s - loss: 4067008.2500 - val_loss: 3937871.0000 - 196ms/epoch - 3ms/step Epoch 10/700 72/72 - 0s - loss: 3842127.5000 - val_loss: 3714400.0000 - 193ms/epoch - 3ms/step Epoch 11/700 72/72 - 0s - loss: 3605623.5000 - val_loss: 3485980.0000 - 188ms/epoch - 3ms/step Epoch 12/700 72/72 - 0s - loss: 3360782.7500 - val_loss: 3247669.0000 - 192ms/epoch - 3ms/step Epoch 13/700 72/72 - 0s - loss: 3111052.5000 - val_loss: 2983719.7500 - 201ms/epoch - 3ms/step Epoch 14/700 72/72 - 0s - loss: 2859288.2500 - val_loss: 2728252.7500 - 190ms/epoch - 3ms/step Epoch 15/700 72/72 - 0s - loss: 2608742.7500 - val_loss: 2490870.0000 - 196ms/epoch - 3ms/step Epoch 16/700 72/72 - 0s - loss: 2362129.7500 - val_loss: 2243484.5000 - 199ms/epoch - 3ms/step Epoch 17/700 72/72 - 0s - loss: 2121880.2500 - val_loss: 2004048.3750 - 192ms/epoch - 3ms/step Epoch 18/700 72/72 - 0s - loss: 1890480.1250 - val_loss: 1785151.8750 - 188ms/epoch - 3ms/step Epoch 19/700 72/72 - 0s - loss: 1669629.5000 - val_loss: 1560535.2500 - 194ms/epoch - 3ms/step Epoch 20/700 72/72 - 0s - loss: 1461320.5000 - val_loss: 1359107.3750 - 202ms/epoch - 3ms/step Epoch 21/700 72/72 - 0s - loss: 1266923.3750 - val_loss: 1177122.7500 - 202ms/epoch - 3ms/step Epoch 22/700 72/72 - 0s - loss: 1087273.5000 - val_loss: 1003961.1250 - 205ms/epoch - 3ms/step Epoch 23/700 72/72 - 0s - loss: 923309.9375 - val_loss: 847739.3750 - 202ms/epoch - 3ms/step Epoch 24/700 72/72 - 0s - loss: 775377.8125 - val_loss: 701684.1875 - 202ms/epoch - 3ms/step Epoch 25/700 72/72 - 0s - loss: 643496.4375 - val_loss: 581252.6250 - 206ms/epoch - 3ms/step Epoch 26/700 72/72 - 0s - loss: 527479.3125 - val_loss: 473521.2500 - 202ms/epoch - 3ms/step Epoch 27/700 72/72 - 0s - loss: 426782.7812 - val_loss: 384618.2188 - 193ms/epoch - 3ms/step Epoch 28/700 72/72 - 0s - loss: 340606.2812 - val_loss: 302297.5625 - 198ms/epoch - 3ms/step Epoch 29/700 72/72 - 0s - loss: 267751.0000 - val_loss: 234403.3594 - 192ms/epoch - 3ms/step Epoch 30/700 72/72 - 0s - loss: 207412.5469 - val_loss: 179622.5469 - 191ms/epoch - 3ms/step Epoch 31/700 72/72 - 0s - loss: 158190.7656 - val_loss: 136658.9219 - 200ms/epoch - 3ms/step Epoch 32/700 72/72 - 0s - loss: 118603.9141 - val_loss: 100376.5859 - 200ms/epoch - 3ms/step Epoch 33/700 72/72 - 0s - loss: 87488.3438 - val_loss: 72649.2266 - 197ms/epoch - 3ms/step Epoch 34/700 72/72 - 0s - loss: 63425.4180 - val_loss: 52746.0000 - 200ms/epoch - 3ms/step Epoch 35/700 72/72 - 0s - loss: 45147.2734 - val_loss: 38055.1680 - 199ms/epoch - 3ms/step Epoch 36/700 72/72 - 0s - loss: 31609.7109 - val_loss: 25793.3008 - 197ms/epoch - 3ms/step Epoch 37/700 72/72 - 0s - loss: 21656.0391 - val_loss: 17909.1504 - 200ms/epoch - 3ms/step Epoch 38/700 72/72 - 0s - loss: 14635.4629 - val_loss: 11191.2129 - 200ms/epoch - 3ms/step Epoch 39/700 72/72 - 0s - loss: 9713.2969 - val_loss: 7997.5781 - 203ms/epoch - 3ms/step Epoch 40/700 72/72 - 0s - loss: 6393.2378 - val_loss: 4722.2134 - 202ms/epoch - 3ms/step Epoch 41/700 72/72 - 0s - loss: 4194.1372 - val_loss: 3349.3635 - 210ms/epoch - 3ms/step Epoch 42/700 72/72 - 0s - loss: 2663.3704 - val_loss: 1896.9995 - 201ms/epoch - 3ms/step Epoch 43/700 72/72 - 0s - loss: 1726.8728 - val_loss: 1233.5474 - 192ms/epoch - 3ms/step Epoch 44/700 72/72 - 0s - loss: 1120.0137 - val_loss: 744.7473 - 192ms/epoch - 3ms/step Epoch 45/700 72/72 - 0s - loss: 800.1802 - val_loss: 494.0578 - 193ms/epoch - 3ms/step Epoch 46/700 72/72 - 0s - loss: 584.7601 - val_loss: 267.9976 - 191ms/epoch - 3ms/step Epoch 47/700 72/72 - 0s - loss: 425.1408 - val_loss: 205.9256 - 193ms/epoch - 3ms/step Epoch 48/700 72/72 - 0s - loss: 410.9278 - val_loss: 161.2164 - 191ms/epoch - 3ms/step Epoch 49/700 72/72 - 0s - loss: 383.3403 - val_loss: 132.7805 - 193ms/epoch - 3ms/step Epoch 50/700 72/72 - 0s - loss: 283.6974 - val_loss: 131.8669 - 196ms/epoch - 3ms/step Epoch 51/700 72/72 - 0s - loss: 302.8070 - val_loss: 114.6506 - 202ms/epoch - 3ms/step Epoch 52/700 72/72 - 0s - loss: 311.7834 - val_loss: 124.8023 - 191ms/epoch - 3ms/step Epoch 53/700 72/72 - 0s - loss: 264.2063 - val_loss: 117.4685 - 196ms/epoch - 3ms/step Epoch 54/700 72/72 - 0s - loss: 341.2187 - val_loss: 106.9735 - 199ms/epoch - 3ms/step Epoch 55/700 72/72 - 0s - loss: 294.4723 - val_loss: 95.4622 - 193ms/epoch - 3ms/step Epoch 56/700 72/72 - 0s - loss: 323.0903 - val_loss: 103.9625 - 191ms/epoch - 3ms/step Epoch 57/700 72/72 - 0s - loss: 316.3461 - val_loss: 87.8155 - 190ms/epoch - 3ms/step Epoch 58/700 72/72 - 0s - loss: 276.1931 - val_loss: 99.9530 - 206ms/epoch - 3ms/step Epoch 59/700 72/72 - 0s - loss: 300.0178 - val_loss: 91.4259 - 205ms/epoch - 3ms/step Epoch 60/700 72/72 - 0s - loss: 293.7442 - val_loss: 91.1352 - 197ms/epoch - 3ms/step Epoch 61/700 72/72 - 0s - loss: 362.6285 - val_loss: 136.7434 - 192ms/epoch - 3ms/step Epoch 62/700 72/72 - 0s - loss: 305.5133 - val_loss: 112.6693 - 199ms/epoch - 3ms/step Epoch 63/700 72/72 - 0s - loss: 335.7766 - val_loss: 112.8853 - 193ms/epoch - 3ms/step Epoch 64/700 72/72 - 0s - loss: 265.2259 - val_loss: 102.5939 - 189ms/epoch - 3ms/step Epoch 65/700 72/72 - 0s - loss: 380.0038 - val_loss: 89.7266 - 197ms/epoch - 3ms/step Epoch 66/700 72/72 - 0s - loss: 313.7647 - val_loss: 99.9384 - 201ms/epoch - 3ms/step Epoch 67/700 72/72 - 0s - loss: 337.4904 - val_loss: 150.3070 - 192ms/epoch - 3ms/step Epoch 68/700 72/72 - 0s - loss: 289.4359 - val_loss: 87.8965 - 197ms/epoch - 3ms/step Epoch 69/700 72/72 - 0s - loss: 329.6559 - val_loss: 84.0812 - 205ms/epoch - 3ms/step Epoch 70/700 72/72 - 0s - loss: 220.7717 - val_loss: 104.7515 - 196ms/epoch - 3ms/step Epoch 71/700 72/72 - 0s - loss: 267.6108 - val_loss: 83.2307 - 202ms/epoch - 3ms/step Epoch 72/700 72/72 - 0s - loss: 276.6163 - val_loss: 114.3175 - 198ms/epoch - 3ms/step Epoch 73/700 72/72 - 0s - loss: 250.2851 - val_loss: 118.0055 - 193ms/epoch - 3ms/step Epoch 74/700 72/72 - 0s - loss: 302.3627 - val_loss: 88.4500 - 196ms/epoch - 3ms/step Epoch 75/700 72/72 - 0s - loss: 281.5790 - val_loss: 112.4895 - 195ms/epoch - 3ms/step Epoch 76/700 72/72 - 0s - loss: 252.8293 - val_loss: 80.4848 - 189ms/epoch - 3ms/step Epoch 77/700 72/72 - 0s - loss: 321.3287 - val_loss: 81.2727 - 190ms/epoch - 3ms/step Epoch 78/700 72/72 - 0s - loss: 292.9449 - val_loss: 98.0065 - 195ms/epoch - 3ms/step Epoch 79/700 72/72 - 0s - loss: 259.5386 - val_loss: 85.9486 - 197ms/epoch - 3ms/step Epoch 80/700 72/72 - 0s - loss: 294.6333 - val_loss: 110.3510 - 190ms/epoch - 3ms/step Epoch 81/700 72/72 - 0s - loss: 287.5786 - val_loss: 109.3566 - 190ms/epoch - 3ms/step Epoch 82/700 72/72 - 0s - loss: 343.5632 - val_loss: 91.4886 - 185ms/epoch - 3ms/step Epoch 83/700 72/72 - 0s - loss: 297.9141 - val_loss: 148.2100 - 194ms/epoch - 3ms/step Epoch 84/700 72/72 - 0s - loss: 327.3770 - val_loss: 113.3585 - 185ms/epoch - 3ms/step Epoch 85/700 72/72 - 0s - loss: 286.8950 - val_loss: 84.3653 - 191ms/epoch - 3ms/step Epoch 86/700 72/72 - 0s - loss: 293.4732 - val_loss: 93.4493 - 189ms/epoch - 3ms/step Epoch 87/700 72/72 - 0s - loss: 363.5659 - val_loss: 103.0605 - 195ms/epoch - 3ms/step Epoch 88/700 72/72 - 0s - loss: 265.1273 - val_loss: 109.8584 - 193ms/epoch - 3ms/step Epoch 89/700 72/72 - 0s - loss: 237.3158 - val_loss: 84.5322 - 193ms/epoch - 3ms/step Epoch 90/700 72/72 - 0s - loss: 316.2466 - val_loss: 93.6544 - 190ms/epoch - 3ms/step Epoch 91/700 72/72 - 0s - loss: 277.2148 - val_loss: 112.2858 - 194ms/epoch - 3ms/step Epoch 92/700 72/72 - 0s - loss: 247.4191 - val_loss: 97.9013 - 187ms/epoch - 3ms/step Epoch 93/700 72/72 - 0s - loss: 366.7252 - val_loss: 98.9186 - 186ms/epoch - 3ms/step Epoch 94/700 72/72 - 0s - loss: 347.8570 - val_loss: 99.7910 - 188ms/epoch - 3ms/step Epoch 95/700 72/72 - 0s - loss: 313.8355 - val_loss: 108.8730 - 191ms/epoch - 3ms/step Epoch 96/700 72/72 - 0s - loss: 304.1493 - val_loss: 85.6108 - 199ms/epoch - 3ms/step
model.save('models/nn_regression_ta_all_vars.h5')
df_history = pd.DataFrame(history.history)
df_history.index.name = 'epoch'
df_history = df_history.reset_index()
df_history.to_csv('results/nn_regression_history_ta_all_vars.csv')
fig, ax = plt.subplots(figsize=(6, 6))
_ = sns.lineplot(x=df_history.epoch-0.5, y='loss', data=df_history, ax=ax, label='training set')
_ = sns.lineplot(x=df_history.epoch, y='val_loss', data=df_history, ax=ax, label='validation set')
_ = ax.set(ylabel = 'MSE')
# _ = ax.set(yscale='log')
from sklearn.metrics import r2_score
print('MSE on training set = {:.2f}'.format(model.evaluate(X_train, y_train, verbose=0)))
print('MSE on test set = {:.2f}\n'.format(model.evaluate(X_test, y_test, verbose=0)))
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)
print('R squared on training set = {:.3f}'.format(r2_score(y_train, y_pred_train)))
print('R squared on test set = {:.3f}'.format(r2_score(y_test, y_pred_test)))
MSE on training set = 66.22 MSE on test set = 80.87 R squared on training set = 0.989 R squared on test set = 0.986
fig, ax = plt.subplots(figsize=(6,6))
_ = sns.scatterplot(x=y_test.ravel(), y=y_pred_test.ravel(), ax=ax)
_ = ax.set(xlabel='observed TA', ylabel='predicted TA', title='Test dataset')
_ = ax.axis('equal')
# save test set features, target & predictions
df_test = pd.DataFrame(np.c_[X_test, y_test, y_pred_test], columns = features + ['TA observed', 'TA predicted'])
df_test['TA residuals'] = df_test['TA observed'] - df_test['TA predicted']
df_test.to_csv('results/bottle_data_test_ta_all_vars.csv')
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
score_vals = [] # store score values
nn_reg = keras.models.Sequential([
keras.layers.BatchNormalization(),
keras.layers.Dense(n_hidden, input_shape=X_train.shape[1:]),
keras.layers.LeakyReLU(alpha=alpha),
keras.layers.BatchNormalization(),
keras.layers.Dense(n_hidden),
keras.layers.LeakyReLU(alpha=alpha),
keras.layers.BatchNormalization(),
keras.layers.Dense(y_train.shape[1])
])
nn_reg.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam())
for k, (train_idx, test_idx) in enumerate(kf.split(X_train)):
X_tr, X_te = X_train[train_idx], X_train[test_idx]
y_tr, y_te = y_train[train_idx], y_train[test_idx]
history_cv = nn_reg.fit(X_tr, y_tr, epochs=700, verbose=0, validation_split=0.2, callbacks=[early_stopping_cb])
y_pred = nn_reg.predict(X_te)
score = r2_score(y_te, y_pred)
score_vals.append(score)
print('Fold {} test set R squared: {:.3f}'.format(k+1, score))
scores = np.array(score_vals)
print('\nBest R squared: {:.3f}'.format(scores.max()))
print('Worst R squared: {:.3f}'.format(scores.min()))
print('Mean R squared: {:.3f}'.format(scores.mean()))
Fold 1 test set R squared: 0.987 Fold 2 test set R squared: 0.981 Fold 3 test set R squared: 0.988 Fold 4 test set R squared: 0.989 Fold 5 test set R squared: 0.985 Best R squared: 0.989 Worst R squared: 0.981 Mean R squared: 0.986