Introduction to Interprenet

Interprenet builds constrained neural networks. We currently support monotonic and lipschitz constraints. In this tutorial, we will show how to apply intperpretable constraints to a model trained on the California house price dataset.

[1]:
import pandas
from sklearn.datasets import fetch_california_housing

bunch = fetch_california_housing()
X = pandas.DataFrame(bunch['data'], columns=bunch['feature_names'])
y = bunch['target'] * 10 # express target variable in units of $10,000 vs $100,000 to avoid values < 1
X
[1]:
MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25
... ... ... ... ... ... ... ... ...
20635 1.5603 25.0 5.045455 1.133333 845.0 2.560606 39.48 -121.09
20636 2.5568 18.0 6.114035 1.315789 356.0 3.122807 39.49 -121.21
20637 1.7000 17.0 5.205543 1.120092 1007.0 2.325635 39.43 -121.22
20638 1.8672 18.0 5.329513 1.171920 741.0 2.123209 39.43 -121.32
20639 2.3886 16.0 5.254717 1.162264 1387.0 2.616981 39.37 -121.24

20640 rows × 8 columns

We can get a quick ballpark of which variables exhibit monotonic relations with the target using correlation.

[2]:
import numpy

correlations = dict(map(reversed,
         sorted(((numpy.corrcoef(X[column], y)[0, 1], column) for column in X.columns),
                key=lambda x: -abs(x[0]))))
correlations
[2]:
{'MedInc': 0.6880752079585482,
 'AveRooms': 0.15194828974145785,
 'Latitude': -0.14416027687465913,
 'HouseAge': 0.10562341249320988,
 'AveBedrms': -0.046700512969486935,
 'Longitude': -0.045966615117978504,
 'Population': -0.024649678888894896,
 'AveOccup': -0.02373741295613434}

See https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset for details on the meaning of each column.

Comparing correlation to common sense, it seems the following variables do and should exhibit monotonic relations with the target:

  • MedInc median income in block group

  • AveRooms average number of rooms per household

Futhermore, let’s assume it would be unreasonable to expect the rate of change of price with respect to MedInc (as AveRooms is not continuous) to be no greater than twice the correlation.

[3]:
import jax

from mvtk import interprenet

constraints = {
    'MedInc': frozenset([interprenet.monotonic_constraint, interprenet.lipschitz_constraint]),
    'AveRooms': frozenset([interprenet.monotonic_constraint]),
    'Latitude': frozenset([interprenet.monotonic_constraint]),
    'HouseAge': frozenset([interprenet.lipschitz_constraint]),
    'AveBedrms': frozenset([interprenet.monotonic_constraint]),
    'Longitude':frozenset(),
    'Population': frozenset([interprenet.lipschitz_constraint]),
    'AveOccup': frozenset([interprenet.lipschitz_constraint])}

def get_scale(column):
    scale = 1
    if interprenet.monotonic_constraint in constraints[column]:
        # monotonic decreasing features need their sign flipped
        # before entering the neural network
        scale *= numpy.sign(correlations[column])
    # if interprenet.lipschitz_constraint in constraints[column]:
        # rescale before applying lipschitz constraint
        # to match assumptions mentioned in the previous block
        scale *= 2 * abs(correlations[column])
    return scale

rescale = lambda x: x * numpy.asarray(list(map(get_scale, X.columns)))
init_params, model = interprenet.constrained_model(
    tuple(constraints[column] for column in X.columns),
    preprocess=rescale,
    postprocess=jax.nn.relu) # monotonicity and lipschitz preserving way of eliminating negative prices
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701442376.697070       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[4]:
import jax
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(jax.numpy.asarray(X.values),
                                                    jax.numpy.asarray(y.astype('float')),
                                                    test_size=0.2,
                                                    random_state=0)
[5]:
def mean_absolute_percentage_error(y, y_pred):
    return (abs(y - y_pred) / y).mean()

trained_params = interprenet.train((X_train, y_train),
                                   (X_test, y_test),
                                   (init_params, model),
                                   metric=lambda y, y_pred: -mean_absolute_percentage_error(y, y_pred),
                                   loss_fn=mean_absolute_percentage_error,
                                   num_epochs=32,
                                   step_size=0.001)
[6]:
trained_model = lambda X: model(trained_params, X)
mean_absolute_percentage_error(y_test, trained_model(X_test))
[6]:
Array(0.45206466, dtype=float32)

No matter what the values of the other features are (blue lines) increasing MedInc always increases price, but at a maximum rate.

[7]:
import matplotlib.pylab as plt

interprenet.plot(trained_model, X, 'MedInc')
plt.show()
../../_images/notebooks_interprenet_Interprenet_10_0.png

No matter what the values of the other features are (blue lines) increasing Latitude always decreases price.

[8]:
import matplotlib.pylab as plt

interprenet.plot(trained_model, X, 'Latitude')
plt.show()
../../_images/notebooks_interprenet_Interprenet_12_0.png

HouseAge can increase or decrease the price of the house at unbounded rates.

[9]:
import matplotlib.pylab as plt

interprenet.plot(trained_model, X, 'HouseAge')
plt.show()
../../_images/notebooks_interprenet_Interprenet_14_0.png

AveRooms can only increase the price of the house.

[10]:
import matplotlib.pylab as plt

interprenet.plot(trained_model, X, 'AveRooms')
plt.show()
../../_images/notebooks_interprenet_Interprenet_16_0.png
[ ]: