Source code for mvtk.supervisor.divergence.metrics

__all__ = []

import itertools
import jax
import numpy
import scipy
import public

from .nn import Approximator, NormalizedLinear
from .generators import fdiv_data_stream, js_data_stream
from collections import Counter
from functools import partial
from jax.example_libraries import optimizers
from mvtk.supervisor.utils import parallel, split
from .utils import arrayify


[docs]@public.add def calc_div_variational(data_stream, loss, model_generator=Approximator, summary=""): r"""Calculate an :math:`f`-divergence or integral probability metric using a variational of hybrid variational estimator. Variational estimates will generally (but, not always, thanks to the training proceedure!) be a lower bound on the true value Args: data_stream (generator): The data stream generator. loss (function): Loss function that takes as arguments the model outputs. Returns a scalar. model_generator: A function that takes a Jax ``PRNGKey`` the number of dimensions of the support and returns a `Jax model <https://jax.readthedocs.io/en/latest/jax.example_libraries.stax.html>`_ to be used for variational approximations. The function this model is trained to approximate is sometimes known as the *witness function*--especially when dealing with `integral probability metrics <http://www.gatsby.ucl.ac.uk/~gretton/papers/montreal19.pdf>`_. Specifically, the function returns a tuple that contains the initial parameters and a function that maps those parameters and the model inputs to the model outputs. Defaults to :meth:`supervisor.divergence.Approximator`. summary (string): Summary of divergence to appear in docstring Returns: function for computing divergence""" def calc_div( *sample_distributions, categorical_columns=tuple(), model_generator_kwargs={}, loss_kwargs={}, nprng=None, batch_size=16, num_batches=128, num_epochs=4, effective_sample_size=None, train_test_split=0.75, step_size=0.0125 ): r""" Args: *sample_distributions (list): Sample distributions. A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton. model_generator_kwargs (optional): Dictionary of optional kwargs to pass to model_generator. ``width`` and ``depth`` are useful. See :meth:`supervisor.divergence.Approximator` for more details. loss_kwargs (optional): Dictionary of optional kwargs to pass to loss function. ``weights`` is commonly used for reweighting expectations. See `hybrid estimation <supervisor_user_guide.rst#hybrid-estimation>`__ for details. categorical_columns (optional): List of indices of columns which should be treated as categorical. nprng (optional): Numpy ``RandomState`` batch_size (int): mini batch size. Defaults to 16. num_batches (int): number of batches per epoch. Defaults to 128. num_epochs (int): number of epochs to train for. Defaults to 4. effective_sample_size (optional): Size of subsample over which Epoch losses are computed. This determines how large a sample a divergence is computed over. train_test_split (optional): If not None, specifies the proportion of samples devoted to training as opposed to validation. If None, no split is used. Defaults to 0.75. step_size (float): step size for Adam optimizer Returns: Estimate of divergence.""" if nprng is None: nprng = numpy.random.RandomState(1) sample_distributions = tuple(map(arrayify, sample_distributions)) if train_test_split is None: training_samples = validation_samples = sample_distributions else: training_samples, validation_samples = zip( *( zip( *( split(sample, train_ratio=train_test_split, nprng=nprng) for sample in sample_distribution ) ) for sample_distribution in sample_distributions ) ) mini_batches = data_stream( nprng, batch_size, training_samples, categorical_columns=categorical_columns ) if effective_sample_size is None: effective_sample_size = num_batches * batch_size large_batches = data_stream( nprng, effective_sample_size, validation_samples, categorical_columns=categorical_columns, ) input_size = next(next(mini_batches)[0].values().__iter__()).shape[1] init_params, approximator = model_generator( input_size, **model_generator_kwargs ) key_to_index = ( { key: index for index, key in enumerate( get_density_estimators(categorical_columns, *sample_distributions) ) } if categorical_columns else {tuple(): 0} ) opt_init, opt_update, get_params = optimizers.adam(step_size) opt_state = opt_init({key: init_params for key in key_to_index}) def _loss(params, batch): return loss( ( jax.numpy.vstack( approximator(params[key], sample) for key, sample in samples.items() ) for samples in batch ), **loss_kwargs ) @jax.jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, jax.grad(_loss)(params, batch), opt_state) itercount = itertools.count() best_loss = numpy.inf for epoch in range(num_epochs): for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(mini_batches)) params = get_params(opt_state) epoch_loss = _loss(params, next(large_batches)) if numpy.isnan(epoch_loss): raise ValueError( """The loss is NaN. Make sure floating point arithmetic makes sense for your data.""" ) if epoch_loss < best_loss: # track the best expected loss over all epochs this is where a # biased sample could push estimate of the expected loss higher # than the true :math:`f`-divergence I think the trade off for # fast convergence given a low epoch number is well worth it. best_loss = epoch_loss return numpy.clip(-best_loss, 0, numpy.inf) # clip just above zero calc_div.__doc__ = summary + calc_div.__doc__ return calc_div
[docs]@public.add def fdiv_loss(convex_conjugate): """General template for :math:`f`-divergence losses given convex conjugate. Args: convex_conjugate: The convex conjugate of the function, :math:`f`. """ def loss(batch, weights=(1, 1)): r"""Args: batch: pair of minibatches drawn from each sample weights: Provides an alternative means of reweighting minibatches. See `hybrid estimation <supervisor_user_guide.rst#hybrid-estimation>`__ for details.""" input1, input2 = batch batch_loss = ( convex_conjugate(input2).mean() * weights[1] - input1.mean() * weights[0] ) # print(batch_loss) return batch_loss return loss
[docs]@public.add def ipm_loss(batch): """Integral probability metric loss.""" input1, input2 = batch batch_loss = input2.mean() - input1.mean() return -jax.numpy.abs(batch_loss)
calc_tv = calc_div_variational( fdiv_data_stream, fdiv_loss(lambda x: x), model_generator=partial(Approximator, activation=lambda x: jax.numpy.tanh(x) / 2), summary=r"""Total variation - :math:`f`-divergence form :math:`\frac{1}{2}\int dx \vert p\left(x\right) - q\left(x\right) \vert = \sup\limits_{f : \|f\|_\infty \le \frac{1}{2}} \mathbb{E}_{x \sim p}\left[f(x)\right] - \mathbb{E}_{x^\prime \sim q}\left[f(x^\prime)\right]` https://arxiv.org/abs/1606.00709""", ) __all__.append("calc_tv") calc_em = calc_div_variational( fdiv_data_stream, ipm_loss, model_generator=partial(Approximator, linear=NormalizedLinear, residual=False), summary=r"""Wasserstein-1 (Earth Mover's) metric :math:`\int dxdx^\prime d(x, x^\prime)\gamma(x, x^\prime)` , with :math:`d(x, x^\prime)=\|x - x^\prime\|_1` subject to constraints :math:`\int dx^\prime\gamma(x, x^\prime) = p(x)` :math:`\int dx\gamma(x, x^\prime) = q(x^\prime)` Via Kantorovich-Rubinstein duality, this is equivalent to :math:`\sup\limits_{f \in \mathcal{F}} \vert \mathbb{E}_{x \sim p}\left[f(x)\right] - \mathbb{E}_{x^\prime \sim q}\left[f(x^\prime)\right] \vert` , with :math:`\mathcal{F} = \{f: \|f(x) - f(x^\prime)\|_1 \le \|x - x^\prime\|_1 \}` According to `Joel Tropp's thesis section 4.3.1 <http://users.cms.caltech.edu/~jtropp/papers/Tro04-Topics-Sparse.pdf>`_, the operator norm of a linear transformation from an :math:`L^1` metric space to an :math:`L^1` metric space is bounded above by the :math:`L^1` norms of its columns. This is realized by normalzing the weight columns with an :math:`L^1` norm and excluding residual connections before applying them.""", ) __all__.append("calc_em") calc_js = calc_div_variational( js_data_stream, fdiv_loss(lambda y: (jax.numpy.log(-1 / (y * numpy.log(2))) - 1) / numpy.log(2)), model_generator=partial(Approximator, activation=lambda y: -jax.numpy.exp(y)), summary=r"""Jensen-Shannon divergence calculator :math:`f(x) = -\log_2(x)` :math:`f^{*}(y) = \sup\limits_x \left[xy - f(x)\right]` :math:`\frac{d}{dx}\left[xy - f(x)\right] = 0` :math:`x = \frac{-1}{y\log(2)}` :math:`f^{*}(y) = -\frac{\log\left(-y\log(2)\right) + 1}{\log(2)}` Note that the domain of this function (when assumed to be real valued) is naturally :math:`y < 0`.""", ) __all__.append("calc_js") calc_hl = calc_div_variational( fdiv_data_stream, fdiv_loss(lambda y: -1 / (4 * y) - 1), model_generator=partial(Approximator, activation=lambda y: -abs(y)), summary=r"""Hellinger distance calculator :math:`f(x) = 1 - \sqrt{x}` :math:`f^{*}(y) = \sup\limits_x\left[xy - f(x)\right]` :math:`\frac{d}{dx}\left[xy - f(x)\right] = 0` :math:`x = \frac{1}{4y ^ 2}` :math:`f^{*}(y) = \frac{1}{2\vert y \vert} + \frac{1}{4y} - 1` Since the `Fenchel–Moreau theorem <https://en.wikipedia.org/wiki/Fenchel%E2%80%93Moreau_theorem>`_ requires the convex conjugate to be lower semicontinuous for bicongugacy to hold, we take :math:`y < 0`. This in turn simplifies the expression of :math:`f^{*}` to :math:`f^{*}(y) = -\frac{1}{4y} - 1`""", ) __all__.append("calc_hl")
[docs]@public.add def histogram(data): histogram = {} N = len(data) for key, count in Counter(map(tuple, data)).items(): histogram[key] = count / N return histogram
def join_keys(dictionaries): keys = set() for dictionary in dictionaries: keys |= dictionary.keys() return keys
[docs]@public.add def average_histograms(histograms): avg = {} N = len(tuple(histograms)) for key in join_keys(histograms): p = 0 for histogram in histograms: if key in histogram: p += histogram[key] avg[key] = p / N return avg
[docs]@public.add def cat_histograms(histograms): histogram = {} for key in join_keys(histograms): histogram[key] = tuple( histogram[key] if key in histogram else 0 for histogram in histograms ) return histogram
[docs]@public.add def get_density_estimators(categorical_columns, *sample_distributions): return cat_histograms( tuple( average_histograms( tuple( histogram(samples[:, categorical_columns]) for samples in sample_distribution ) ) for sample_distribution in sample_distributions ) )
[docs]@public.add def metric_from_density(metric, *densities): return metric(*numpy.asarray(densities).T)
[docs]@public.add def calc_mle(metric, *sample_distributions): sample_distributions = tuple(map(arrayify, sample_distributions)) categorical_columns = numpy.arange(sample_distributions[0][0].shape[1], dtype="int") densities = get_density_estimators( categorical_columns, *sample_distributions ).values() return metric_from_density(metric, *densities)
[docs]@public.add def calc_hl_density(density_p, density_q): r"""Hellinger distance calculated from histograms. Hellinger distance is defined as :math:`\sqrt{\frac{1}{2}\sum\limits_{x\in\mathcal{X}}\left(\sqrt{p(x)} - \sqrt{q(x)}\right)^2}`. Args: density_p (list): probability mass function of p density_q (list): probability mass function of q""" return numpy.sqrt(((numpy.sqrt(density_p) - numpy.sqrt(density_q)) ** 2).sum() / 2)
[docs]@public.add def calc_hl_mle(sample_distribution_p, sample_distribution_q): r"""Hellinger distance calculated via histogram based density estimators. Hellinger distance is defined as :math:`\sqrt{\frac{1}{2}\sum\limits_{x\in\mathcal{X}}\left(\sqrt{p(x)} - \sqrt{q(x)}\right)^2}`. Args: sample_distribution_p (list): A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton. sample_distribution_q (list):""" return calc_mle(calc_hl_density, sample_distribution_p, sample_distribution_q)
[docs]@public.add def calc_tv_density(density_p, density_q): r"""Total variation calculated from histograms. For two distributions, :math:`p` and :math:`q` defined over the same probability space, :math:`\mathcal{X}`, the total variation is defined as :math:`\frac{1}{2}\sum\limits_{x\in\mathcal{X}}\vert p(x) - q(x)\vert`. Args: density_p (list): probability mass function of p density_q (list): probability mass function of q""" return numpy.abs(density_p - density_q).sum() / 2
[docs]@public.add def calc_tv_mle(sample_distribution_p, sample_distribution_q): r"""Total variation calculated via histogram based density estimators. All columns are assumed to be categorical. For two distributions, :math:`p` and :math:`q` defined over the same probability space, `\mathcal{X}`, the total variation is defined as :math:`\frac{1}{2}\sum\limits_{x\in\mathcal{X}}\vert p(x) - q(x)\vert`. Args: sample_distribution_p (list): A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton. sample_distribution_q (list):""" return calc_mle(calc_tv_density, sample_distribution_p, sample_distribution_q)
[docs]@public.add def calc_kl_density(density_p, density_q): r"""Kullback–Leibler (KL) divergence calculated from histograms. For two distributions, :math:`p` and :math:`q` defined over the same probability space, `\mathcal{X}`, the total variation is defined as :math:`\sum\limits_{x\in\mathcal{X}}p(x)\log\left(\frac{p(x)}{q(x)}\right)`. Args: density_p (list): probability mass function of :math:`p` density_q (list): probability mass function of :math:`q`""" return numpy.log((density_p / density_q) ** density_p).sum()
[docs]@public.add def calc_kl_mle(sample_distribution_p, sample_distribution_q): r"""Kullback–Leibler (KL) divergence calculated via histogram based density estimators. For two distributions, :math:`p` and :math:`q` defined over the same probability space, `\mathcal{X}`, the KL divergence is defined as :math:`\sum\limits_{x\in\mathcal{X}}p(x)\log\left(\frac{p(x)}{q(x)}\right)`. Args: sample_distribution_p (list): A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton. sample_distribution_q (list):""" return calc_mle(calc_kl_density, sample_distribution_p, sample_distribution_q)
[docs]@public.add def calc_js_density(*densities): r"""Jensen-Shannon divergence calculated from histograms. For two distributions, :math:`p` and :math:`q` defined over the same probability space, `\mathcal{X}`, the Jensen-Shannon divergence is defined as the average of the KL divergences between each probability mass function and the average of all probability mass functions being compared. This is well defined for more than two probability masses, and will be zero when all probability masses have disjoint support and 1 when they are all identical and the KL divergences are taken using a logarithmic base equal to the number of probability masses being compared. Typically, there will be only two probability mass functions, and the logarithmic base is therefore taken to be 2. Args: *densities (list): probability mass functions""" n = len(densities) mean = sum(densities) / n return sum(calc_kl_density(density, mean) for density in densities) / ( n * numpy.log(n) )
[docs]@public.add def calc_js_mle(*sample_distributions): r"""Jensen-Shannon divergences calculated via histogram based density estimators. For two distributions, :math:`p` and :math:`q` defined over the same probability space, `\mathcal{X}`, the Jensen-Shannon divergence is defined as the average of the KL divergences between each probability mass function and the average of all probability mass functions being compared. This is well defined for more than two probability masses, and will be zero when all probability masses have disjoint support and 1 when they are all identical and the KL divergences are taken using a logarithmic base equal to the number of probability masses being compared. Typically, there will be only two probability mass functions, and the logarithmic base is therefore taken to be 2. Args: *sample_distributions (list): A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton.""" return calc_mle(calc_js_density, *sample_distributions)
[docs]@public.add def cal_div_knn( divergence, sample_distribution_p, sample_distribution_q, bias=lambda N, k: 0, num_samples=2048, categorical_columns=tuple(), nprng=numpy.random.RandomState(0), k=128, ): r""":math:`f`-divergence from knn density estimators Args: divergence: :math:`f` that defines the :math:`f`-divergence. sample_distribution_p (list): A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton. sample_distribution_q (list): bias (function): function of the number of samples and number of nearest neighbors that compensates for expected bias of plugin estimator. num_samples (int, optional): Number of subsamples to take from each distribution on which to construct kdtrees and otherwise make computations. Defaults to 2046. k (int, optional): Number of nearest neighbors. As a rule of thumb, you should multiply this by two with every dimension past one. Defaults to 128.""" sample_distribution_p = arrayify(sample_distribution_p) sample_distribution_q = arrayify(sample_distribution_q) p, q = next( fdiv_data_stream( nprng, num_samples, (sample_distribution_p, sample_distribution_q), categorical_columns=categorical_columns, ) ) def knn_ratio(ptree, qtree, x): d = max(qtree.query(x, k=k)[0]) n = len(ptree.query_ball_point(x, d)) return divergence(n / (k + 1)) numerator = 0 denominator = 0 for key, conditional in q.items(): denominator += len(conditional) if key not in p: continue qtree = scipy.spatial.cKDTree(conditional) ptree = scipy.spatial.cKDTree(p[key]) numerator += numpy.sum(parallel(partial(knn_ratio, ptree, qtree), qtree.data)) return max(0, numerator / denominator - bias(num_samples, k))
[docs]@public.add def calc_tv_knn(sample_distribution_p, sample_distribution_q, **kwargs): r"""Total variation from knn density estimators Args: divergence: :math:`f` that defines the :math:`f`-divergence. sample_distribution_p (list): A numpy array, pandas dataframe, or pandas series or a list of numpy arrays, dataframes or series. If it is a list then will sample from each in the list For example, ``[[batch1, batch2, batch3], [batch4, batch5], [batch6, batch7]]`` Assuming ``batch1`` came from distribution :math:`p_1`, ``batch2`` from :math:`p_2`, etc, this function will simulate a system in which a latent `N=3` sided die role that determines whether to draw a sample from :math:`\frac{p_1 + p_2 + p_3}{3}`, :math:`\frac{p_4 + p_5}{2}`, or :math:`\frac{p_6 + p_7}{2}`. The outer most list is typically a singleton. sample_distribution_q (list): bias (function): function of the number of samples and number of nearest neighbors that compensates for expected bias of plugin estimator. num_samples (int, optional): Number of subsamples to take from each distribution on which to construct kdtrees and otherwise make computations. Defaults to 2046. k (int, optional): Number of nearest neighbors. As a rule of thumb, you should multiply this by two with every dimension past one. Defaults to 128.""" def bias(N, k): def integral_no_p(p): return ( (1 - p) ** (-k + N) * p**k - N * scipy.special.betainc(k, 1 - k + N, p) ) / (k - N) def integral_with_p(p): return scipy.special.betainc(k + 1, N - k, p) r0 = (k - 1) / N p_less = ( integral_no_p(r0) - integral_with_p(r0) - (integral_no_p(0) - integral_with_p(0)) ) p_greater = ( integral_with_p(1) - integral_no_p(1) - (integral_with_p(r0) - integral_no_p(r0)) ) return p_less + p_greater return cal_div_knn( lambda r: abs(1 - r), sample_distribution_p, sample_distribution_q, bias=bias, **kwargs )
[docs]@public.add def balanced_binary_cross_entropy(y_true, y_pred): r"""Compute cross entropy loss while compensating for class imbalance Args: y_true (array): Ground truth, binary or soft labels. y_pred (array): Array of model scores.""" P = y_true.sum() N = len(y_true) - P return ( scipy.special.rel_entr(y_true, y_pred).sum() / P + scipy.special.rel_entr(1 - y_true, 1 - y_pred).sum() / N ) / 2
[docs]@public.add def calc_tv_lower_bound(log_loss): r"""Lower bound of total variation. A model (not provided) must be trained to classify data as belonging to one of two datasets using log loss, ideally compensating for class imbalance during training. This function will compute an lower bound on the total variation of the two datasets the model was trained to distinguish using the loss from the validation set. Args: log_loss (float): Binary cross entropy loss with class imbalance compensated.""" js0 = 1 - log_loss / numpy.log(2) return max(0, js0)