plot.py (prophet-1.0) | : | plot.py (prophet-1.1) | ||
---|---|---|---|---|
skipping to change at line 41 | skipping to change at line 41 | |||
logger.error('Importing matplotlib failed. Plotting will not work.') | logger.error('Importing matplotlib failed. Plotting will not work.') | |||
try: | try: | |||
import plotly.graph_objs as go | import plotly.graph_objs as go | |||
from plotly.subplots import make_subplots | from plotly.subplots import make_subplots | |||
except ImportError: | except ImportError: | |||
logger.error('Importing plotly failed. Interactive plots will not work.') | logger.error('Importing plotly failed. Interactive plots will not work.') | |||
def plot( | def plot( | |||
m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y', | m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y', | |||
figsize=(10, 6) | figsize=(10, 6), include_legend=False | |||
): | ): | |||
"""Plot the Prophet forecast. | """Plot the Prophet forecast. | |||
Parameters | Parameters | |||
---------- | ---------- | |||
m: Prophet model. | m: Prophet model. | |||
fcst: pd.DataFrame output of m.predict. | fcst: pd.DataFrame output of m.predict. | |||
ax: Optional matplotlib axes on which to plot. | ax: Optional matplotlib axes on which to plot. | |||
uncertainty: Optional boolean to plot uncertainty intervals, which will | uncertainty: Optional boolean to plot uncertainty intervals, which will | |||
only be done if m.uncertainty_samples > 0. | only be done if m.uncertainty_samples > 0. | |||
plot_cap: Optional boolean indicating if the capacity should be shown | plot_cap: Optional boolean indicating if the capacity should be shown | |||
in the figure, if available. | in the figure, if available. | |||
xlabel: Optional label name on X-axis | xlabel: Optional label name on X-axis | |||
ylabel: Optional label name on Y-axis | ylabel: Optional label name on Y-axis | |||
figsize: Optional tuple width, height in inches. | figsize: Optional tuple width, height in inches. | |||
include_legend: Optional boolean to add legend to the plot. | ||||
Returns | Returns | |||
------- | ------- | |||
A matplotlib figure. | A matplotlib figure. | |||
""" | """ | |||
if ax is None: | if ax is None: | |||
fig = plt.figure(facecolor='w', figsize=figsize) | fig = plt.figure(facecolor='w', figsize=figsize) | |||
ax = fig.add_subplot(111) | ax = fig.add_subplot(111) | |||
else: | else: | |||
fig = ax.get_figure() | fig = ax.get_figure() | |||
fcst_t = fcst['ds'].dt.to_pydatetime() | fcst_t = fcst['ds'].dt.to_pydatetime() | |||
ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.') | ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.', | |||
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2') | label='Observed data points') | |||
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2', label='Forecast') | ||||
if 'cap' in fcst and plot_cap: | if 'cap' in fcst and plot_cap: | |||
ax.plot(fcst_t, fcst['cap'], ls='--', c='k') | ax.plot(fcst_t, fcst['cap'], ls='--', c='k', label='Maximum capacity') | |||
if m.logistic_floor and 'floor' in fcst and plot_cap: | if m.logistic_floor and 'floor' in fcst and plot_cap: | |||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k') | ax.plot(fcst_t, fcst['floor'], ls='--', c='k', label='Minimum capacity') | |||
if uncertainty and m.uncertainty_samples: | if uncertainty and m.uncertainty_samples: | |||
ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'], | ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'], | |||
color='#0072B2', alpha=0.2) | color='#0072B2', alpha=0.2, label='Uncertainty interval' ) | |||
# Specify formatting to workaround matplotlib issue #12925 | # Specify formatting to workaround matplotlib issue #12925 | |||
locator = AutoDateLocator(interval_multiples=False) | locator = AutoDateLocator(interval_multiples=False) | |||
formatter = AutoDateFormatter(locator) | formatter = AutoDateFormatter(locator) | |||
ax.xaxis.set_major_locator(locator) | ax.xaxis.set_major_locator(locator) | |||
ax.xaxis.set_major_formatter(formatter) | ax.xaxis.set_major_formatter(formatter) | |||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2) | ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2) | |||
ax.set_xlabel(xlabel) | ax.set_xlabel(xlabel) | |||
ax.set_ylabel(ylabel) | ax.set_ylabel(ylabel) | |||
if include_legend: | ||||
ax.legend() | ||||
fig.tight_layout() | fig.tight_layout() | |||
return fig | return fig | |||
def plot_components( | def plot_components( | |||
m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0, | m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0, | |||
figsize=None | figsize=None | |||
): | ): | |||
"""Plot the Prophet forecast components. | """Plot the Prophet forecast components. | |||
Will plot whichever are available of: trend, holidays, weekly | Will plot whichever are available of: trend, holidays, weekly | |||
skipping to change at line 460 | skipping to change at line 464 | |||
if trend: | if trend: | |||
artists.append(ax.plot(fcst['ds'], fcst['trend'], c=cp_color)) | artists.append(ax.plot(fcst['ds'], fcst['trend'], c=cp_color)) | |||
signif_changepoints = m.changepoints[ | signif_changepoints = m.changepoints[ | |||
np.abs(np.nanmean(m.params['delta'], axis=0)) >= threshold | np.abs(np.nanmean(m.params['delta'], axis=0)) >= threshold | |||
] if len(m.changepoints) > 0 else [] | ] if len(m.changepoints) > 0 else [] | |||
for cp in signif_changepoints: | for cp in signif_changepoints: | |||
artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle)) | artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle)) | |||
return artists | return artists | |||
def plot_cross_validation_metric( | def plot_cross_validation_metric( | |||
df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6), color='b' | df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6), color='b', | |||
point_color='gray' | ||||
): | ): | |||
"""Plot a performance metric vs. forecast horizon from cross validation. | """Plot a performance metric vs. forecast horizon from cross validation. | |||
Cross validation produces a collection of out-of-sample model predictions | Cross validation produces a collection of out-of-sample model predictions | |||
that can be compared to actual values, at a range of different horizons | that can be compared to actual values, at a range of different horizons | |||
(distance from the cutoff). This computes a specified performance metric | (distance from the cutoff). This computes a specified performance metric | |||
for each prediction, and aggregated over a rolling window with horizon. | for each prediction, and aggregated over a rolling window with horizon. | |||
This uses prophet.diagnostics.performance_metrics to compute the metrics. | This uses prophet.diagnostics.performance_metrics to compute the metrics. | |||
Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'. | Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'. | |||
skipping to change at line 530 | skipping to change at line 535 | |||
10 ** 3, | 10 ** 3, | |||
1., | 1., | |||
] | ] | |||
for i, dt in enumerate(dts): | for i, dt in enumerate(dts): | |||
if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'): | if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'): | |||
break | break | |||
x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / floa t(dt_conversions[i]) | x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / floa t(dt_conversions[i]) | |||
x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float (dt_conversions[i]) | x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float (dt_conversions[i]) | |||
ax.plot(x_plt, df_none[metric], '.', alpha=0.1, c=color) | ax.plot(x_plt, df_none[metric], '.', alpha=0.1, c=point_color) | |||
ax.plot(x_plt_h, df_h[metric], '-', c=color) | ax.plot(x_plt_h, df_h[metric], '-', c=color) | |||
ax.grid(True) | ax.grid(True) | |||
ax.set_xlabel('Horizon ({})'.format(dt_names[i])) | ax.set_xlabel('Horizon ({})'.format(dt_names[i])) | |||
ax.set_ylabel(metric) | ax.set_ylabel(metric) | |||
return fig | return fig | |||
def plot_plotly(m, fcst, uncertainty=True, plot_cap=True, trend=False, changepoi nts=False, | def plot_plotly(m, fcst, uncertainty=True, plot_cap=True, trend=False, changepoi nts=False, | |||
changepoints_threshold=0.01, xlabel='ds', ylabel='y', figsize=(9 00, 600)): | changepoints_threshold=0.01, xlabel='ds', ylabel='y', figsize=(9 00, 600)): | |||
"""Plot the Prophet forecast with Plotly offline. | """Plot the Prophet forecast with Plotly offline. | |||
End of changes. 9 change blocks. | ||||
8 lines changed or deleted | 13 lines changed or added |