Source code for mvtk.interprenet

import jax
import itertools
import public

from jax.example_libraries import optimizers, stax
from jax._src.nn.initializers import glorot_normal, normal


[docs]@public.add def monotonic_constraint(weights): """Monotonicity constraint on weights.""" return abs(weights)
[docs]@public.add def lipschitz_constraint(weights): """Lipschitz constraint on weights. https://arxiv.org/abs/1811.05381 """ return weights / abs(weights).sum(0)
[docs]@public.add def identity(weights): return weights
[docs]@public.add def clip(x, eps=2**-16): return jax.numpy.clip(x, eps, 1 - eps)
[docs]@public.add def ConstrainedDense(constraint): """Layer constructor function for a constrained dense (fully-connected) layer. Args: constraint (function): Transformation to be applied to weights """ def Dense(out_dim, W_init=glorot_normal(), b_init=normal()): def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim,) k1, k2 = jax.random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,)) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return jax.numpy.dot(inputs, constraint(W)) + b return init_fun, apply_fun return Dense
[docs]@public.add def SortLayer(axis=-1): """Sort layer used for lipschitz preserving nonlinear activation function. https://arxiv.org/abs/1811.05381 """ def init_fun(rng, input_shape): output_shape = input_shape return output_shape, () def apply_fun(params, inputs, **kwargs): return jax.numpy.sort(inputs, axis=axis) return init_fun, apply_fun
[docs]@public.add def net(layers, linear, activation): return stax.serial( *itertools.chain(*((linear(layer), activation()) for layer in layers)) )
[docs]@public.add def thread(fns): def composition(x): for fn in fns: x = fn(x) return x return composition
def partialnet(linear, activation): def init_fun(rng, input_shape, n_hyper, layers): params = [] apply_funs = [] for layer in layers: rng, layer_rng = jax.random.split(rng) layer_init, apply_layer = linear(layer) input_shape, layer_params = layer_init(layer_rng, input_shape) d1, d2 = input_shape d2 *= n_hyper input_shape = (d1, d2) params.append(layer_params) apply_funs.append(apply_layer) layer_init, apply_layer = activation() input_shape, layer_params = layer_init(layer_rng, input_shape) input_shape = (d1, d2) params.append(layer_params) apply_funs.append(apply_layer) def apply_net(multiparams, inputs): # (network 1 output), (network 2 output), ... for params, fn in zip(zip(*multiparams), apply_funs): inputs = jax.numpy.concatenate( tuple(fn(param, inputs) for param in params), axis=1 ) return inputs return params, apply_net return init_fun
[docs]@public.add def sigmoid_clip(inputs, eps=2**-16): return jax.scipy.special.expit(clip(inputs, eps))
[docs]@public.add def constrained_model( constraints, get_layers=lambda input_shape: [input_shape + 1] * 2, output_shape=1, preprocess=identity, postprocess=sigmoid_clip, rng=jax.random.PRNGKey(0), ): """Create a neural network with groups of constraints assigned to each feature. Separate constrained neural networks are generated for each group of constraints. Each feature is fed into exactly one of these neural networks (the one that matches its assigned group of constraints). The output of these constrained neural networks are concatenated and fed into one final neural network that obeys the union of all constraints applied. Args: constraints (list): List sets of constraints (one frozenset of constraints for each feature) get_layers (function): Returns shape of constrained neural network given size of input (i.e. the number of features that will be fed into it). preprocess: Preprocessing function to be applied to feature vector before being sent through any neural networks. This can be useful for adjusting signs for monotonic neural networks or scales for lipschitz ones. postprocess: Final activation applied to output of neural network. rng: jax PRNGKey Returns: init_params, model """ union = set() groups = {} for i, constraint in enumerate(constraints): union |= constraint if constraint not in groups: groups[constraint] = [i] else: groups[constraint].append(i) nets = [] catted_size = 0 for constraint, idx in groups.items(): init_net, apply_net = net( get_layers(len(idx)), ConstrainedDense(thread(constraint)), SortLayer ) rng, new_rng = jax.random.split(rng) suboutput_shape, params = init_net(new_rng, (-1, len(idx))) catted_size += suboutput_shape[1] nets.append((params, apply_net)) params1, apply_nets = zip(*nets) init_net, apply_net2 = stax.serial( net(get_layers(catted_size), ConstrainedDense(thread(union)), SortLayer), ConstrainedDense(thread(union))(output_shape), ) rng, new_rng = jax.random.split(rng) output_shape, params2 = init_net(new_rng, (-1, catted_size)) params = (params1, params2) groups = {key: jax.numpy.asarray(value) for key, value in groups.items()} def apply_net_pipeline(params, inputs): inputs = preprocess(inputs) params1, params2 = params return postprocess( apply_net2( params2, jax.numpy.concatenate( tuple( apply_net(p, inputs[:, idx]) for p, apply_net, idx in zip( params1, apply_nets, groups.values() ) ), axis=1, ), ) ) return params, apply_net_pipeline
[docs]@public.add def cross_entropy(y, y_pred): return (y * jax.numpy.log(y_pred) + (1 - y) * jax.numpy.log(1 - y_pred)).mean()
[docs]@public.add def parameterized_loss(loss, net): def _(params, batch): X, y = batch return loss(y, net(params, X)) return _
[docs]@public.add def batch_generator(X, y, balance=False): assert len(X) == len(y) if balance: weights = jax.numpy.empty(len(y)) p = jax.numpy.mean(y) weights = weights[y == 1].set(1 / p) weights = weights[y == 0].set(1 / (1 - p)) weights /= weights.sum() weights = jax.numpy.clip(weights, 0, 1) else: weights = None N = len(X) def _(batch_size, rng=jax.random.PRNGKey(0), replace=False): while True: rng, new_rng = jax.random.split(rng) idx = jax.random.choice( new_rng, N, shape=(batch_size,), p=weights, replace=replace ) yield X[idx], y[idx] return _
[docs]@public.add def train( train, test, net, metric, loss_fn=cross_entropy, mini_batch_size=32, num_epochs=64, step_size=0.01, track=1, ): """Train interpretable neural network. This routine will check accuracy using ``metric`` every ``track`` epochs. The model parameters with the highest accuracy are returned. Args: train (tuple): (X, y), each ``jax.numpy.array`` of type ``float``. test (tuple): (X, y), each ``jax.numpy.array`` of type ``float``. net (tuple): (init_params, model) a jax model returned by ``constrained_model``. metric (function): function of two jax arrays: ground truth and predictions. Returns ``float`` representing performance metric. loss_fn (function): function of two jax arrays: ground truth and predictions. Returns ``float`` representing loss. mini_batch_size (int): Size of minibatches from train used for stochastic gradient descent num_epochs (int): Number of epochs to train step_size (float): Step size used for stochastic gradient descent track (int): Number of epochs between metric checks Returns: best params """ mini_batches = batch_generator(*train)(mini_batch_size) params, apply_net = net loss = parameterized_loss(loss_fn, apply_net) @jax.jit def update(i, opt_state): return opt_update( i, jax.grad(loss)(get_params(opt_state), next(mini_batches)), opt_state ) opt_init, opt_update, get_params = optimizers.adam(step_size) opt_state = opt_init(params) best_performance = -jax.numpy.inf best_params = params for epoch in range(num_epochs): opt_state = update(epoch, opt_state) if epoch and not epoch % track: X, y = test params = get_params(opt_state) performance = metric(y, apply_net(params, X)) if performance > best_performance: best_performance = performance best_params = params # print(epoch, best_performance) return best_params
[docs]@public.add def plot( model, data, feature, N=256, n_interp=1024, fig=None, rng=jax.random.PRNGKey(0), ): r"""`Individual Conditional Expectation plot (blue) <https://christophm.github.io/interpretable-ml-book/ice.html>`_ and `Partial Dependence Plot (red) <https://christophm.github.io/interpretable-ml-book/pdp.html>`_. Args: model (function): function from data (as dataframe) to scores feature (string): Feature to examine N (int): size of sample from ``data`` to consider for averages and conditional expectation plots. Determines number of blue lines and sample size for averaging to create red line. n_interp (int): Number of values of ``feature`` to evaluate along x-axis. Randomly chosen from unique values of this feature within ``data``. fig: Matplotlib figure. Defaults to ``None``. rng (PRNGKey): Jax ``PRNGKey`` Returns: matplolib figure""" import matplotlib.pylab as plt import matplotlib.pyplot if fig is None: fig = matplotlib.pyplot.gcf() plt.clf() plt.title(feature) plt.ylabel("Model Score") plt.xlabel(feature) data = data.sort_values([feature]) rng, new_rng = jax.random.split(rng) all_values_idx = list( jax.random.choice(new_rng, len(data), shape=(N,), replace=False) ) all_values = data.values[all_values_idx] _, unique_feature_idx = jax.numpy.unique(data[feature].values, return_index=True) nunique = len(unique_feature_idx) n_interp = min(nunique, n_interp) feature_idx = list( unique_feature_idx[ jax.random.choice(rng, nunique, shape=(n_interp,), replace=False).sort() ] ) feature_values = data[feature].values[feature_idx] rest = jax.numpy.asarray( [i for i, column in enumerate(data.columns) if column != feature], dtype="int32" ) all_scores = [] data_values = jax.numpy.asarray(data.values[feature_idx]) for replacement in all_values[:, rest]: fixed_values = data_values.at[jax.numpy.index_exp[:, rest]].set(replacement) scores = model(fixed_values) all_scores.append(scores) plt.plot(feature_values, scores, "b", alpha=0.125) plt.plot(feature_values, jax.numpy.asarray(all_scores).mean(0), "r", linewidth=2.0) return fig