Source code for mvtk.supervisor.utils

import multiprocessing
import sys
import time
import pandas as pd
import numpy as np
import public

from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Collection
from typing import List
from itertools import combinations
from fastcore.imports import in_notebook

if in_notebook():
    from tqdm import tqdm_notebook as tqdm
else:
    from tqdm import tqdm


[docs]@public.add def parallel(func, arr: Collection, max_workers=None, show_progress: bool = False): """ NOTE: This code was adapted from the ``parallel`` function within Fastai's Fastcore library. Key differences include returning a list with order preserved. Run a function on a collection (list, set etc) of items :param func: The function to run :param arr: The collection to run on :param max_workers: How many workers to use. Will use multiprocessing.cpu_count() if this is not provided :return: a list of the results """ if show_progress: progress_bar = tqdm(arr, smoothing=0, file=sys.stdout) results = [] max_workers = multiprocessing.cpu_count() if max_workers is None else max_workers with ThreadPoolExecutor(max_workers=max_workers) as ex: future_to_index = {ex.submit(func, o): i for i, o in enumerate(arr)} for future in as_completed(future_to_index): results.append((future_to_index[future], future.result())) if show_progress: progress_bar.update() results.sort() # Complete the progress bar if not complete if show_progress: for n in range(progress_bar.n, len(list(arr))): time.sleep(0.1) progress_bar.update() return [result for i, result in results]
[docs]@public.add def column_indexes(df: pd.DataFrame, cols: List[str]): """ :param df: The dataframe :param cols: a list of column names :return: The column indexes of the column names """ return [df.columns.get_loc(col) for col in cols if col in df]
def format_date(date_str, dateformat="%b%d"): date = pd.to_datetime(date_str) return datetime.strftime(date, dateformat)
[docs]@public.add def compute_divergence_crosstabs( data, datecol=None, format=None, show_progress=True, divergence=None ): """Compute the divergence crosstabs. :param data: The data to compute the divergences on :param datecol: The column representing the date. If None, will use the index, if the index is a datetimeindex :param format: A function applied to datecol values for formatting e.g. ``format_date`` :param show_progress: Whether the progress bar will be shown :param divergence: The divergence function to use """ if datecol is None: datecol = data.index dates, subsets = zip(*data.groupby(datecol)) dates = list(dates) subsets = (subset.drop(columns=[datecol]) for subset in subsets) return compute_divergence_crosstabs_split( subsets, dates, format, show_progress, divergence )
[docs]@public.add def compute_divergence_crosstabs_split( subsets, dates, format=None, show_progress=True, divergence=None ): """Compute the divergence crosstabs. :param subsets: The data to compute the divergences on :param dates: The list of dates for the subsets :param format: A function applied to datecol values for formatting e.g. ``format_date`` :param show_progress: Whether the progress bar will be shown :param divergence: The divergence function to use """ # Create a divergence matrix divergences = np.zeros((len(dates), len(dates))) if not divergence: from mvtk.supervisor.divergence import calc_tv divergence = calc_tv def compute_divergence(args): return divergence(*args) for (i, j), v in zip( combinations(range(len(dates)), 2), parallel( compute_divergence, combinations(subsets, 2), show_progress=show_progress ), ): divergences[i, j] = divergences[j, i] = v if format is None: formatted = dates else: formatted = [format(d) for d in dates] return pd.DataFrame(divergences, columns=formatted, index=formatted)
[docs]@public.add def plot_divergence_crosstabs_3d(divergences): """Plot the divergences in 3d. :params divergences: The list of divergences """ import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa F401 fig = plt.figure() ax = fig.add_subplot(111, projection="3d") keys = list(divergences.keys()) indexes = range(len(keys)) for i in indexes: y = [x[1] for x in list(divergences[keys[i]].items())] ax.bar(indexes, y, i, zdir="y", alpha=0.8) ax.set(xticks=indexes, xticklabels=keys, yticks=indexes, yticklabels=keys) return fig
[docs]@public.add def split(x, train_ratio=0.5, nprng=np.random.RandomState(0)): i = int(len(x) * train_ratio) if hasattr(x, "shape"): idx = np.arange(x.shape[0]) nprng.shuffle(idx) x = x[idx] else: nprng.shuffle(x) return x[:i], x[i:]