Source code for hts.model.base

import logging
from typing import NamedTuple, Union

import numpy
import pandas
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.statespace.sarimax import SARIMAX

from hts._t import ModelT, NAryTreeT, TimeSeriesModelT, TransformT
from hts.core.exceptions import InvalidArgumentException
from hts.hierarchy import HierarchyTree
from hts.transforms import BoxCoxTransformer, FunctionTransformer

logger = logging.getLogger(__name__)


[docs]class TimeSeriesModel(TimeSeriesModelT): """Base class for the implementation of the underlying models. Inherits from scikit-learn base classes """
[docs] def __init__( self, kind: str, node: HierarchyTree, transform: TransformT = False, **kwargs ): """ Parameters ---------- kind : str One of `prophet`, `sarimax`, `auto-arima`, `holt-winters` node : HierarchyTree Node transform : Bool or NamedTuple kwargs Keyword arguments to be passed to the model instantiation. See the documentation of each of the actual model implementations for a more comprehensive treatment """ if kind not in ModelT.names(): raise InvalidArgumentException( f'Model {kind} not valid. Pick one of: {" ".join(ModelT.names())}' ) self.kind = kind self.node = node self.transform_function = self._set_transform(transform=transform) self.model = self.create_model(**kwargs) self.forecast = None self.residual = None self.mse = None
def _set_transform(self, transform: TransformT): if transform is False or transform is None: return FunctionTransformer(func=self._no_func, inv_func=self._no_func) elif transform is True: return BoxCoxTransformer() elif isinstance(transform, tuple): if not hasattr(transform, "func") or not hasattr(transform, "inv_func"): raise ValueError( "If passing a NamedTuple, it must have a `func` and `inv_func` parameters" ) return FunctionTransformer( func=getattr(transform, "func"), inv_func=getattr(transform, "inv_func") ) else: raise ValueError( "Invalid transform passed. Use either `True` for default boxcox transform or " "a `NamedTuple(func: Callable, inv_func: Callable)` for custom transforms" ) def _set_results_return_self(self, in_sample, y_hat): in_sample = self.transform_function.inverse_transform(in_sample) y_hat = self.transform_function.inverse_transform(y_hat) self.forecast = pandas.DataFrame( {"yhat": numpy.concatenate([in_sample, y_hat])} ) self.residual = (in_sample - self._get_transformed_data(as_series=True)).values self.mse = numpy.mean(numpy.array(self.residual) ** 2) return self def _get_transformed_data( self, as_series: bool = False ) -> Union[pandas.DataFrame, pandas.Series]: key = self.node.key value = self.node.item transformed = self.transform_function.transform(value[key]) if as_series: return pandas.Series(transformed) else: return pandas.DataFrame({key: transformed})
[docs] def create_model(self, **kwargs): if self.kind == ModelT.holt_winters.name: data = self._get_transformed_data() model = ExponentialSmoothing(endog=data, **kwargs) elif self.kind == ModelT.auto_arima.name: try: from pmdarima import AutoARIMA except ImportError: # pragma: no cover logger.error( "pmdarima not installed, so auto_arima won't work. Exiting." "Install it with: pip install scikit-hts[auto_arima]" ) return model = AutoARIMA(**kwargs) elif self.kind == ModelT.sarimax.name: as_df = self.node.item end = self._get_transformed_data(as_series=True) if self.node.exogenous: ex = as_df[self.node.exogenous] else: ex = None model = SARIMAX(endog=end, exog=ex, **kwargs) else: raise return model
[docs] def fit(self, **fit_args) -> "TimeSeriesModel": raise NotImplementedError
[docs] def predict(self, node: HierarchyTree, **predict_args): raise NotImplementedError
[docs] def fit_predict(self, node: HierarchyTree, **kwargs): return self.fit().predict(node)