forecaster.py (prophet-1.0) | : | forecaster.py (prophet-1.1) | ||
---|---|---|---|---|
skipping to change at line 29 | skipping to change at line 29 | |||
from prophet.plot import (plot, plot_components) | from prophet.plot import (plot, plot_components) | |||
logger = logging.getLogger('prophet') | logger = logging.getLogger('prophet') | |||
logger.setLevel(logging.INFO) | logger.setLevel(logging.INFO) | |||
class Prophet(object): | class Prophet(object): | |||
"""Prophet forecaster. | """Prophet forecaster. | |||
Parameters | Parameters | |||
---------- | ---------- | |||
growth: String 'linear' or 'logistic' to specify a linear or logistic | growth: String 'linear', 'logistic' or 'flat' to specify a linear, logistic | |||
trend. | or | |||
flat trend. | ||||
changepoints: List of dates at which to include potential changepoints. If | changepoints: List of dates at which to include potential changepoints. If | |||
not specified, potential changepoints are selected automatically. | not specified, potential changepoints are selected automatically. | |||
n_changepoints: Number of potential changepoints to include. Not used | n_changepoints: Number of potential changepoints to include. Not used | |||
if input `changepoints` is supplied. If `changepoints` is not supplied, | if input `changepoints` is supplied. If `changepoints` is not supplied, | |||
then n_changepoints potential changepoints are selected uniformly from | then n_changepoints potential changepoints are selected uniformly from | |||
the first `changepoint_range` proportion of the history. | the first `changepoint_range` proportion of the history. | |||
changepoint_range: Proportion of history in which trend changepoints will | changepoint_range: Proportion of history in which trend changepoints will | |||
be estimated. Defaults to 0.8 for the first 80%. Not used if | be estimated. Defaults to 0.8 for the first 80%. Not used if | |||
`changepoints` is specified. | `changepoints` is specified. | |||
yearly_seasonality: Fit yearly seasonality. | yearly_seasonality: Fit yearly seasonality. | |||
skipping to change at line 536 | skipping to change at line 536 | |||
------- | ------- | |||
holiday_features: pd.DataFrame with a column for each holiday. | holiday_features: pd.DataFrame with a column for each holiday. | |||
prior_scale_list: List of prior scales for each holiday column. | prior_scale_list: List of prior scales for each holiday column. | |||
holiday_names: List of names of holidays | holiday_names: List of names of holidays | |||
""" | """ | |||
# Holds columns of our future matrix. | # Holds columns of our future matrix. | |||
expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0])) | expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0])) | |||
prior_scales = {} | prior_scales = {} | |||
# Makes an index so we can perform `get_loc` below. | # Makes an index so we can perform `get_loc` below. | |||
# Strip to just dates. | # Strip to just dates. | |||
row_index = pd.DatetimeIndex(dates.apply(lambda x: x.date())) | row_index = pd.DatetimeIndex(dates.dt.date) | |||
for _ix, row in holidays.iterrows(): | for row in holidays.itertuples(): | |||
dt = row.ds.date() | dt = row.ds.date() | |||
try: | try: | |||
lw = int(row.get('lower_window', 0)) | lw = int(getattr(row, 'lower_window', 0)) | |||
uw = int(row.get('upper_window', 0)) | uw = int(getattr(row, 'upper_window', 0)) | |||
except ValueError: | except ValueError: | |||
lw = 0 | lw = 0 | |||
uw = 0 | uw = 0 | |||
ps = float(row.get('prior_scale', self.holidays_prior_scale)) | ps = float(getattr(row, 'prior_scale', self.holidays_prior_scale)) | |||
if np.isnan(ps): | if np.isnan(ps): | |||
ps = float(self.holidays_prior_scale) | ps = float(self.holidays_prior_scale) | |||
if row.holiday in prior_scales and prior_scales[row.holiday] != ps: | if row.holiday in prior_scales and prior_scales[row.holiday] != ps: | |||
raise ValueError( | raise ValueError( | |||
'Holiday {holiday!r} does not have consistent prior ' | 'Holiday {holiday!r} does not have consistent prior ' | |||
'scale specification.'.format(holiday=row.holiday) | 'scale specification.'.format(holiday=row.holiday) | |||
) | ) | |||
if ps <= 0: | if ps <= 0: | |||
raise ValueError('Prior scale must be > 0') | raise ValueError('Prior scale must be > 0') | |||
prior_scales[row.holiday] = ps | prior_scales[row.holiday] = ps | |||
skipping to change at line 895 | skipping to change at line 895 | |||
group: List of components that form the group. | group: List of components that form the group. | |||
Returns | Returns | |||
------- | ------- | |||
Dataframe with components. | Dataframe with components. | |||
""" | """ | |||
new_comp = components[components['component'].isin(set(group))].copy() | new_comp = components[components['component'].isin(set(group))].copy() | |||
group_cols = new_comp['col'].unique() | group_cols = new_comp['col'].unique() | |||
if len(group_cols) > 0: | if len(group_cols) > 0: | |||
new_comp = pd.DataFrame({'col': group_cols, 'component': name}) | new_comp = pd.DataFrame({'col': group_cols, 'component': name}) | |||
components = components.append(new_comp) | components = pd.concat([components, new_comp]) | |||
return components | return components | |||
def parse_seasonality_args(self, name, arg, auto_disable, default_order): | def parse_seasonality_args(self, name, arg, auto_disable, default_order): | |||
"""Get number of fourier components for built-in seasonalities. | """Get number of fourier components for built-in seasonalities. | |||
Parameters | Parameters | |||
---------- | ---------- | |||
name: string name of the seasonality component. | name: string name of the seasonality component. | |||
arg: 'auto', True, False, or number of fourier components as provided. | arg: 'auto', True, False, or number of fourier components as provided. | |||
auto_disable: bool if seasonality should be disabled when 'auto'. | auto_disable: bool if seasonality should be disabled when 'auto'. | |||
skipping to change at line 1579 | skipping to change at line 1579 | |||
freq=freq) | freq=freq) | |||
dates = dates[dates > last_date] # Drop start if equals last_date | dates = dates[dates > last_date] # Drop start if equals last_date | |||
dates = dates[:periods] # Return correct number of periods | dates = dates[:periods] # Return correct number of periods | |||
if include_history: | if include_history: | |||
dates = np.concatenate((np.array(self.history_dates), dates)) | dates = np.concatenate((np.array(self.history_dates), dates)) | |||
return pd.DataFrame({'ds': dates}) | return pd.DataFrame({'ds': dates}) | |||
def plot(self, fcst, ax=None, uncertainty=True, plot_cap=True, | def plot(self, fcst, ax=None, uncertainty=True, plot_cap=True, | |||
xlabel='ds', ylabel='y', figsize=(10, 6)): | xlabel='ds', ylabel='y', figsize=(10, 6), include_legend=False): | |||
"""Plot the Prophet forecast. | """Plot the Prophet forecast. | |||
Parameters | Parameters | |||
---------- | ---------- | |||
fcst: pd.DataFrame output of self.predict. | fcst: pd.DataFrame output of self.predict. | |||
ax: Optional matplotlib axes on which to plot. | ax: Optional matplotlib axes on which to plot. | |||
uncertainty: Optional boolean to plot uncertainty intervals. | uncertainty: Optional boolean to plot uncertainty intervals. | |||
plot_cap: Optional boolean indicating if the capacity should be shown | plot_cap: Optional boolean indicating if the capacity should be shown | |||
in the figure, if available. | in the figure, if available. | |||
xlabel: Optional label name on X-axis | xlabel: Optional label name on X-axis | |||
ylabel: Optional label name on Y-axis | ylabel: Optional label name on Y-axis | |||
figsize: Optional tuple width, height in inches. | figsize: Optional tuple width, height in inches. | |||
include_legend: Optional boolean to add legend to the plot. | ||||
Returns | Returns | |||
------- | ------- | |||
A matplotlib figure. | A matplotlib figure. | |||
""" | """ | |||
return plot( | return plot( | |||
m=self, fcst=fcst, ax=ax, uncertainty=uncertainty, | m=self, fcst=fcst, ax=ax, uncertainty=uncertainty, | |||
plot_cap=plot_cap, xlabel=xlabel, ylabel=ylabel, | plot_cap=plot_cap, xlabel=xlabel, ylabel=ylabel, | |||
figsize=figsize | figsize=figsize, include_legend=include_legend | |||
) | ) | |||
def plot_components(self, fcst, uncertainty=True, plot_cap=True, | def plot_components(self, fcst, uncertainty=True, plot_cap=True, | |||
weekly_start=0, yearly_start=0, figsize=None): | weekly_start=0, yearly_start=0, figsize=None): | |||
"""Plot the Prophet forecast components. | """Plot the Prophet forecast components. | |||
Will plot whichever are available of: trend, holidays, weekly | Will plot whichever are available of: trend, holidays, weekly | |||
seasonality, and yearly seasonality. | seasonality, and yearly seasonality. | |||
Parameters | Parameters | |||
End of changes. 9 change blocks. | ||||
10 lines changed or deleted | 12 lines changed or added |