Introduction to Interprenet
Interprenet builds constrained neural networks. We currently support monotonic and lipschitz constraints. In this tutorial, we will show how to apply intperpretable constraints to a model trained on the California house price dataset.
[1]:
import pandas
from sklearn.datasets import fetch_california_housing
bunch = fetch_california_housing()
X = pandas.DataFrame(bunch['data'], columns=bunch['feature_names'])
y = bunch['target'] * 10 # express target variable in units of $10,000 vs $100,000 to avoid values < 1
X
[1]:
MedInc | HouseAge | AveRooms | AveBedrms | Population | AveOccup | Latitude | Longitude | |
---|---|---|---|---|---|---|---|---|
0 | 8.3252 | 41.0 | 6.984127 | 1.023810 | 322.0 | 2.555556 | 37.88 | -122.23 |
1 | 8.3014 | 21.0 | 6.238137 | 0.971880 | 2401.0 | 2.109842 | 37.86 | -122.22 |
2 | 7.2574 | 52.0 | 8.288136 | 1.073446 | 496.0 | 2.802260 | 37.85 | -122.24 |
3 | 5.6431 | 52.0 | 5.817352 | 1.073059 | 558.0 | 2.547945 | 37.85 | -122.25 |
4 | 3.8462 | 52.0 | 6.281853 | 1.081081 | 565.0 | 2.181467 | 37.85 | -122.25 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
20635 | 1.5603 | 25.0 | 5.045455 | 1.133333 | 845.0 | 2.560606 | 39.48 | -121.09 |
20636 | 2.5568 | 18.0 | 6.114035 | 1.315789 | 356.0 | 3.122807 | 39.49 | -121.21 |
20637 | 1.7000 | 17.0 | 5.205543 | 1.120092 | 1007.0 | 2.325635 | 39.43 | -121.22 |
20638 | 1.8672 | 18.0 | 5.329513 | 1.171920 | 741.0 | 2.123209 | 39.43 | -121.32 |
20639 | 2.3886 | 16.0 | 5.254717 | 1.162264 | 1387.0 | 2.616981 | 39.37 | -121.24 |
20640 rows × 8 columns
We can get a quick ballpark of which variables exhibit monotonic relations with the target using correlation.
[2]:
import numpy
correlations = dict(map(reversed,
sorted(((numpy.corrcoef(X[column], y)[0, 1], column) for column in X.columns),
key=lambda x: -abs(x[0]))))
correlations
[2]:
{'MedInc': 0.6880752079585482,
'AveRooms': 0.15194828974145785,
'Latitude': -0.14416027687465913,
'HouseAge': 0.10562341249320988,
'AveBedrms': -0.046700512969486935,
'Longitude': -0.045966615117978504,
'Population': -0.024649678888894896,
'AveOccup': -0.02373741295613434}
See https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset for details on the meaning of each column.
Comparing correlation to common sense, it seems the following variables do and should exhibit monotonic relations with the target:
MedInc median income in block group
AveRooms average number of rooms per household
Futhermore, let’s assume it would be unreasonable to expect the rate of change of price with respect to MedInc (as AveRooms is not continuous) to be no greater than twice the correlation.
[3]:
import jax
from mvtk import interprenet
constraints = {
'MedInc': frozenset([interprenet.monotonic_constraint, interprenet.lipschitz_constraint]),
'AveRooms': frozenset([interprenet.monotonic_constraint]),
'Latitude': frozenset([interprenet.monotonic_constraint]),
'HouseAge': frozenset([interprenet.lipschitz_constraint]),
'AveBedrms': frozenset([interprenet.monotonic_constraint]),
'Longitude':frozenset(),
'Population': frozenset([interprenet.lipschitz_constraint]),
'AveOccup': frozenset([interprenet.lipschitz_constraint])}
def get_scale(column):
scale = 1
if interprenet.monotonic_constraint in constraints[column]:
# monotonic decreasing features need their sign flipped
# before entering the neural network
scale *= numpy.sign(correlations[column])
# if interprenet.lipschitz_constraint in constraints[column]:
# rescale before applying lipschitz constraint
# to match assumptions mentioned in the previous block
scale *= 2 * abs(correlations[column])
return scale
rescale = lambda x: x * numpy.asarray(list(map(get_scale, X.columns)))
init_params, model = interprenet.constrained_model(
tuple(constraints[column] for column in X.columns),
preprocess=rescale,
postprocess=jax.nn.relu) # monotonicity and lipschitz preserving way of eliminating negative prices
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701442376.697070 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[4]:
import jax
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(jax.numpy.asarray(X.values),
jax.numpy.asarray(y.astype('float')),
test_size=0.2,
random_state=0)
[5]:
def mean_absolute_percentage_error(y, y_pred):
return (abs(y - y_pred) / y).mean()
trained_params = interprenet.train((X_train, y_train),
(X_test, y_test),
(init_params, model),
metric=lambda y, y_pred: -mean_absolute_percentage_error(y, y_pred),
loss_fn=mean_absolute_percentage_error,
num_epochs=32,
step_size=0.001)
[6]:
trained_model = lambda X: model(trained_params, X)
mean_absolute_percentage_error(y_test, trained_model(X_test))
[6]:
Array(0.45206466, dtype=float32)
No matter what the values of the other features are (blue lines) increasing MedInc always increases price, but at a maximum rate.
[7]:
import matplotlib.pylab as plt
interprenet.plot(trained_model, X, 'MedInc')
plt.show()
No matter what the values of the other features are (blue lines) increasing Latitude always decreases price.
[8]:
import matplotlib.pylab as plt
interprenet.plot(trained_model, X, 'Latitude')
plt.show()
HouseAge can increase or decrease the price of the house at unbounded rates.
[9]:
import matplotlib.pylab as plt
interprenet.plot(trained_model, X, 'HouseAge')
plt.show()
AveRooms can only increase the price of the house.
[10]:
import matplotlib.pylab as plt
interprenet.plot(trained_model, X, 'AveRooms')
plt.show()
[ ]: