test_diagnostics.py (prophet-0.7) | : | test_diagnostics.py (prophet-1.0) | ||
---|---|---|---|---|
skipping to change at line 20 | skipping to change at line 20 | |||
import itertools | import itertools | |||
import os | import os | |||
from unittest import TestCase | from unittest import TestCase | |||
from unittest.mock import patch | from unittest.mock import patch | |||
import numpy as np | import numpy as np | |||
import pandas as pd | import pandas as pd | |||
import datetime | import datetime | |||
from fbprophet import Prophet | from prophet import Prophet | |||
from fbprophet import diagnostics | from prophet import diagnostics | |||
DATA_all = pd.read_csv( | DATA_all = pd.read_csv( | |||
os.path.join(os.path.dirname(__file__), 'data.csv'), parse_dates=['ds'] | os.path.join(os.path.dirname(__file__), 'data.csv'), parse_dates=['ds'] | |||
) | ) | |||
DATA = DATA_all.head(100) | DATA = DATA_all.head(100) | |||
class CustomParallelBackend: | class CustomParallelBackend: | |||
def map(self, func, *iterables): | def map(self, func, *iterables): | |||
results = [func(*args) for args in zip(*iterables)] | results = [func(*args) for args in zip(*iterables)] | |||
return results | return results | |||
skipping to change at line 99 | skipping to change at line 99 | |||
mock_predict = pd.DataFrame({'ds':pd.date_range(start='2012-09-17', peri ods=3), | mock_predict = pd.DataFrame({'ds':pd.date_range(start='2012-09-17', peri ods=3), | |||
'yhat':np.arange(16, 19), | 'yhat':np.arange(16, 19), | |||
'yhat_lower':np.arange(15, 18), | 'yhat_lower':np.arange(15, 18), | |||
'yhat_upper': np.arange(17, 20), | 'yhat_upper': np.arange(17, 20), | |||
'y': np.arange(16.5, 19.5), | 'y': np.arange(16.5, 19.5), | |||
'cutoff': [datetime.date(2012, 9, 15)]*3}) | 'cutoff': [datetime.date(2012, 9, 15)]*3}) | |||
# cross validation with 3 and 7 forecasts | # cross validation with 3 and 7 forecasts | |||
for args, forecasts in ((['4 days', '10 days', '115 days'], 3), | for args, forecasts in ((['4 days', '10 days', '115 days'], 3), | |||
(['4 days', '4 days', '115 days'], 7)): | (['4 days', '4 days', '115 days'], 7)): | |||
with patch('fbprophet.diagnostics.single_cutoff_forecast') as mock_f unc: | with patch('prophet.diagnostics.single_cutoff_forecast') as mock_fun c: | |||
mock_func.return_value = mock_predict | mock_func.return_value = mock_predict | |||
df_cv = diagnostics.cross_validation(m, *args) | df_cv = diagnostics.cross_validation(m, *args) | |||
# check single forecast function called expected number of times | # check single forecast function called expected number of times | |||
self.assertEqual(diagnostics.single_cutoff_forecast.call_count, | self.assertEqual(diagnostics.single_cutoff_forecast.call_count, | |||
forecasts) | forecasts) | |||
def test_cross_validation_logistic_or_flat_growth(self): | def test_cross_validation_logistic_or_flat_growth(self): | |||
params = (x for x in ['logistic', 'flat']) | params = (x for x in ['logistic', 'flat']) | |||
for growth in params: | for growth in params: | |||
with self.subTest(i=growth): | with self.subTest(i=growth): | |||
skipping to change at line 190 | skipping to change at line 190 | |||
def test_performance_metrics(self): | def test_performance_metrics(self): | |||
m = Prophet() | m = Prophet() | |||
m.fit(self.__df) | m.fit(self.__df) | |||
df_cv = diagnostics.cross_validation( | df_cv = diagnostics.cross_validation( | |||
m, horizon='4 days', period='10 days', initial='90 days') | m, horizon='4 days', period='10 days', initial='90 days') | |||
# Aggregation level none | # Aggregation level none | |||
df_none = diagnostics.performance_metrics(df_cv, rolling_window=-1) | df_none = diagnostics.performance_metrics(df_cv, rolling_window=-1) | |||
self.assertEqual( | self.assertEqual( | |||
set(df_none.columns), | set(df_none.columns), | |||
{'horizon', 'coverage', 'mae', 'mape', 'mdape', 'mse', 'rmse'}, | {'horizon', 'coverage', 'mae', 'mape', 'mdape', 'mse', 'rmse', 'smap e'}, | |||
) | ) | |||
self.assertEqual(df_none.shape[0], 16) | self.assertEqual(df_none.shape[0], 16) | |||
# Aggregation level 0 | # Aggregation level 0 | |||
df_0 = diagnostics.performance_metrics(df_cv, rolling_window=0) | df_0 = diagnostics.performance_metrics(df_cv, rolling_window=0) | |||
self.assertEqual(len(df_0), 4) | self.assertEqual(len(df_0), 4) | |||
self.assertEqual(len(df_0['horizon'].unique()), 4) | self.assertEqual(len(df_0['horizon'].unique()), 4) | |||
# Aggregation level 0.2 | # Aggregation level 0.2 | |||
df_horizon = diagnostics.performance_metrics(df_cv, rolling_window=0.2) | df_horizon = diagnostics.performance_metrics(df_cv, rolling_window=0.2) | |||
self.assertEqual(len(df_horizon), 4) | self.assertEqual(len(df_horizon), 4) | |||
self.assertEqual(len(df_horizon['horizon'].unique()), 4) | self.assertEqual(len(df_horizon['horizon'].unique()), 4) | |||
End of changes. 3 change blocks. | ||||
4 lines changed or deleted | 4 lines changed or added |