interprenet
- ConstrainedDense(constraint)[source]
Layer constructor function for a constrained dense (fully-connected) layer.
- Parameters:
constraint (function) – Transformation to be applied to weights
- constrained_model(constraints, get_layers=<function <lambda>>, output_shape=1, preprocess=<function identity>, postprocess=<function sigmoid_clip>, rng=Array([0, 0], dtype=uint32))[source]
Create a neural network with groups of constraints assigned to each feature. Separate constrained neural networks are generated for each group of constraints. Each feature is fed into exactly one of these neural networks (the one that matches its assigned group of constraints). The output of these constrained neural networks are concatenated and fed into one final neural network that obeys the union of all constraints applied.
- Parameters:
constraints (list) – List sets of constraints (one frozenset of constraints for each feature)
get_layers (function) – Returns shape of constrained neural network given size of input (i.e. the number of features that will be fed into it).
preprocess – Preprocessing function to be applied to feature vector before being sent through any neural networks. This can be useful for adjusting signs for monotonic neural networks or scales for lipschitz ones.
postprocess – Final activation applied to output of neural network.
rng – jax PRNGKey
- Returns:
init_params, model
- plot(model, data, feature, N=256, n_interp=1024, fig=None, rng=Array([0, 0], dtype=uint32))[source]
Individual Conditional Expectation plot (blue) and Partial Dependence Plot (red).
- Parameters:
model (function) – function from data (as dataframe) to scores
feature (string) – Feature to examine
N (int) – size of sample from
data
to consider for averages and conditional expectation plots. Determines number of blue lines and sample size for averaging to create red line.n_interp (int) – Number of values of
feature
to evaluate along x-axis. Randomly chosen from unique values of this feature withindata
.fig – Matplotlib figure. Defaults to
None
.rng (PRNGKey) – Jax
PRNGKey
- Returns:
matplolib figure
- train(train, test, net, metric, loss_fn=<function cross_entropy>, mini_batch_size=32, num_epochs=64, step_size=0.01, track=1)[source]
Train interpretable neural network. This routine will check accuracy using
metric
everytrack
epochs. The model parameters with the highest accuracy are returned.- Parameters:
train (tuple) – (X, y), each
jax.numpy.array
of typefloat
.test (tuple) – (X, y), each
jax.numpy.array
of typefloat
.net (tuple) – (init_params, model) a jax model returned by
constrained_model
.metric (function) – function of two jax arrays: ground truth and predictions. Returns
float
representing performance metric.loss_fn (function) – function of two jax arrays: ground truth and predictions. Returns
float
representing loss.mini_batch_size (int) – Size of minibatches from train used for stochastic gradient descent
num_epochs (int) – Number of epochs to train
step_size (float) – Step size used for stochastic gradient descent
track (int) – Number of epochs between metric checks
- Returns:
best params