import multiprocessing
import sys
import time
import pandas as pd
import numpy as np
import public
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Collection
from typing import List
from itertools import combinations
from fastcore.imports import in_notebook
if in_notebook():
    from tqdm import tqdm_notebook as tqdm
else:
    from tqdm import tqdm
[docs]@public.add
def parallel(func, arr: Collection, max_workers=None, show_progress: bool = False):
    """
    NOTE: This code was adapted from the ``parallel`` function
        within Fastai's Fastcore library. Key differences include
        returning a list with order preserved.
    Run a function on a collection (list, set etc) of items
    :param func: The function to run
    :param arr: The collection to run on
    :param max_workers: How many workers to use. Will use
        multiprocessing.cpu_count() if this is not provided
    :return: a list of the results
    """
    if show_progress:
        progress_bar = tqdm(arr, smoothing=0, file=sys.stdout)
    results = []
    max_workers = multiprocessing.cpu_count() if max_workers is None else max_workers
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        future_to_index = {ex.submit(func, o): i for i, o in enumerate(arr)}
        for future in as_completed(future_to_index):
            results.append((future_to_index[future], future.result()))
            if show_progress:
                progress_bar.update()
    results.sort()
    # Complete the progress bar if not complete
    if show_progress:
        for n in range(progress_bar.n, len(list(arr))):
            time.sleep(0.1)
            progress_bar.update()
    return [result for i, result in results] 
[docs]@public.add
def column_indexes(df: pd.DataFrame, cols: List[str]):
    """
    :param df: The dataframe
    :param cols: a list of column names
    :return: The column indexes of the column names
    """
    return [df.columns.get_loc(col) for col in cols if col in df] 
def format_date(date_str, dateformat="%b%d"):
    date = pd.to_datetime(date_str)
    return datetime.strftime(date, dateformat)
[docs]@public.add
def compute_divergence_crosstabs(
    data, datecol=None, format=None, show_progress=True, divergence=None
):
    """Compute the divergence crosstabs.
    :param data: The data to compute the divergences on
    :param datecol: The column representing the date. If None, will
        use the index, if the index is a datetimeindex
    :param format: A function applied to datecol values for formatting
        e.g. ``format_date``
    :param show_progress: Whether the progress bar will be shown
    :param divergence: The divergence function to use
    """
    if datecol is None:
        datecol = data.index
    dates, subsets = zip(*data.groupby(datecol))
    dates = list(dates)
    subsets = (subset.drop(columns=[datecol]) for subset in subsets)
    return compute_divergence_crosstabs_split(
        subsets, dates, format, show_progress, divergence
    ) 
[docs]@public.add
def compute_divergence_crosstabs_split(
    subsets, dates, format=None, show_progress=True, divergence=None
):
    """Compute the divergence crosstabs.
    :param subsets: The data to compute the divergences on
    :param dates: The list of dates for the subsets
    :param format: A function applied to datecol values for formatting
        e.g. ``format_date``
    :param show_progress: Whether the progress bar will be shown
    :param divergence: The divergence function to use
    """
    # Create a divergence matrix
    divergences = np.zeros((len(dates), len(dates)))
    if not divergence:
        from mvtk.supervisor.divergence import calc_tv
        divergence = calc_tv
    def compute_divergence(args):
        return divergence(*args)
    for (i, j), v in zip(
        combinations(range(len(dates)), 2),
        parallel(
            compute_divergence, combinations(subsets, 2), show_progress=show_progress
        ),
    ):
        divergences[i, j] = divergences[j, i] = v
    if format is None:
        formatted = dates
    else:
        formatted = [format(d) for d in dates]
    return pd.DataFrame(divergences, columns=formatted, index=formatted) 
[docs]@public.add
def plot_divergence_crosstabs_3d(divergences):
    """Plot the divergences in 3d.
    :params divergences: The list of divergences
    """
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D  # noqa F401
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    keys = list(divergences.keys())
    indexes = range(len(keys))
    for i in indexes:
        y = [x[1] for x in list(divergences[keys[i]].items())]
        ax.bar(indexes, y, i, zdir="y", alpha=0.8)
    ax.set(xticks=indexes, xticklabels=keys, yticks=indexes, yticklabels=keys)
    return fig 
[docs]@public.add
def split(x, train_ratio=0.5, nprng=np.random.RandomState(0)):
    i = int(len(x) * train_ratio)
    if hasattr(x, "shape"):
        idx = np.arange(x.shape[0])
        nprng.shuffle(idx)
        x = x[idx]
    else:
        nprng.shuffle(x)
    return x[:i], x[i:]