test_diagnostics.R (prophet-0.7) | : | test_diagnostics.R (prophet-1.0) | ||
---|---|---|---|---|
skipping to change at line 59 | skipping to change at line 59 | |||
df$cap <- 40 | df$cap <- 40 | |||
m <- prophet(df, growth = 'logistic') | m <- prophet(df, growth = 'logistic') | |||
df.cv <- cross_validation( | df.cv <- cross_validation( | |||
m, horizon = 1, units = "days", period = 1, initial = 140) | m, horizon = 1, units = "days", period = 1, initial = 140) | |||
expect_equal(length(unique(df.cv$cutoff)), 2) | expect_equal(length(unique(df.cv$cutoff)), 2) | |||
expect_true(all(df.cv$cutoff < df.cv$ds)) | expect_true(all(df.cv$cutoff < df.cv$ds)) | |||
df.merged <- dplyr::left_join(df.cv, m$history, by="ds") | df.merged <- dplyr::left_join(df.cv, m$history, by="ds") | |||
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0) | expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0) | |||
}) | }) | |||
test_that("cross_validation_flat", { | ||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386') | ||||
df <- DATA | ||||
m <- prophet(df, growth = 'flat') | ||||
df.cv <- cross_validation( | ||||
m, horizon = 1, units = "days", period = 1, initial = 140) | ||||
expect_equal(length(unique(df.cv$cutoff)), 2) | ||||
expect_true(all(df.cv$cutoff < df.cv$ds)) | ||||
df.merged <- dplyr::left_join(df.cv, m$history, by="ds") | ||||
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0) | ||||
}) | ||||
test_that("cross_validation_extra_regressors", { | test_that("cross_validation_extra_regressors", { | |||
skip_if_not(Sys.getenv('R_ARCH') != '/i386') | skip_if_not(Sys.getenv('R_ARCH') != '/i386') | |||
df <- DATA | df <- DATA | |||
df$extra <- seq(0, nrow(df) - 1) | df$extra <- seq(0, nrow(df) - 1) | |||
df$is_conditional_week <- seq(0, nrow(df) - 1) %/% 7 %% 2 | df$is_conditional_week <- seq(0, nrow(df) - 1) %/% 7 %% 2 | |||
m <- prophet() | m <- prophet() | |||
m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5) | m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5) | |||
m <- add_seasonality(m, name = 'conditional_weekly', period = 7, | m <- add_seasonality(m, name = 'conditional_weekly', period = 7, | |||
fourier.order = 3, prior.scale = 2., | fourier.order = 3, prior.scale = 2., | |||
condition.name = 'is_conditional_week') | condition.name = 'is_conditional_week') | |||
skipping to change at line 131 | skipping to change at line 143 | |||
test_that("performance_metrics", { | test_that("performance_metrics", { | |||
skip_if_not(Sys.getenv('R_ARCH') != '/i386') | skip_if_not(Sys.getenv('R_ARCH') != '/i386') | |||
m <- prophet(DATA) | m <- prophet(DATA) | |||
df_cv <- cross_validation( | df_cv <- cross_validation( | |||
m, horizon = 4, units = "days", period = 10, initial = 90) | m, horizon = 4, units = "days", period = 10, initial = 90) | |||
# Aggregation level none | # Aggregation level none | |||
df_none <- performance_metrics(df_cv, rolling_window = -1) | df_none <- performance_metrics(df_cv, rolling_window = -1) | |||
expect_true(all( | expect_true(all( | |||
sort(colnames(df_none)) | sort(colnames(df_none)) | |||
== sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse')) | == sort(c('horizon', 'mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'covera ge')) | |||
)) | )) | |||
expect_equal(nrow(df_none), 16) | expect_equal(nrow(df_none), 16) | |||
# Aggregation level 0 | # Aggregation level 0 | |||
df_0 <- performance_metrics(df_cv, rolling_window = 0) | df_0 <- performance_metrics(df_cv, rolling_window = 0) | |||
expect_equal(nrow(df_0), 4) | expect_equal(nrow(df_0), 4) | |||
expect_equal(length(unique(df_0$h)), 4) | expect_equal(length(unique(df_0$h)), 4) | |||
# Aggregation level 0.2 | # Aggregation level 0.2 | |||
df_horizon <- performance_metrics(df_cv, rolling_window = 0.2) | df_horizon <- performance_metrics(df_cv, rolling_window = 0.2) | |||
expect_equal(nrow(df_horizon), 4) | expect_equal(nrow(df_horizon), 4) | |||
expect_equal(length(unique(df_horizon$horizon)), 4) | expect_equal(length(unique(df_horizon$horizon)), 4) | |||
skipping to change at line 164 | skipping to change at line 176 | |||
df_cv$y[1] <- 0. | df_cv$y[1] <- 0. | |||
df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mape')) | df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mape')) | |||
expect_true(all( | expect_true(all( | |||
sort(colnames(df_horizon)) == sort(c('coverage', 'horizon')) | sort(colnames(df_horizon)) == sort(c('coverage', 'horizon')) | |||
)) | )) | |||
df_horizon <- performance_metrics(df_cv, metrics = c('mape')) | df_horizon <- performance_metrics(df_cv, metrics = c('mape')) | |||
expect_null(df_horizon) | expect_null(df_horizon) | |||
# List of metrics containing non valid metrics | # List of metrics containing non valid metrics | |||
expect_error( | expect_error( | |||
performance_metrics(df_cv, metrics = c('mse', 'error_metric')), | performance_metrics(df_cv, metrics = c('mse', 'error_metric')), | |||
'Valid values for metrics are: mse, rmse, mae, mape, coverage' | 'Valid values for metrics are: mse, rmse, mae, mape, mdape, smape, coverage ' | |||
) | ) | |||
}) | }) | |||
test_that("rolling_mean", { | test_that("rolling_mean", { | |||
skip_if_not(Sys.getenv('R_ARCH') != '/i386') | skip_if_not(Sys.getenv('R_ARCH') != '/i386') | |||
x <- 0:9 | x <- 0:9 | |||
h <- 0:9 | h <- 0:9 | |||
df <- prophet:::rolling_mean_by_h(x=x, h=h, w=1, name='x') | df <- prophet:::rolling_mean_by_h(x=x, h=h, w=1, name='x') | |||
expect_equal(x, df$x) | expect_equal(x, df$x) | |||
expect_equal(h, df$horizon) | expect_equal(h, df$horizon) | |||
skipping to change at line 223 | skipping to change at line 235 | |||
expect_equal(c(7.), df$horizon) | expect_equal(c(7.), df$horizon) | |||
expect_equal(c(4.5), df$x) | expect_equal(c(4.5), df$x) | |||
}) | }) | |||
test_that("copy", { | test_that("copy", { | |||
skip_if_not(Sys.getenv('R_ARCH') != '/i386') | skip_if_not(Sys.getenv('R_ARCH') != '/i386') | |||
df <- DATA_all | df <- DATA_all | |||
df$cap <- 200. | df$cap <- 200. | |||
df$binary_feature <- c(rep(0, 255), rep(1, 255)) | df$binary_feature <- c(rep(0, 255), rep(1, 255)) | |||
inputs <- list( | inputs <- list( | |||
growth = c('linear', 'logistic'), | growth = c('linear', 'logistic', 'flat'), | |||
yearly.seasonality = c(TRUE, FALSE), | yearly.seasonality = c(TRUE, FALSE), | |||
weekly.seasonality = c(TRUE, FALSE), | weekly.seasonality = c(TRUE, FALSE), | |||
daily.seasonality = c(TRUE, FALSE), | daily.seasonality = c(TRUE, FALSE), | |||
holidays = c('null', 'insert_dataframe'), | holidays = c('null', 'insert_dataframe'), | |||
seasonality.mode = c('additive', 'multiplicative') | seasonality.mode = c('additive', 'multiplicative') | |||
) | ) | |||
products <- expand.grid(inputs) | products <- expand.grid(inputs) | |||
for (i in 1:length(products)) { | for (i in 1:length(products)) { | |||
if (products$holidays[i] == 'insert_dataframe') { | if (products$holidays[i] == 'insert_dataframe') { | |||
holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x')) | holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x')) | |||
End of changes. 4 change blocks. | ||||
3 lines changed or deleted | 15 lines changed or added |