interprenet

ConstrainedDense(constraint)[source]

Layer constructor function for a constrained dense (fully-connected) layer.

Parameters:

constraint (function) – Transformation to be applied to weights

SortLayer(axis=-1)[source]

Sort layer used for lipschitz preserving nonlinear activation function.

https://arxiv.org/abs/1811.05381

batch_generator(X, y, balance=False)[source]
clip(x, eps=1.52587890625e-05)[source]
constrained_model(constraints, get_layers=<function <lambda>>, output_shape=1, preprocess=<function identity>, postprocess=<function sigmoid_clip>, rng=Array([0, 0], dtype=uint32))[source]

Create a neural network with groups of constraints assigned to each feature. Separate constrained neural networks are generated for each group of constraints. Each feature is fed into exactly one of these neural networks (the one that matches its assigned group of constraints). The output of these constrained neural networks are concatenated and fed into one final neural network that obeys the union of all constraints applied.

Parameters:
  • constraints (list) – List sets of constraints (one frozenset of constraints for each feature)

  • get_layers (function) – Returns shape of constrained neural network given size of input (i.e. the number of features that will be fed into it).

  • preprocess – Preprocessing function to be applied to feature vector before being sent through any neural networks. This can be useful for adjusting signs for monotonic neural networks or scales for lipschitz ones.

  • postprocess – Final activation applied to output of neural network.

  • rng – jax PRNGKey

Returns:

init_params, model

cross_entropy(y, y_pred)[source]
identity(weights)[source]
lipschitz_constraint(weights)[source]

Lipschitz constraint on weights.

https://arxiv.org/abs/1811.05381

monotonic_constraint(weights)[source]

Monotonicity constraint on weights.

net(layers, linear, activation)[source]
parameterized_loss(loss, net)[source]
plot(model, data, feature, N=256, n_interp=1024, fig=None, rng=Array([0, 0], dtype=uint32))[source]

Individual Conditional Expectation plot (blue) and Partial Dependence Plot (red).

Parameters:
  • model (function) – function from data (as dataframe) to scores

  • feature (string) – Feature to examine

  • N (int) – size of sample from data to consider for averages and conditional expectation plots. Determines number of blue lines and sample size for averaging to create red line.

  • n_interp (int) – Number of values of feature to evaluate along x-axis. Randomly chosen from unique values of this feature within data.

  • fig – Matplotlib figure. Defaults to None.

  • rng (PRNGKey) – Jax PRNGKey

Returns:

matplolib figure

sigmoid_clip(inputs, eps=1.52587890625e-05)[source]
thread(fns)[source]
train(train, test, net, metric, loss_fn=<function cross_entropy>, mini_batch_size=32, num_epochs=64, step_size=0.01, track=1)[source]

Train interpretable neural network. This routine will check accuracy using metric every track epochs. The model parameters with the highest accuracy are returned.

Parameters:
  • train (tuple) – (X, y), each jax.numpy.array of type float.

  • test (tuple) – (X, y), each jax.numpy.array of type float.

  • net (tuple) – (init_params, model) a jax model returned by constrained_model.

  • metric (function) – function of two jax arrays: ground truth and predictions. Returns float representing performance metric.

  • loss_fn (function) – function of two jax arrays: ground truth and predictions. Returns float representing loss.

  • mini_batch_size (int) – Size of minibatches from train used for stochastic gradient descent

  • num_epochs (int) – Number of epochs to train

  • step_size (float) – Step size used for stochastic gradient descent

  • track (int) – Number of epochs between metric checks

Returns:

best params