diagnostics.py (prophet-1.0) | : | diagnostics.py (prophet-1.1) | ||
---|---|---|---|---|
skipping to change at line 116 | skipping to change at line 116 | |||
func(*args) | func(*args) | |||
for args in zip(*iterables) | for args in zip(*iterables) | |||
] | ] | |||
return results | return results | |||
Returns | Returns | |||
------- | ------- | |||
A pd.DataFrame with the forecast, actual value and cutoff. | A pd.DataFrame with the forecast, actual value and cutoff. | |||
""" | """ | |||
if model.history is None: | ||||
raise Exception('Model has not been fit. Fitting the model provides cont | ||||
extual parameters for cross validation.') | ||||
df = model.history.copy().reset_index(drop=True) | df = model.history.copy().reset_index(drop=True) | |||
horizon = pd.Timedelta(horizon) | horizon = pd.Timedelta(horizon) | |||
predict_columns = ['ds', 'yhat'] | predict_columns = ['ds', 'yhat'] | |||
if model.uncertainty_samples: | if model.uncertainty_samples: | |||
predict_columns.extend(['yhat_lower', 'yhat_upper']) | predict_columns.extend(['yhat_lower', 'yhat_upper']) | |||
# Identify largest seasonality period | # Identify largest seasonality period | |||
period_max = 0. | period_max = 0. | |||
for s in model.seasonalities.values(): | for s in model.seasonalities.values(): | |||
skipping to change at line 171 | skipping to change at line 174 | |||
valid = {"threads", "processes", "dask"} | valid = {"threads", "processes", "dask"} | |||
if parallel == "threads": | if parallel == "threads": | |||
pool = concurrent.futures.ThreadPoolExecutor() | pool = concurrent.futures.ThreadPoolExecutor() | |||
elif parallel == "processes": | elif parallel == "processes": | |||
pool = concurrent.futures.ProcessPoolExecutor() | pool = concurrent.futures.ProcessPoolExecutor() | |||
elif parallel == "dask": | elif parallel == "dask": | |||
try: | try: | |||
from dask.distributed import get_client | from dask.distributed import get_client | |||
except ImportError as e: | except ImportError as e: | |||
raise ImportError("parallel='dask' requies the optional " | raise ImportError("parallel='dask' requires the optional " | |||
"dependency dask.") from e | "dependency dask.") from e | |||
pool = get_client() | pool = get_client() | |||
# delay df and model to avoid large objects in task graph. | # delay df and model to avoid large objects in task graph. | |||
df, model = pool.scatter([df, model]) | df, model = pool.scatter([df, model]) | |||
elif hasattr(parallel, "map"): | elif hasattr(parallel, "map"): | |||
pool = parallel | pool = parallel | |||
else: | else: | |||
msg = ("'parallel' should be one of {} for an instance with a " | msg = ("'parallel' should be one of {} for an instance with a " | |||
"'map' method".format(', '.join(valid))) | "'map' method".format(', '.join(valid))) | |||
raise ValueError(msg) | raise ValueError(msg) | |||
skipping to change at line 414 | skipping to change at line 417 | |||
w: Integer window size (number of elements). | w: Integer window size (number of elements). | |||
name: Name for metric in result dataframe | name: Name for metric in result dataframe | |||
Returns | Returns | |||
------- | ------- | |||
Dataframe with columns horizon and name, the rolling mean of x. | Dataframe with columns horizon and name, the rolling mean of x. | |||
""" | """ | |||
# Aggregate over h | # Aggregate over h | |||
df = pd.DataFrame({'x': x, 'h': h}) | df = pd.DataFrame({'x': x, 'h': h}) | |||
df2 = ( | df2 = ( | |||
df.groupby('h').agg(['mean', 'count']).reset_index().sort_values('h') | df.groupby('h').agg(['sum', 'count']).reset_index().sort_values('h') | |||
) | ) | |||
xm = df2['x']['mean'].values | xs = df2['x']['sum'].values | |||
ns = df2['x']['count'].values | ns = df2['x']['count'].values | |||
hs = df2['h'].values | hs = df2.h.values | |||
trailing_i = len(df2) - 1 | ||||
x_sum = 0 | ||||
n_sum = 0 | ||||
# We don't know output size but it is bounded by len(df2) | ||||
res_x = np.empty(len(df2)) | ||||
res_h = [] | ||||
res_x = [] | ||||
# Start from the right and work backwards | # Start from the right and work backwards | |||
i = len(hs) - 1 | for i in range(len(df2) - 1, -1, -1): | |||
while i >= 0: | x_sum += xs[i] | |||
# Construct a mean of at least w samples. | n_sum += ns[i] | |||
n = int(ns[i]) | while n_sum >= w: | |||
xbar = float(xm[i]) | ||||
j = i - 1 | ||||
while ((n < w) and j >= 0): | ||||
# Include points from the previous horizon. All of them if still | # Include points from the previous horizon. All of them if still | |||
# less than w, otherwise just enough to get to w. | # less than w, otherwise weight the mean by the difference | |||
n2 = min(w - n, ns[j]) | excess_n = n_sum - w | |||
xbar = xbar * (n / (n + n2)) + xm[j] * (n2 / (n + n2)) | excess_x = excess_n * xs[i] / ns[i] | |||
n += n2 | res_x[trailing_i] = (x_sum - excess_x)/ w | |||
j -= 1 | x_sum -= xs[trailing_i] | |||
if n < w: | n_sum -= ns[trailing_i] | |||
# Ran out of horizons before enough points. | trailing_i -= 1 | |||
break | ||||
res_h.append(hs[i]) | res_h = hs[(trailing_i + 1):] | |||
res_x.append(xbar) | res_x = res_x[(trailing_i + 1):] | |||
i -= 1 | ||||
res_h.reverse() | ||||
res_x.reverse() | ||||
return pd.DataFrame({'horizon': res_h, name: res_x}) | return pd.DataFrame({'horizon': res_h, name: res_x}) | |||
def rolling_median_by_h(x, h, w, name): | def rolling_median_by_h(x, h, w, name): | |||
"""Compute a rolling median of x, after first aggregating by h. | """Compute a rolling median of x, after first aggregating by h. | |||
Right-aligned. Computes a single median for each unique value of h. Each | Right-aligned. Computes a single median for each unique value of h. Each | |||
median is over at least w samples. | median is over at least w samples. | |||
For each h where there are fewer than w samples, we take samples from the pr evious h, | For each h where there are fewer than w samples, we take samples from the pr evious h, | |||
moving backwards. (In other words, we ~ assume that the x's are shuffled wit hin each h.) | moving backwards. (In other words, we ~ assume that the x's are shuffled wit hin each h.) | |||
End of changes. 8 change blocks. | ||||
26 lines changed or deleted | 29 lines changed or added |