diagnostics.R (prophet-0.7) | : | diagnostics.R (prophet-1.0) | ||
---|---|---|---|---|
skipping to change at line 113 | skipping to change at line 113 | |||
initial.dt <- max( | initial.dt <- max( | |||
as.difftime(3 * horizon, units = units), | as.difftime(3 * horizon, units = units), | |||
seasonality.dt | seasonality.dt | |||
) | ) | |||
}else { | }else { | |||
initial.dt <- as.difftime(initial, units = units) | initial.dt <- as.difftime(initial, units = units) | |||
} | } | |||
cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt) | cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt) | |||
}else{ | }else{ | |||
cutoffs <- set_date(ds=cutoffs) | cutoffs <- set_date(ds=cutoffs) | |||
# Validation | ||||
if (min(cutoffs) <= min(df$ds)) { | ||||
stop('Minimum cutoff value is not strictly greater than min date in histor | ||||
y') | ||||
} | ||||
end_date_minus_horizon <- max(df$ds) - horizon.dt | ||||
if (max(cutoffs) > end_date_minus_horizon) { | ||||
stop('Maximum cutoff value is greater than end date minus horizon') | ||||
} | ||||
initial.dt <- cutoffs[1] - min(df$ds) | initial.dt <- cutoffs[1] - min(df$ds) | |||
} | } | |||
# Check if the initial window (that is, the amount of time between the | # Check if the initial window (that is, the amount of time between the | |||
# start of the history and the first cutoff) is less than the | # start of the history and the first cutoff) is less than the | |||
# maximum seasonality period | # maximum seasonality period | |||
if (initial.dt < seasonality.dt) { | if (initial.dt < seasonality.dt) { | |||
warning(paste0('Seasonality has period of ', period.max, ' days which ', | warning(paste0('Seasonality has period of ', period.max, ' days which ', | |||
'is larger than initial window. Consider increasing initial.')) | 'is larger than initial window. Consider increasing initial.')) | |||
} | } | |||
predicts <- data.frame() | predicts <- data.frame() | |||
for (i in 1:length(cutoffs)) { | for (i in 1:length(cutoffs)) { | |||
# Copy the model | df.c <- single_cutoff_forecast(df, model, cutoffs[i], horizon.dt, predict_co | |||
cutoff <- cutoffs[i] | lumns) | |||
m <- prophet_copy(model, cutoff) | ||||
# Train model | ||||
history.c <- dplyr::filter(df, ds <= cutoff) | ||||
if (nrow(history.c) < 2) { | ||||
stop('Less than two datapoints before cutoff. Increase initial window.') | ||||
} | ||||
fit.args <- c(list(m=m, df=history.c), model$fit.kwargs) | ||||
m <- do.call(fit.prophet, fit.args) | ||||
# Calculate yhat | ||||
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt) | ||||
# Get the columns for the future dataframe | ||||
columns <- 'ds' | ||||
if (m$growth == 'logistic') { | ||||
columns <- c(columns, 'cap') | ||||
if (m$logistic.floor) { | ||||
columns <- c(columns, 'floor') | ||||
} | ||||
} | ||||
columns <- c(columns, names(m$extra_regressors)) | ||||
for (name in names(m$seasonalities)) { | ||||
condition.name = m$seasonalities[[name]]$condition.name | ||||
if (!is.null(condition.name)) { | ||||
columns <- c(columns, condition.name) | ||||
} | ||||
} | ||||
future <- df.predict[columns] | ||||
yhat <- stats::predict(m, future) | ||||
# Merge yhat, y, and cutoff. | ||||
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds") | ||||
df.c <- df.c[c(predict_columns, "y")] | ||||
df.c <- dplyr::select(df.c, y, predict_columns) | ||||
df.c$cutoff <- cutoff | ||||
predicts <- rbind(predicts, df.c) | predicts <- rbind(predicts, df.c) | |||
} | } | |||
return(predicts) | return(predicts) | |||
} | } | |||
#' Forecast for a single cutoff. | ||||
#' Used in cross_validation function when evaluating for multiple cutoffs. | ||||
#' | ||||
#' @param df Dataframe with history for cutoff. | ||||
#' @param model Prophet model object. | ||||
#' @param cutoff Datetime of cutoff. | ||||
#' @param horizon.dt timediff forecast horizon. | ||||
#' @param predict_columns Array of names of columns to be returned in output. | ||||
#' | ||||
#' @return Dataframe with forecast, actual value, and cutoff. | ||||
#' | ||||
#' @keywords internal | ||||
single_cutoff_forecast <- function(df, model, cutoff, horizon.dt, predict_column | ||||
s){ | ||||
m <- prophet_copy(model, cutoff) | ||||
# Train model | ||||
history.c <- dplyr::filter(df, ds <= cutoff) | ||||
if (nrow(history.c) < 2) { | ||||
stop('Less than two datapoints before cutoff. Increase initial window.') | ||||
} | ||||
fit.args <- c(list(m=m, df=history.c), model$fit.kwargs) | ||||
m <- do.call(fit.prophet, fit.args) | ||||
# Calculate yhat | ||||
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt) | ||||
# Get the columns for the future dataframe | ||||
columns <- 'ds' | ||||
if (m$growth == 'logistic') { | ||||
columns <- c(columns, 'cap') | ||||
if (m$logistic.floor) { | ||||
columns <- c(columns, 'floor') | ||||
} | ||||
} | ||||
columns <- c(columns, names(m$extra_regressors)) | ||||
for (name in names(m$seasonalities)) { | ||||
condition.name = m$seasonalities[[name]]$condition.name | ||||
if (!is.null(condition.name)) { | ||||
columns <- c(columns, condition.name) | ||||
} | ||||
} | ||||
future <- df.predict[columns] | ||||
yhat <- stats::predict(m, future) | ||||
# Merge yhat, y, and cutoff. | ||||
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds") | ||||
df.c <- df.c[c(predict_columns, "y")] | ||||
df.c <- dplyr::select(df.c, y, predict_columns) | ||||
df.c$cutoff <- cutoff | ||||
return(df.c) | ||||
} | ||||
#' Copy Prophet object. | #' Copy Prophet object. | |||
#' | #' | |||
#' @param m Prophet model object. | #' @param m Prophet model object. | |||
#' @param cutoff Date, possibly as string. Changepoints are only retained if | #' @param cutoff Date, possibly as string. Changepoints are only retained if | |||
#' changepoints <= cutoff. | #' changepoints <= cutoff. | |||
#' | #' | |||
#' @return An unfitted Prophet model object with the same parameters as the | #' @return An unfitted Prophet model object with the same parameters as the | |||
#' input model. | #' input model. | |||
#' | #' | |||
#' @keywords internal | #' @keywords internal | |||
skipping to change at line 225 | skipping to change at line 249 | |||
#' Compute performance metrics from cross-validation results. | #' Compute performance metrics from cross-validation results. | |||
#' | #' | |||
#' Computes a suite of performance metrics on the output of cross-validation. | #' Computes a suite of performance metrics on the output of cross-validation. | |||
#' By default the following metrics are included: | #' By default the following metrics are included: | |||
#' 'mse': mean squared error, | #' 'mse': mean squared error, | |||
#' 'rmse': root mean squared error, | #' 'rmse': root mean squared error, | |||
#' 'mae': mean absolute error, | #' 'mae': mean absolute error, | |||
#' 'mape': mean percent error, | #' 'mape': mean percent error, | |||
#' 'mdape': median percent error, | #' 'mdape': median percent error, | |||
#' 'smape': symmetric mean absolute percentage error, | ||||
#' 'coverage': coverage of the upper and lower intervals | #' 'coverage': coverage of the upper and lower intervals | |||
#' | #' | |||
#' A subset of these can be specified by passing a list of names as the | #' A subset of these can be specified by passing a list of names as the | |||
#' `metrics` argument. | #' `metrics` argument. | |||
#' | #' | |||
#' Metrics are calculated over a rolling window of cross validation | #' Metrics are calculated over a rolling window of cross validation | |||
#' predictions, after sorting by horizon. Averaging is first done within each | #' predictions, after sorting by horizon. Averaging is first done within each | |||
#' value of the horizon, and then across horizons as needed to reach the | #' value of the horizon, and then across horizons as needed to reach the | |||
#' window size. The size of that window (number of simulated forecast points) | #' window size. The size of that window (number of simulated forecast points) | |||
#' is determined by the rolling_window argument, which specifies a proportion | #' is determined by the rolling_window argument, which specifies a proportion | |||
skipping to change at line 249 | skipping to change at line 274 | |||
#' points. The results are set to the right edge of the window. | #' points. The results are set to the right edge of the window. | |||
#' | #' | |||
#' If rolling_window < 0, then metrics are computed at each datapoint with no | #' If rolling_window < 0, then metrics are computed at each datapoint with no | |||
#' averaging (i.e., 'mse' will actually be squared error with no mean). | #' averaging (i.e., 'mse' will actually be squared error with no mean). | |||
#' | #' | |||
#' The output is a dataframe containing column 'horizon' along with columns | #' The output is a dataframe containing column 'horizon' along with columns | |||
#' for each of the metrics computed. | #' for each of the metrics computed. | |||
#' | #' | |||
#' @param df The dataframe returned by cross_validation. | #' @param df The dataframe returned by cross_validation. | |||
#' @param metrics An array of performance metrics to compute. If not provided, | #' @param metrics An array of performance metrics to compute. If not provided, | |||
#' will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage'). | #' will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage'). | |||
#' @param rolling_window Proportion of data to use in each rolling window for | #' @param rolling_window Proportion of data to use in each rolling window for | |||
#' computing the metrics. Should be in [0, 1] to average. | #' computing the metrics. Should be in [0, 1] to average. | |||
#' | #' | |||
#' @return A dataframe with a column for each metric, and column 'horizon'. | #' @return A dataframe with a column for each metric, and column 'horizon'. | |||
#' | #' | |||
#' @export | #' @export | |||
performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) { | performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) { | |||
valid_metrics <- c('mse', 'rmse', 'mae', 'mape', 'coverage') | valid_metrics <- c('mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage') | |||
if (is.null(metrics)) { | if (is.null(metrics)) { | |||
metrics <- valid_metrics | metrics <- valid_metrics | |||
} | } | |||
if ((!('yhat_lower' %in% colnames(df)) | !('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){ | if ((!('yhat_lower' %in% colnames(df)) | !('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){ | |||
metrics <- metrics[metrics != 'coverage'] | metrics <- metrics[metrics != 'coverage'] | |||
} | } | |||
if (length(metrics) != length(unique(metrics))) { | if (length(metrics) != length(unique(metrics))) { | |||
stop('Input metrics must be an array of unique values.') | stop('Input metrics must be an array of unique values.') | |||
} | } | |||
skipping to change at line 500 | skipping to change at line 525 | |||
#' | #' | |||
#' @keywords internal | #' @keywords internal | |||
mdape <- function(df, w) { | mdape <- function(df, w) { | |||
ape <- abs((df$y - df$yhat) / df$y) | ape <- abs((df$y - df$yhat) / df$y) | |||
if (w < 0) { | if (w < 0) { | |||
return(data.frame(horizon = df$horizon, mdape = ape)) | return(data.frame(horizon = df$horizon, mdape = ape)) | |||
} | } | |||
return(rolling_median_by_h(x = ape, h = df$horizon, w = w, name = 'mdape')) | return(rolling_median_by_h(x = ape, h = df$horizon, w = w, name = 'mdape')) | |||
} | } | |||
#' Symmetric mean absolute percentage error | ||||
#' based on Chen and Yang (2004) formula | ||||
#' | ||||
#' @param df Cross-validation results dataframe. | ||||
#' @param w Aggregation window size. | ||||
#' | ||||
#' @return Array of symmetric mean absolute percent errors. | ||||
#' | ||||
#' @keywords internal | ||||
smape <- function(df, w) { | ||||
sape <- abs(df$y - df$yhat) / ((abs(df$y) + abs(df$yhat)) / 2) | ||||
if (w < 0) { | ||||
return(data.frame(horizon = df$horizon, smape = sape)) | ||||
} | ||||
return(rolling_mean_by_h(x = sape, h = df$horizon, w = w, name = 'smape')) | ||||
} | ||||
#' Coverage | #' Coverage | |||
#' | #' | |||
#' @param df Cross-validation results dataframe. | #' @param df Cross-validation results dataframe. | |||
#' @param w Aggregation window size. | #' @param w Aggregation window size. | |||
#' | #' | |||
#' @return Array of coverages | #' @return Array of coverages | |||
#' | #' | |||
#' @keywords internal | #' @keywords internal | |||
coverage <- function(df, w) { | coverage <- function(df, w) { | |||
is_covered <- (df$y >= df$yhat_lower) & (df$y <= df$yhat_upper) | is_covered <- (df$y >= df$yhat_lower) & (df$y <= df$yhat_upper) | |||
End of changes. 7 change blocks. | ||||
36 lines changed or deleted | 81 lines changed or added |