import jax
import itertools
import public
from jax.example_libraries import optimizers, stax
from jax._src.nn.initializers import glorot_normal, normal
[docs]@public.add
def monotonic_constraint(weights):
    """Monotonicity constraint on weights."""
    return abs(weights) 
[docs]@public.add
def lipschitz_constraint(weights):
    """Lipschitz constraint on weights.
    https://arxiv.org/abs/1811.05381
    """
    return weights / abs(weights).sum(0) 
[docs]@public.add
def identity(weights):
    return weights 
[docs]@public.add
def clip(x, eps=2**-16):
    return jax.numpy.clip(x, eps, 1 - eps) 
[docs]@public.add
def ConstrainedDense(constraint):
    """Layer constructor function for a constrained dense (fully-connected)
    layer.
    Args:
        constraint (function): Transformation to be applied to weights
    """
    def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
        def init_fun(rng, input_shape):
            output_shape = input_shape[:-1] + (out_dim,)
            k1, k2 = jax.random.split(rng)
            W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
            return output_shape, (W, b)
        def apply_fun(params, inputs, **kwargs):
            W, b = params
            return jax.numpy.dot(inputs, constraint(W)) + b
        return init_fun, apply_fun
    return Dense 
[docs]@public.add
def SortLayer(axis=-1):
    """Sort layer used for lipschitz preserving nonlinear activation function.
    https://arxiv.org/abs/1811.05381
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape
        return output_shape, ()
    def apply_fun(params, inputs, **kwargs):
        return jax.numpy.sort(inputs, axis=axis)
    return init_fun, apply_fun 
[docs]@public.add
def net(layers, linear, activation):
    return stax.serial(
        *itertools.chain(*((linear(layer), activation()) for layer in layers))
    ) 
[docs]@public.add
def thread(fns):
    def composition(x):
        for fn in fns:
            x = fn(x)
        return x
    return composition 
def partialnet(linear, activation):
    def init_fun(rng, input_shape, n_hyper, layers):
        params = []
        apply_funs = []
        for layer in layers:
            rng, layer_rng = jax.random.split(rng)
            layer_init, apply_layer = linear(layer)
            input_shape, layer_params = layer_init(layer_rng, input_shape)
            d1, d2 = input_shape
            d2 *= n_hyper
            input_shape = (d1, d2)
            params.append(layer_params)
            apply_funs.append(apply_layer)
            layer_init, apply_layer = activation()
            input_shape, layer_params = layer_init(layer_rng, input_shape)
            input_shape = (d1, d2)
            params.append(layer_params)
            apply_funs.append(apply_layer)
        def apply_net(multiparams, inputs):
            # (network 1 output), (network 2 output), ...
            for params, fn in zip(zip(*multiparams), apply_funs):
                inputs = jax.numpy.concatenate(
                    tuple(fn(param, inputs) for param in params), axis=1
                )
            return inputs
        return params, apply_net
    return init_fun
[docs]@public.add
def sigmoid_clip(inputs, eps=2**-16):
    return jax.scipy.special.expit(clip(inputs, eps)) 
[docs]@public.add
def constrained_model(
    constraints,
    get_layers=lambda input_shape: [input_shape + 1] * 2,
    output_shape=1,
    preprocess=identity,
    postprocess=sigmoid_clip,
    rng=jax.random.PRNGKey(0),
):
    """Create a neural network with groups of constraints assigned to each
    feature. Separate constrained neural networks are generated for each group
    of constraints. Each feature is fed into exactly one of these neural
    networks (the one that matches its assigned group of constraints). The
    output of these constrained neural networks are concatenated and fed into
    one final neural network that obeys the union of all constraints applied.
    Args:
        constraints (list): List sets of constraints (one frozenset of constraints
            for each feature)
        get_layers (function): Returns shape of constrained neural network
            given size of input (i.e. the number of features that will be fed into
            it).
        preprocess: Preprocessing function to be applied to feature vector
            before being sent through any neural networks. This can be useful for
            adjusting signs for monotonic neural networks or scales for lipschitz
            ones.
        postprocess: Final activation applied to output of neural network.
        rng: jax PRNGKey
    Returns:
        init_params, model
    """
    union = set()
    groups = {}
    for i, constraint in enumerate(constraints):
        union |= constraint
        if constraint not in groups:
            groups[constraint] = [i]
        else:
            groups[constraint].append(i)
    nets = []
    catted_size = 0
    for constraint, idx in groups.items():
        init_net, apply_net = net(
            get_layers(len(idx)), ConstrainedDense(thread(constraint)), SortLayer
        )
        rng, new_rng = jax.random.split(rng)
        suboutput_shape, params = init_net(new_rng, (-1, len(idx)))
        catted_size += suboutput_shape[1]
        nets.append((params, apply_net))
    params1, apply_nets = zip(*nets)
    init_net, apply_net2 = stax.serial(
        net(get_layers(catted_size), ConstrainedDense(thread(union)), SortLayer),
        ConstrainedDense(thread(union))(output_shape),
    )
    rng, new_rng = jax.random.split(rng)
    output_shape, params2 = init_net(new_rng, (-1, catted_size))
    params = (params1, params2)
    groups = {key: jax.numpy.asarray(value) for key, value in groups.items()}
    def apply_net_pipeline(params, inputs):
        inputs = preprocess(inputs)
        params1, params2 = params
        return postprocess(
            apply_net2(
                params2,
                jax.numpy.concatenate(
                    tuple(
                        apply_net(p, inputs[:, idx])
                        for p, apply_net, idx in zip(
                            params1, apply_nets, groups.values()
                        )
                    ),
                    axis=1,
                ),
            )
        )
    return params, apply_net_pipeline 
[docs]@public.add
def cross_entropy(y, y_pred):
    return (y * jax.numpy.log(y_pred) + (1 - y) * jax.numpy.log(1 - y_pred)).mean() 
[docs]@public.add
def parameterized_loss(loss, net):
    def _(params, batch):
        X, y = batch
        return loss(y, net(params, X))
    return _ 
[docs]@public.add
def batch_generator(X, y, balance=False):
    assert len(X) == len(y)
    if balance:
        weights = jax.numpy.empty(len(y))
        p = jax.numpy.mean(y)
        weights = weights[y == 1].set(1 / p)
        weights = weights[y == 0].set(1 / (1 - p))
        weights /= weights.sum()
        weights = jax.numpy.clip(weights, 0, 1)
    else:
        weights = None
    N = len(X)
    def _(batch_size, rng=jax.random.PRNGKey(0), replace=False):
        while True:
            rng, new_rng = jax.random.split(rng)
            idx = jax.random.choice(
                new_rng, N, shape=(batch_size,), p=weights, replace=replace
            )
            yield X[idx], y[idx]
    return _ 
[docs]@public.add
def train(
    train,
    test,
    net,
    metric,
    loss_fn=cross_entropy,
    mini_batch_size=32,
    num_epochs=64,
    step_size=0.01,
    track=1,
):
    """Train interpretable neural network. This routine will check accuracy
    using ``metric`` every ``track`` epochs. The model parameters with the
    highest accuracy are returned.
    Args:
        train (tuple): (X, y), each ``jax.numpy.array`` of type ``float``.
        test (tuple): (X, y), each ``jax.numpy.array`` of type ``float``.
        net (tuple): (init_params, model) a jax model returned by
            ``constrained_model``.
        metric (function): function of two jax arrays: ground truth and
            predictions. Returns ``float`` representing performance metric.
        loss_fn (function): function of two jax arrays: ground truth and
            predictions. Returns ``float`` representing loss.
        mini_batch_size (int): Size of minibatches from train used for
            stochastic gradient descent
        num_epochs (int): Number of epochs to train
        step_size (float): Step size used for stochastic gradient descent
        track (int): Number of epochs between metric checks
    Returns:
        best params
    """
    mini_batches = batch_generator(*train)(mini_batch_size)
    params, apply_net = net
    loss = parameterized_loss(loss_fn, apply_net)
    @jax.jit
    def update(i, opt_state):
        return opt_update(
            i, jax.grad(loss)(get_params(opt_state), next(mini_batches)), opt_state
        )
    opt_init, opt_update, get_params = optimizers.adam(step_size)
    opt_state = opt_init(params)
    best_performance = -jax.numpy.inf
    best_params = params
    for epoch in range(num_epochs):
        opt_state = update(epoch, opt_state)
        if epoch and not epoch % track:
            X, y = test
            params = get_params(opt_state)
            performance = metric(y, apply_net(params, X))
            if performance > best_performance:
                best_performance = performance
                best_params = params
            # print(epoch, best_performance)
    return best_params 
[docs]@public.add
def plot(
    model,
    data,
    feature,
    N=256,
    n_interp=1024,
    fig=None,
    rng=jax.random.PRNGKey(0),
):
    r"""`Individual Conditional Expectation plot
    (blue) <https://christophm.github.io/interpretable-ml-book/ice.html>`_ and
    `Partial Dependence Plot
    (red) <https://christophm.github.io/interpretable-ml-book/pdp.html>`_.
    Args:
        model (function): function from data (as dataframe) to scores
        feature (string): Feature to examine
        N (int): size of sample from ``data`` to consider for averages and
            conditional expectation plots. Determines number of blue lines and
            sample size for averaging to create red line.
        n_interp (int): Number of values of ``feature`` to evaluate along
            x-axis. Randomly chosen from unique values of this feature within
            ``data``.
        fig: Matplotlib figure. Defaults to ``None``.
        rng (PRNGKey): Jax ``PRNGKey``
    Returns:
        matplolib figure"""
    import matplotlib.pylab as plt
    import matplotlib.pyplot
    if fig is None:
        fig = matplotlib.pyplot.gcf()
    plt.clf()
    plt.title(feature)
    plt.ylabel("Model Score")
    plt.xlabel(feature)
    data = data.sort_values([feature])
    rng, new_rng = jax.random.split(rng)
    all_values_idx = list(
        jax.random.choice(new_rng, len(data), shape=(N,), replace=False)
    )
    all_values = data.values[all_values_idx]
    _, unique_feature_idx = jax.numpy.unique(data[feature].values, return_index=True)
    nunique = len(unique_feature_idx)
    n_interp = min(nunique, n_interp)
    feature_idx = list(
        unique_feature_idx[
            jax.random.choice(rng, nunique, shape=(n_interp,), replace=False).sort()
        ]
    )
    feature_values = data[feature].values[feature_idx]
    rest = jax.numpy.asarray(
        [i for i, column in enumerate(data.columns) if column != feature], dtype="int32"
    )
    all_scores = []
    data_values = jax.numpy.asarray(data.values[feature_idx])
    for replacement in all_values[:, rest]:
        fixed_values = data_values.at[jax.numpy.index_exp[:, rest]].set(replacement)
        scores = model(fixed_values)
        all_scores.append(scores)
        plt.plot(feature_values, scores, "b", alpha=0.125)
    plt.plot(feature_values, jax.numpy.asarray(all_scores).mean(0), "r", linewidth=2.0)
    return fig