Skip to content

Commit

Permalink
[MNT] maintenance changes for AutoTBATS (#6400)
Browse files Browse the repository at this point in the history
This PR adds some minor enhancements in `StatsForecastAutoTBATS`:

1. supports new arguments added in newer `statsforecast` to control box
cox parameter range for automated selection
2. updates existing test parameters with use of more arguments for
higher coverage
3. removes internal use of deprecated `seasonal_periods` parameter
instead of consistent `season_length` parameter -> this is not a
breaking change as user interface is unchanged
  • Loading branch information
yarnabrina committed May 10, 2024
1 parent 823ea79 commit e383fed
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions sktime/forecasting/statsforecast.py
Expand Up @@ -655,17 +655,17 @@ class StatsForecastAutoTBATS(_GeneralisedStatsForecastAdapter):
Number of observations per unit of time. Ex: 24 Hourly data.
use_boxcox : bool (default=None)
Whether or not to use a Box-Cox transformation. By default tries both.
bc_lower_bound : float (default=0.0)
Lower bound for the Box-Cox transformation.
bc_upper_bound : float (default=1.5)
Upper bound for the Box-Cox transformation.
use_trend : bool (default=None)
Whether or not to use a trend component. By default tries both.
use_damped_trend : bool (default=None)
Whether or not to dampen the trend component. By default tries both.
use_arma_errors : bool (default=True)
Whether or not to use a ARMA errors.
Default is True and this evaluates both models.
bc_lower_bound : float (default=0.0)
Lower bound for the Box-Cox transformation.
bc_upper_bound : float (default=1.0)
Upper bound for the Box-Cox transformation.
See Also
--------
Expand Down Expand Up @@ -699,12 +699,16 @@ def __init__(
use_trend: Optional[bool] = None,
use_damped_trend: Optional[bool] = None,
use_arma_errors: bool = True,
bc_lower_bound: float = 0.0,
bc_upper_bound: float = 1.0,
):
self.seasonal_periods = seasonal_periods
self.use_boxcox = use_boxcox
self.use_trend = use_trend
self.use_damped_trend = use_damped_trend
self.use_arma_errors = use_arma_errors
self.bc_lower_bound = bc_lower_bound
self.bc_upper_bound = bc_upper_bound

super().__init__()

Expand All @@ -716,11 +720,13 @@ def _get_statsforecast_class(self):

def _get_statsforecast_params(self) -> dict:
return {
"seasonal_periods": self.seasonal_periods,
"season_length": self.seasonal_periods,
"use_boxcox": self.use_boxcox,
"use_trend": self.use_trend,
"use_damped_trend": self.use_damped_trend,
"use_arma_errors": self.use_arma_errors,
"bc_lower_bound": self.bc_lower_bound,
"bc_upper_bound": self.bc_upper_bound,
}

@classmethod
Expand All @@ -745,7 +751,21 @@ def get_test_params(cls, parameter_set="default"):
"""
del parameter_set # to avoid being detected as unused by `vulture` etc.

params = [{"seasonal_periods": 3}, {"seasonal_periods": [3, 12]}]
params = [
{
"seasonal_periods": 3,
"use_boxcox": True,
"bc_lower_bound": 0.25,
"bc_upper_bound": 0.75,
},
{
"seasonal_periods": [3, 12],
"use_boxcox": False,
"use_trend": True,
"use_damped_trend": True,
"use_arma_errors": False,
},
]

return params

Expand Down

0 comments on commit e383fed

Please sign in to comment.