Ordinal regressionΒΆ

Ordinal regression aims to fit a model to some data \((X, Y)\), where \(Y\) is an ordinal variable. To do so, we use a VPG model with a specific likelihood (gpflow.likelihoods.Ordinal).

[1]:
import gpflow

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 6)

np.random.seed(123)  # for reproducibility
[2]:
# make a one-dimensional ordinal regression problem

# This function generates a set of inputs X,
# quantitative output f (latent) and ordinal values Y


def generate_data(num_data):
    # First generate random inputs
    X = np.random.rand(num_data, 1)

    # Now generate values of a latent GP
    kern = gpflow.kernels.SquaredExponential(lengthscales=0.1)
    K = kern(X)
    f = np.random.multivariate_normal(mean=np.zeros(num_data), cov=K).reshape(-1, 1)

    # Finally convert f values into ordinal values Y
    Y = np.round((f + f.min()) * 3)
    Y = Y - Y.min()
    Y = np.asarray(Y, np.float64)

    return X, f, Y


np.random.seed(1)
num_data = 20
X, f, Y = generate_data(num_data)

plt.figure(figsize=(11, 6))
plt.plot(X, f, ".")
plt.ylabel("latent function value")

plt.twinx()
plt.plot(X, Y, "kx", mew=1.5)
plt.ylabel("observed data value")
2022-03-18 10:06:49.283950: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-03-18 10:06:49.287347: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such file or directory
2022-03-18 10:06:49.287857: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2022-03-18 10:06:49.288614: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2]:
Text(0, 0.5, 'observed data value')
../../_images/notebooks_advanced_ordinal_regression_2_2.png
[3]:
# construct ordinal likelihood - bin_edges is the same as unique(Y) but centered
bin_edges = np.array(np.arange(np.unique(Y).size + 1), dtype=float)
bin_edges = bin_edges - bin_edges.mean()
likelihood = gpflow.likelihoods.Ordinal(bin_edges)

# build a model with this likelihood
m = gpflow.models.VGP(data=(X, Y), kernel=gpflow.kernels.Matern32(), likelihood=likelihood)

# fit the model
opt = gpflow.optimizers.Scipy()
opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=100))
2022-03-18 10:06:49.456941: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
[3]:
      fun: 25.487470477487047
 hess_inv: <233x233 LbfgsInvHessProduct with dtype=float64>
      jac: array([-6.65482792e-02, -2.28011094e-02, -4.10331104e-02, -4.42627418e-02,
       -6.76994496e-02, -2.98011020e-02, -6.97511637e-02, -4.67811016e-04,
       -1.28365467e-02,  2.16799746e-02,  5.08941935e-03,  1.45547171e-02,
        6.88426919e-03, -2.41224436e-02, -9.48170239e-03, -2.19859908e-02,
       -2.05467585e-04, -1.22032380e-02, -2.20537165e-02, -1.67117284e-03,
        2.77558580e-04, -4.07676425e-04, -5.72485098e-09,  2.48512752e-06,
       -1.87613440e-09, -7.61132102e-06,  1.94440349e-11, -6.40226130e-03,
        1.98085263e-08,  2.41774976e-06, -2.43871182e-07, -2.07220555e-05,
        1.54185927e-04, -3.35571882e-03,  2.68769104e-04,  4.67488855e-05,
        1.87349895e-04, -1.87075004e-04, -1.52141097e-04,  3.49573448e-04,
       -2.59341003e-02,  7.38025918e-03, -3.48153857e-09,  1.01544739e-06,
       -2.16878450e-10,  4.48229323e-04,  1.26772760e-11, -1.31505057e-03,
        1.42452956e-08,  1.72877352e-07, -4.68453744e-07, -1.70397529e-05,
        9.67499997e-05,  2.67820614e-03, -1.21331960e-03, -1.31539205e-02,
        1.63548390e-03, -1.34677623e-03, -4.01958872e-04, -8.42956290e-04,
        1.30852862e-02,  3.10825813e-02, -1.37891848e-02, -5.81727766e-04,
       -5.26061617e-03, -2.84808995e-09, -1.51826941e-04, -6.49921089e-07,
       -8.71770858e-03,  4.66246333e-03, -2.30823654e-02, -8.98120884e-04,
       -7.03403632e-04, -5.79634199e-05, -4.16271734e-04,  6.33038296e-05,
       -7.47499160e-03, -3.56795892e-04,  2.50168079e-03,  3.24804117e-03,
       -2.16233481e-02, -1.00909264e-02, -6.66105808e-02, -1.95515266e-05,
       -8.15485894e-05, -2.65275252e-08, -2.20500837e-06,  3.82557557e-06,
       -1.25829646e-04,  7.31918817e-05, -3.69209769e-04, -1.79144839e-05,
       -1.29683386e-05,  1.74288971e-06, -6.48552173e-06,  1.63472619e-06,
       -1.16533766e-04, -5.29125623e-06,  4.07042889e-05, -1.78211942e-05,
        3.77690070e-03, -2.38226907e-02, -3.23549979e-02,  1.59620787e-02,
        2.43412582e-02, -1.65233061e-09,  4.12675097e-05, -3.23306484e-07,
        1.79293668e-02,  8.80523692e-04, -5.92348433e-03, -6.05228126e-04,
       -4.29061540e-03, -2.70107404e-05,  6.33272833e-04,  3.53920300e-04,
       -4.95085938e-03,  8.70060777e-04, -4.18799383e-03, -8.09859157e-06,
       -2.67041191e-02, -3.05997587e-02, -5.91385078e-02,  5.05747012e-02,
       -1.71623151e-02, -1.32902931e-03,  7.45973108e-11,  8.92101065e-05,
       -2.86595758e-08, -2.63708394e-06,  1.28404794e-06, -3.43501114e-06,
        2.12480794e-04, -1.21357176e-03,  3.46221253e-04,  7.34346924e-03,
       -3.13980980e-03, -1.21033260e-02,  2.05415073e-04, -4.64325996e-04,
       -2.25410284e-02, -1.01492933e-03,  5.84196436e-03, -8.31480497e-03,
        1.51740827e-02,  1.24359456e-02, -4.63549998e-03,  9.49243555e-09,
        1.48145850e-02, -7.98858002e-05,  3.50859581e-03,  8.56290952e-04,
       -3.97727785e-03,  7.05717959e-04,  1.40914591e-03,  1.11314470e-04,
        1.03747031e-02, -1.03133396e-03,  5.89783703e-03,  6.12019068e-03,
        1.31142733e-02, -6.94945065e-03, -8.77972016e-03,  3.47435536e-03,
       -1.32934317e-02,  7.23586436e-03, -2.55926494e-02,  1.14888923e-02,
        4.14225296e-07, -9.61651344e-05, -7.13228787e-06, -8.85671255e-04,
        3.20724486e-03, -1.14173652e-02,  4.38266654e-03,  1.30171502e-03,
        1.27924069e-03, -1.19806804e-03, -1.73452399e-03,  3.55347773e-03,
        6.82831128e-03, -1.08854421e-04,  3.09663294e-04, -1.50953628e-02,
       -4.94313579e-03,  7.44792614e-05,  5.97823614e-03,  7.36015811e-03,
       -8.48815128e-03, -5.85041430e-04,  5.19347404e-03, -2.70607876e-03,
       -2.18095111e-02,  3.16616151e-06,  3.82129358e-03,  1.86778718e-03,
       -2.04992385e-02,  5.59729280e-03,  1.25365618e-02, -3.92868360e-03,
        9.77589097e-04,  5.74138277e-04,  4.45949676e-04,  1.26726598e-02,
       -3.52411611e-04,  7.33145240e-04, -3.43104390e-04,  2.79855493e-03,
        1.18075865e-02, -8.33277793e-04,  8.20941777e-04,  1.62535752e-04,
        8.15731699e-04,  4.18787011e-05, -8.13929159e-05, -6.60589823e-05,
        1.62099712e-03, -1.60702687e-04, -7.80567068e-04, -2.65167430e-03,
        6.97977738e-03,  3.00029809e-03, -6.50974064e-04, -6.91986612e-02,
        1.45574820e-03, -2.02419559e-03,  3.70928132e-03, -6.58496841e-03,
       -7.01642799e-03, -3.52647198e-05,  1.25209562e-01, -8.32970706e-03,
       -1.00396882e-02])
  message: 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
     nfev: 116
      nit: 100
     njev: 116
   status: 1
  success: False
        x: array([-1.99973507e+00, -2.36106014e-01, -8.88950231e-04,  7.62047389e-01,
       -1.98110852e-01, -1.59229729e+00,  7.93681577e-01, -3.86889635e-01,
       -1.13949872e-01,  3.78716109e-01, -8.98007865e-02,  3.46381679e-01,
       -7.61846764e-02,  1.59570909e+00, -2.61505459e-01,  8.54327746e-01,
        6.56880248e-03,  4.46001417e-01, -3.20365109e-01,  6.90522685e-04,
        9.88817853e-01, -2.80824927e-03,  6.18784910e-08,  1.28533951e-05,
        1.00558904e-08,  1.33405300e-04,  5.89551639e-11, -6.28875301e-02,
        7.68178603e-09, -1.38969408e-04, -1.17838966e-05, -7.71515631e-04,
        2.21555882e-03, -6.31938620e-02,  1.79725561e-03, -3.69556992e-03,
       -2.79571210e-03,  1.02877805e-05,  4.53908217e-06,  3.23323905e-05,
        2.24967549e-01,  9.45886629e-01,  3.60354837e-08,  6.94167470e-06,
        6.70474894e-09,  8.43115306e-03,  3.48936608e-11, -2.24840100e-02,
        7.58195183e-09, -4.44197992e-05, -4.05352864e-06, -2.51109800e-04,
        7.33512760e-04,  2.99273658e-02, -3.30685715e-02, -2.28129531e-01,
        1.38893404e-03, -7.87974067e-04,  1.12013359e-05, -9.02922037e-06,
       -1.92226109e-02,  1.41991779e-01,  6.14727350e-01, -9.80894462e-03,
       -1.31907731e-01, -4.48108569e-08, -6.29714658e-04, -1.33375023e-05,
        1.21348315e-02,  7.00032255e-02, -3.90308924e-01, -3.17501863e-02,
       -8.18888591e-03, -2.86406174e-04, -3.04797075e-05, -1.67377067e-07,
        1.74985874e-03, -5.42130485e-05,  2.23708887e-02,  4.28758482e-02,
       -5.42741033e-03,  3.53010317e-04,  1.37478718e-01,  9.99917312e-01,
       -2.04363826e-03, -4.50334169e-07, -9.60871806e-06,  3.04958051e-05,
        1.97229939e-04,  1.07912582e-03, -6.05900275e-03, -4.88684129e-04,
       -1.28209179e-04,  2.44421575e-05,  2.21559925e-07,  1.20153934e-05,
        2.74078342e-05, -7.17980902e-07,  3.48428003e-04, -1.75272621e-04,
       -1.52910078e-01,  3.75832737e-03, -1.23789086e-02,  1.69803943e-01,
        7.18771052e-01, -1.66881797e-08, -4.40753813e-03, -4.47482648e-06,
       -4.24396739e-01,  2.30615263e-02, -8.94045201e-02, -1.37519921e-02,
       -3.75861609e-03, -9.54115638e-05,  3.86714161e-05,  2.32820141e-05,
        7.92635445e-04,  3.85342568e-05,  4.30465894e-02,  1.76200960e-02,
        3.15223900e-02, -7.24675583e-04, -6.33868195e-02, -7.06778710e-02,
        1.09864108e-01,  5.73035709e-01, -2.04467126e-10,  1.13192405e-02,
       -1.83498169e-08,  2.11808827e-05,  2.17018938e-06,  1.15327185e-04,
       -2.48794531e-04, -2.39822377e-02, -1.74611803e-01,  4.90762416e-02,
       -4.37778536e-04, -3.00788756e-01, -9.22197262e-05,  2.30068757e-04,
       -1.29006345e-02, -1.19659731e-04, -1.25980727e-01,  3.06796476e-02,
       -1.41215997e-01,  3.20209507e-01,  1.62203867e-01, -2.34521968e-07,
        5.21908550e-02,  1.25365593e-03, -2.20292382e-03, -9.54802411e-04,
       -5.62928321e-04,  1.88365869e-05,  8.06665318e-05,  2.19577195e-05,
        1.89869022e-04,  2.91274858e-05, -1.13133935e-01,  8.99645495e-04,
        3.53489867e-02, -7.31890260e-04, -2.02204063e-02, -9.81245241e-02,
       -2.11319512e-01,  1.18283561e-01,  3.07491970e-01,  6.84416177e-01,
       -2.34986078e-06, -2.49204055e-03, -2.14507967e-04, -1.36727050e-02,
        3.84321071e-02, -4.85660167e-01,  6.70934600e-03, -3.72243739e-02,
       -6.34279486e-02,  7.17549017e-03, -3.42841616e-05,  3.32851329e-04,
       -2.12537167e-01,  4.21399542e-03, -9.05005700e-04, -2.67076624e-01,
       -9.12473222e-03,  1.82205014e-03,  5.99705061e-02,  4.20819043e-01,
        2.69432099e-01,  3.30855855e-02, -6.13277164e-02, -2.50656269e-02,
       -7.41799897e-03, -4.92839843e-05,  2.49770763e-04,  1.13665433e-04,
        9.36287131e-04,  3.12953803e-04, -2.73073300e-01,  2.97713887e-02,
       -8.59037069e-02,  5.73175826e-03,  3.08365434e-04,  8.21119838e-03,
       -6.95164931e-04,  2.05160640e-04, -1.22903045e-02, -5.97397972e-02,
        9.07757568e-01,  9.94613822e-01,  4.55107671e-02,  5.61538627e-03,
        1.55954925e-03, -2.06785274e-03,  2.30053720e-05, -1.28025065e-04,
       -5.98762243e-04,  3.59160814e-05,  5.40212362e-03, -2.59290903e-02,
       -2.08566145e-01, -4.35796479e-02, -2.78265262e-05, -1.18714097e-02,
       -1.90978292e-04, -4.28969217e-05, -2.18251400e-04,  5.09965450e-02,
        2.04370047e-01,  1.67917646e-01, -1.97907241e+00,  5.46706876e+00,
       -1.44968285e+00])
[4]:
# here we'll plot the expected value of Y +- 2 std deviations, as if the distribution were Gaussian
plt.figure(figsize=(11, 6))
X_data, Y_data = (m.data[0].numpy(), m.data[1].numpy())
Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
mu, var = m.predict_y(Xtest)
(line,) = plt.plot(Xtest, mu, lw=2)
col = line.get_color()
plt.plot(Xtest, mu + 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(Xtest, mu - 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(X_data, Y_data, "kx", mew=2)
[4]:
[<matplotlib.lines.Line2D at 0x7f8170469090>]
../../_images/notebooks_advanced_ordinal_regression_4_1.png
[5]:
## to see the predictive density, try predicting every possible discrete value for Y.
def pred_log_density(m):
    Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
    ys = np.arange(Y_data.max() + 1)
    densities = []
    for y in ys:
        Ytest = np.full_like(Xtest, y)
        # Predict the log density
        densities.append(m.predict_log_density((Xtest, Ytest)))
    return np.vstack(densities)
[6]:
fig = plt.figure(figsize=(14, 6))
plt.imshow(
    np.exp(pred_log_density(m)),
    interpolation="nearest",
    extent=[X_data.min(), X_data.max(), -0.5, Y_data.max() + 0.5],
    origin="lower",
    aspect="auto",
    cmap=plt.cm.viridis,
)
plt.colorbar()
plt.plot(X, Y, "kx", mew=2, scalex=False, scaley=False)
[6]:
[<matplotlib.lines.Line2D at 0x7f8170346aa0>]
../../_images/notebooks_advanced_ordinal_regression_6_1.png
[7]:
# Predictive density for a single input x=0.5
x_new = 0.5
Y_new = np.arange(np.max(Y_data + 1)).reshape([-1, 1])
X_new = np.full_like(Y_new, x_new)
# for predict_log_density x and y need to have the same number of rows
dens_new = np.exp(m.predict_log_density((X_new, Y_new)))
fig = plt.figure(figsize=(8, 4))
plt.bar(x=Y_new.flatten(), height=dens_new.flatten())
[7]:
<BarContainer object of 8 artists>
../../_images/notebooks_advanced_ordinal_regression_7_1.png