prophet.R (prophet-0.7) | : | prophet.R (prophet-1.0) | ||
---|---|---|---|---|
skipping to change at line 23 | skipping to change at line 23 | |||
"country", "year" | "country", "year" | |||
)) | )) | |||
#' Prophet forecaster. | #' Prophet forecaster. | |||
#' | #' | |||
#' @param df (optional) Dataframe containing the history. Must have columns ds | #' @param df (optional) Dataframe containing the history. Must have columns ds | |||
#' (date type) and y, the time series. If growth is logistic, then df must | #' (date type) and y, the time series. If growth is logistic, then df must | |||
#' also have a column cap that specifies the capacity at each ds. If not | #' also have a column cap that specifies the capacity at each ds. If not | |||
#' provided, then the model object will be instantiated but not fit; use | #' provided, then the model object will be instantiated but not fit; use | |||
#' fit.prophet(m, df) to fit the model. | #' fit.prophet(m, df) to fit the model. | |||
#' @param growth String 'linear' or 'logistic' to specify a linear or logistic | #' @param growth String 'linear', 'logistic', or 'flat' to specify a linear, log | |||
#' trend. | istic | |||
#' or flat trend. | ||||
#' @param changepoints Vector of dates at which to include potential | #' @param changepoints Vector of dates at which to include potential | |||
#' changepoints. If not specified, potential changepoints are selected | #' changepoints. If not specified, potential changepoints are selected | |||
#' automatically. | #' automatically. | |||
#' @param n.changepoints Number of potential changepoints to include. Not used | #' @param n.changepoints Number of potential changepoints to include. Not used | |||
#' if input `changepoints` is supplied. If `changepoints` is not supplied, | #' if input `changepoints` is supplied. If `changepoints` is not supplied, | |||
#' then n.changepoints potential changepoints are selected uniformly from the | #' then n.changepoints potential changepoints are selected uniformly from the | |||
#' first `changepoint.range` proportion of df$ds. | #' first `changepoint.range` proportion of df$ds. | |||
#' @param changepoint.range Proportion of history in which trend changepoints | #' @param changepoint.range Proportion of history in which trend changepoints | |||
#' will be estimated. Defaults to 0.8 for the first 80%. Not used if | #' will be estimated. Defaults to 0.8 for the first 80%. Not used if | |||
#' `changepoints` is specified. | #' `changepoints` is specified. | |||
skipping to change at line 82 | skipping to change at line 82 | |||
#' @examples | #' @examples | |||
#' \dontrun{ | #' \dontrun{ | |||
#' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'), | #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'), | |||
#' y = sin(1:366/200) + rnorm(366)/10) | #' y = sin(1:366/200) + rnorm(366)/10) | |||
#' m <- prophet(history) | #' m <- prophet(history) | |||
#' } | #' } | |||
#' | #' | |||
#' @export | #' @export | |||
#' @importFrom dplyr "%>%" | #' @importFrom dplyr "%>%" | |||
#' @import Rcpp | #' @import Rcpp | |||
#' @rawNamespace import(RcppParallel, except = LdFlags) | ||||
#' @import rlang | #' @import rlang | |||
#' @useDynLib prophet, .registration = TRUE | #' @useDynLib prophet, .registration = TRUE | |||
prophet <- function(df = NULL, | prophet <- function(df = NULL, | |||
growth = 'linear', | growth = 'linear', | |||
changepoints = NULL, | changepoints = NULL, | |||
n.changepoints = 25, | n.changepoints = 25, | |||
changepoint.range = 0.8, | changepoint.range = 0.8, | |||
yearly.seasonality = 'auto', | yearly.seasonality = 'auto', | |||
weekly.seasonality = 'auto', | weekly.seasonality = 'auto', | |||
daily.seasonality = 'auto', | daily.seasonality = 'auto', | |||
skipping to change at line 157 | skipping to change at line 158 | |||
} | } | |||
#' Validates the inputs to Prophet. | #' Validates the inputs to Prophet. | |||
#' | #' | |||
#' @param m Prophet object. | #' @param m Prophet object. | |||
#' | #' | |||
#' @return The Prophet object. | #' @return The Prophet object. | |||
#' | #' | |||
#' @keywords internal | #' @keywords internal | |||
validate_inputs <- function(m) { | validate_inputs <- function(m) { | |||
if (!(m$growth %in% c('linear', 'logistic'))) { | if (!(m$growth %in% c('linear', 'logistic', 'flat'))) { | |||
stop("Parameter 'growth' should be 'linear' or 'logistic'.") | stop("Parameter 'growth' should be 'linear', 'logistic', or 'flat'.") | |||
} | } | |||
if ((m$changepoint.range < 0) | (m$changepoint.range > 1)) { | if ((m$changepoint.range < 0) | (m$changepoint.range > 1)) { | |||
stop("Parameter 'changepoint.range' must be in [0, 1]") | stop("Parameter 'changepoint.range' must be in [0, 1]") | |||
} | } | |||
if (!is.null(m$holidays)) { | if (!is.null(m$holidays)) { | |||
if (!(exists('holiday', where = m$holidays))) { | if (!(exists('holiday', where = m$holidays))) { | |||
stop('Holidays dataframe must have holiday field.') | stop('Holidays dataframe must have holiday field.') | |||
} | } | |||
if (!(exists('ds', where = m$holidays))) { | if (!(exists('ds', where = m$holidays))) { | |||
stop('Holidays dataframe must have ds field.') | stop('Holidays dataframe must have ds field.') | |||
} | } | |||
m$holidays$ds <- as.Date(m$holidays$ds) | m$holidays$ds <- as.Date(m$holidays$ds) | |||
if (any(is.na(m$holidays$ds)) | any(is.na(m$holidays$holiday))) { | ||||
stop('Found NA in the holidays dataframe.') | ||||
} | ||||
has.lower <- exists('lower_window', where = m$holidays) | has.lower <- exists('lower_window', where = m$holidays) | |||
has.upper <- exists('upper_window', where = m$holidays) | has.upper <- exists('upper_window', where = m$holidays) | |||
if (has.lower + has.upper == 1) { | if (has.lower + has.upper == 1) { | |||
stop(paste('Holidays must have both lower_window and upper_window,', | stop(paste('Holidays must have both lower_window and upper_window,', | |||
'or neither.')) | 'or neither.')) | |||
} | } | |||
if (has.lower) { | if (has.lower) { | |||
if(max(m$holidays$lower_window, na.rm=TRUE) > 0) { | if(max(m$holidays$lower_window, na.rm=TRUE) > 0) { | |||
stop('Holiday lower_window should be <= 0') | stop('Holiday lower_window should be <= 0') | |||
} | } | |||
skipping to change at line 242 | skipping to change at line 246 | |||
if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){ | if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){ | |||
stop("Name ", name, " already used for a seasonality.") | stop("Name ", name, " already used for a seasonality.") | |||
} | } | |||
if(check_regressors & (!is.null(m$seasonalities[[name]]))){ | if(check_regressors & (!is.null(m$seasonalities[[name]]))){ | |||
stop("Name ", name, " already used for an added regressor.") | stop("Name ", name, " already used for an added regressor.") | |||
} | } | |||
} | } | |||
#' Convert date vector | #' Convert date vector | |||
#' | #' | |||
#' Convert the date to POSIXct object | #' Convert the date to POSIXct object. Timezones are stripped and replaced | |||
#' with GMT. | ||||
#' | #' | |||
#' @param ds Date vector, can be consisted of characters | #' @param ds Date vector | |||
#' @param tz string time zone | ||||
#' | #' | |||
#' @return vector of POSIXct object converted from date | #' @return vector of POSIXct object converted from date | |||
#' | #' | |||
#' @keywords internal | #' @keywords internal | |||
set_date <- function(ds = NULL, tz = "GMT") { | set_date <- function(ds) { | |||
if (length(ds) == 0) { | if (length(ds) == 0) { | |||
return(NULL) | return(NULL) | |||
} | } | |||
if (is.factor(ds)) { | if (is.factor(ds)) { | |||
ds <- as.character(ds) | ds <- as.character(ds) | |||
} | } | |||
if (min(nchar(ds), na.rm=TRUE) < 12) { | # If a datetime, strip timezone and replace with GMT. | |||
ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz) | if (lubridate::is.instant(ds)) { | |||
} else { | ds <- as.POSIXct(lubridate::force_tz(ds, "GMT"), tz = "GMT") | |||
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz) | } | |||
else { | ||||
# Assume it can be coerced into POSIXct | ||||
if (min(nchar(ds), na.rm=TRUE) < 12) { | ||||
ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = "GMT") | ||||
} else { | ||||
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = "GMT") | ||||
} | ||||
} | } | |||
attr(ds, "tzone") <- tz | ||||
attr(ds, "tzone") <- "GMT" | ||||
return(ds) | return(ds) | |||
} | } | |||
#' Time difference between datetimes | #' Time difference between datetimes | |||
#' | #' | |||
#' Compute time difference of two POSIXct objects | #' Compute time difference of two POSIXct objects | |||
#' | #' | |||
#' @param ds1 POSIXct object | #' @param ds1 POSIXct object | |||
#' @param ds2 POSIXct object | #' @param ds2 POSIXct object | |||
#' @param units string units of difference, e.g. 'days' or 'secs'. | #' @param units string units of difference, e.g. 'days' or 'secs'. | |||
skipping to change at line 1056 | skipping to change at line 1068 | |||
period = 1, | period = 1, | |||
fourier.order = fourier.order, | fourier.order = fourier.order, | |||
prior.scale = m$seasonality.prior.scale, | prior.scale = m$seasonality.prior.scale, | |||
mode = m$seasonality.mode, | mode = m$seasonality.mode, | |||
condition.name = NULL | condition.name = NULL | |||
) | ) | |||
} | } | |||
return(m) | return(m) | |||
} | } | |||
#' Initialize linear growth. | #' Initialize flat growth. | |||
#' | ||||
#' Provides a strong initialization for flat growth by setting the | ||||
#' growth to 0 and calculates the offset parameter that pass the | ||||
#' function through the mean of the the y_scaled values. | ||||
#' | ||||
#' @param df Data frame with columns ds (date), y_scaled (scaled time series), | ||||
#' and t (scaled time). | ||||
#' | ||||
#' @return A vector (k, m) with the rate (k) and offset (m) of the flat | ||||
#' growth function. | ||||
#' | ||||
#' @keywords internal | ||||
flat_growth_init <- function(df) { | ||||
# Initialize the rate | ||||
k <- 0 | ||||
# And the offset | ||||
m <- mean(df$y_scaled) | ||||
return(c(k, m)) | ||||
} | ||||
#' Initialize constant growth. | ||||
#' | #' | |||
#' Provides a strong initialization for linear growth by calculating the | #' Provides a strong initialization for linear growth by calculating the | |||
#' growth and offset parameters that pass the function through the first and | #' growth and offset parameters that pass the function through the first and | |||
#' last points in the time series. | #' last points in the time series. | |||
#' | #' | |||
#' @param df Data frame with columns ds (date), y_scaled (scaled time series), | #' @param df Data frame with columns ds (date), y_scaled (scaled time series), | |||
#' and t (scaled time). | #' and t (scaled time). | |||
#' | #' | |||
#' @return A vector (k, m) with the rate (k) and offset (m) of the linear | #' @return A vector (k, m) with the rate (k) and offset (m) of the linear | |||
#' growth function. | #' growth function. | |||
skipping to change at line 1180 | skipping to change at line 1213 | |||
dat <- list( | dat <- list( | |||
T = nrow(history), | T = nrow(history), | |||
K = ncol(seasonal.features), | K = ncol(seasonal.features), | |||
S = length(m$changepoints.t), | S = length(m$changepoints.t), | |||
y = history$y_scaled, | y = history$y_scaled, | |||
t = history$t, | t = history$t, | |||
t_change = array(m$changepoints.t), | t_change = array(m$changepoints.t), | |||
X = as.matrix(seasonal.features), | X = as.matrix(seasonal.features), | |||
sigmas = array(prior.scales), | sigmas = array(prior.scales), | |||
tau = m$changepoint.prior.scale, | tau = m$changepoint.prior.scale, | |||
trend_indicator = as.numeric(m$growth == 'logistic'), | trend_indicator = switch(m$growth, 'linear'=0, 'logistic'=1, 'flat'=2), | |||
s_a = array(component.cols$additive_terms), | s_a = array(component.cols$additive_terms), | |||
s_m = array(component.cols$multiplicative_terms) | s_m = array(component.cols$multiplicative_terms) | |||
) | ) | |||
# Run stan | # Run stan | |||
if (m$growth == 'linear') { | if (m$growth == 'linear') { | |||
dat$cap <- rep(0, nrow(history)) # Unused inside Stan | dat$cap <- rep(0, nrow(history)) # Unused inside Stan | |||
kinit <- linear_growth_init(history) | kinit <- linear_growth_init(history) | |||
} else { | } else if (m$growth == 'flat') { | |||
dat$cap <- rep(0, nrow(history)) # Unused inside Stan | ||||
kinit <- flat_growth_init(history) | ||||
} else if (m$growth == 'logistic') { | ||||
dat$cap <- history$cap_scaled # Add capacities to the Stan data | dat$cap <- history$cap_scaled # Add capacities to the Stan data | |||
kinit <- logistic_growth_init(history) | kinit <- logistic_growth_init(history) | |||
} | } | |||
if (exists(".prophet.stan.model", where = prophet_model_env)) { | if (exists(".prophet.stan.model", where = prophet_model_env)) { | |||
model <- get('.prophet.stan.model', envir = prophet_model_env) | model <- get('.prophet.stan.model', envir = prophet_model_env) | |||
} else { | } else { | |||
model <- stanmodels$prophet | model <- stanmodels$prophet | |||
} | } | |||
stan_init <- function() { | stan_init <- function() { | |||
list(k = kinit[1], | list(k = kinit[1], | |||
m = kinit[2], | m = kinit[2], | |||
delta = array(rep(0, length(m$changepoints.t))), | delta = array(rep(0, length(m$changepoints.t))), | |||
beta = array(rep(0, ncol(seasonal.features))), | beta = array(rep(0, ncol(seasonal.features))), | |||
sigma_obs = 1 | sigma_obs = 1 | |||
) | ) | |||
} | } | |||
if (min(history$y) == max(history$y)) { | if (min(history$y) == max(history$y) & | |||
(m$growth %in% c('linear', 'flat'))) { | ||||
# Nothing to fit. | # Nothing to fit. | |||
m$params <- stan_init() | m$params <- stan_init() | |||
m$params$sigma_obs <- 0. | m$params$sigma_obs <- 0. | |||
n.iteration <- 1. | n.iteration <- 1. | |||
} else if (m$mcmc.samples > 0) { | } else if (m$mcmc.samples > 0) { | |||
args <- list( | args <- list( | |||
object = model, | object = model, | |||
data = dat, | data = dat, | |||
init = stan_init, | init = stan_init, | |||
iter = m$mcmc.samples | iter = m$mcmc.samples | |||
skipping to change at line 1321 | skipping to change at line 1358 | |||
} | } | |||
if (object$logistic.floor) { | if (object$logistic.floor) { | |||
cols <- c(cols, 'floor') | cols <- c(cols, 'floor') | |||
} | } | |||
df <- df[cols] | df <- df[cols] | |||
df <- dplyr::bind_cols(df, seasonal.components, intervals) | df <- dplyr::bind_cols(df, seasonal.components, intervals) | |||
df$yhat <- df$trend * (1 + df$multiplicative_terms) + df$additive_terms | df$yhat <- df$trend * (1 + df$multiplicative_terms) + df$additive_terms | |||
return(df) | return(df) | |||
} | } | |||
#' Evaluate the flat trend function. | ||||
#' | ||||
#' @param t Vector of times on which the function is evaluated. | ||||
#' @param m Float initial offset. | ||||
#' | ||||
#' @return Vector y(t). | ||||
#' | ||||
#' @keywords internal | ||||
flat_trend <- function(t, m) { | ||||
y <- rep(m, length(t)) | ||||
return(y) | ||||
} | ||||
#' Evaluate the piecewise linear function. | #' Evaluate the piecewise linear function. | |||
#' | #' | |||
#' @param t Vector of times on which the function is evaluated. | #' @param t Vector of times on which the function is evaluated. | |||
#' @param deltas Vector of rate changes at each changepoint. | #' @param deltas Vector of rate changes at each changepoint. | |||
#' @param k Float initial rate. | #' @param k Float initial rate. | |||
#' @param m Float initial offset. | #' @param m Float initial offset. | |||
#' @param changepoint.ts Vector of changepoint times. | #' @param changepoint.ts Vector of changepoint times. | |||
#' | #' | |||
#' @return Vector y(t). | #' @return Vector y(t). | |||
#' | #' | |||
skipping to change at line 1395 | skipping to change at line 1445 | |||
#' | #' | |||
#' @keywords internal | #' @keywords internal | |||
predict_trend <- function(model, df) { | predict_trend <- function(model, df) { | |||
k <- mean(model$params$k, na.rm = TRUE) | k <- mean(model$params$k, na.rm = TRUE) | |||
param.m <- mean(model$params$m, na.rm = TRUE) | param.m <- mean(model$params$m, na.rm = TRUE) | |||
deltas <- colMeans(model$params$delta, na.rm = TRUE) | deltas <- colMeans(model$params$delta, na.rm = TRUE) | |||
t <- df$t | t <- df$t | |||
if (model$growth == 'linear') { | if (model$growth == 'linear') { | |||
trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t) | trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t) | |||
} else { | } else if (model$growth == 'flat') { | |||
trend <- flat_trend(t, param.m) | ||||
} else if (model$growth == 'logistic') { | ||||
cap <- df$cap_scaled | cap <- df$cap_scaled | |||
trend <- piecewise_logistic( | trend <- piecewise_logistic( | |||
t, cap, deltas, k, param.m, model$changepoints.t) | t, cap, deltas, k, param.m, model$changepoints.t) | |||
} | } | |||
return(trend * model$y.scale + df$floor) | return(trend * model$y.scale + df$floor) | |||
} | } | |||
#' Predict seasonality components, holidays, and added regressors. | #' Predict seasonality components, holidays, and added regressors. | |||
#' | #' | |||
#' @param m Prophet object. | #' @param m Prophet object. | |||
skipping to change at line 1595 | skipping to change at line 1647 | |||
# Sample deltas | # Sample deltas | |||
deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda) | deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda) | |||
# Combine with changepoints from the history | # Combine with changepoints from the history | |||
changepoint.ts <- c(model$changepoints.t, changepoint.ts.new) | changepoint.ts <- c(model$changepoints.t, changepoint.ts.new) | |||
deltas <- c(deltas, deltas.new) | deltas <- c(deltas, deltas.new) | |||
# Get the corresponding trend | # Get the corresponding trend | |||
if (model$growth == 'linear') { | if (model$growth == 'linear') { | |||
trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts) | trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts) | |||
} else { | } else if (model$growth == 'flat') { | |||
trend <- flat_trend(t, param.m) | ||||
} else if (model$growth == 'logistic') { | ||||
cap <- df$cap_scaled | cap <- df$cap_scaled | |||
trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts) | trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts) | |||
} | } | |||
return(trend * model$y.scale + df$floor) | return(trend * model$y.scale + df$floor) | |||
} | } | |||
#' Make dataframe with future dates for forecasting. | #' Make dataframe with future dates for forecasting. | |||
#' | #' | |||
#' @param m Prophet model object. | #' @param m Prophet model object. | |||
#' @param periods Int number of periods to forecast forward. | #' @param periods Int number of periods to forecast forward. | |||
End of changes. 16 change blocks. | ||||
19 lines changed or deleted | 74 lines changed or added |