Source code for mvtk.bias_variance.estimators.sklearn_estimator_wrapper

from . import EstimatorWrapper


[docs]class SciKitLearnEstimatorWrapper(EstimatorWrapper):
[docs] def __init__(self, estimator): r"""Create a wrapper for a Scikit-Learn estimator Args: estimator: Scikit-Learn estimator instance Returns: self """ self.estimator = estimator
[docs] def fit(self, X, y, **kwargs): r"""Train the estimator Args: X: features y: ground truth labels kwargs (optional): kwargs for use in training Returns: self """ self.estimator.fit(X, y, **kwargs) return self
[docs] def predict(self, X, **kwargs): r"""Get predictions from the estimator Args: X: features kwargs (optional): kwargs for use in predicting Returns: self """ return self.estimator.predict(X, **kwargs)