diagnostics.py (prophet-1.0) | : | diagnostics.py (prophet-1.1) | ||
---|---|---|---|---|

skipping to change at line 116 | skipping to change at line 116 | |||

func(*args) | func(*args) | |||

for args in zip(*iterables) | for args in zip(*iterables) | |||

] | ] | |||

return results | return results | |||

Returns | Returns | |||

------- | ------- | |||

A pd.DataFrame with the forecast, actual value and cutoff. | A pd.DataFrame with the forecast, actual value and cutoff. | |||

""" | """ | |||

if model.history is None: | ||||

raise Exception('Model has not been fit. Fitting the model provides cont | ||||

extual parameters for cross validation.') | ||||

df = model.history.copy().reset_index(drop=True) | df = model.history.copy().reset_index(drop=True) | |||

horizon = pd.Timedelta(horizon) | horizon = pd.Timedelta(horizon) | |||

predict_columns = ['ds', 'yhat'] | predict_columns = ['ds', 'yhat'] | |||

if model.uncertainty_samples: | if model.uncertainty_samples: | |||

predict_columns.extend(['yhat_lower', 'yhat_upper']) | predict_columns.extend(['yhat_lower', 'yhat_upper']) | |||

# Identify largest seasonality period | # Identify largest seasonality period | |||

period_max = 0. | period_max = 0. | |||

for s in model.seasonalities.values(): | for s in model.seasonalities.values(): | |||

skipping to change at line 171 | skipping to change at line 174 | |||

valid = {"threads", "processes", "dask"} | valid = {"threads", "processes", "dask"} | |||

if parallel == "threads": | if parallel == "threads": | |||

pool = concurrent.futures.ThreadPoolExecutor() | pool = concurrent.futures.ThreadPoolExecutor() | |||

elif parallel == "processes": | elif parallel == "processes": | |||

pool = concurrent.futures.ProcessPoolExecutor() | pool = concurrent.futures.ProcessPoolExecutor() | |||

elif parallel == "dask": | elif parallel == "dask": | |||

try: | try: | |||

from dask.distributed import get_client | from dask.distributed import get_client | |||

except ImportError as e: | except ImportError as e: | |||

raise ImportError("parallel='dask' requies the optional " | raise ImportError("parallel='dask' requires the optional " | |||

"dependency dask.") from e | "dependency dask.") from e | |||

pool = get_client() | pool = get_client() | |||

# delay df and model to avoid large objects in task graph. | # delay df and model to avoid large objects in task graph. | |||

df, model = pool.scatter([df, model]) | df, model = pool.scatter([df, model]) | |||

elif hasattr(parallel, "map"): | elif hasattr(parallel, "map"): | |||

pool = parallel | pool = parallel | |||

else: | else: | |||

msg = ("'parallel' should be one of {} for an instance with a " | msg = ("'parallel' should be one of {} for an instance with a " | |||

"'map' method".format(', '.join(valid))) | "'map' method".format(', '.join(valid))) | |||

raise ValueError(msg) | raise ValueError(msg) | |||

skipping to change at line 414 | skipping to change at line 417 | |||

w: Integer window size (number of elements). | w: Integer window size (number of elements). | |||

name: Name for metric in result dataframe | name: Name for metric in result dataframe | |||

Returns | Returns | |||

------- | ------- | |||

Dataframe with columns horizon and name, the rolling mean of x. | Dataframe with columns horizon and name, the rolling mean of x. | |||

""" | """ | |||

# Aggregate over h | # Aggregate over h | |||

df = pd.DataFrame({'x': x, 'h': h}) | df = pd.DataFrame({'x': x, 'h': h}) | |||

df2 = ( | df2 = ( | |||

df.groupby('h').agg(['mean', 'count']).reset_index().sort_values('h') | df.groupby('h').agg(['sum', 'count']).reset_index().sort_values('h') | |||

) | ) | |||

xm = df2['x']['mean'].values | xs = df2['x']['sum'].values | |||

ns = df2['x']['count'].values | ns = df2['x']['count'].values | |||

hs = df2['h'].values | hs = df2.h.values | |||

trailing_i = len(df2) - 1 | ||||

x_sum = 0 | ||||

n_sum = 0 | ||||

# We don't know output size but it is bounded by len(df2) | ||||

res_x = np.empty(len(df2)) | ||||

res_h = [] | ||||

res_x = [] | ||||

# Start from the right and work backwards | # Start from the right and work backwards | |||

i = len(hs) - 1 | for i in range(len(df2) - 1, -1, -1): | |||

while i >= 0: | x_sum += xs[i] | |||

# Construct a mean of at least w samples. | n_sum += ns[i] | |||

n = int(ns[i]) | while n_sum >= w: | |||

xbar = float(xm[i]) | ||||

j = i - 1 | ||||

while ((n < w) and j >= 0): | ||||

# Include points from the previous horizon. All of them if still | # Include points from the previous horizon. All of them if still | |||

# less than w, otherwise just enough to get to w. | # less than w, otherwise weight the mean by the difference | |||

n2 = min(w - n, ns[j]) | excess_n = n_sum - w | |||

xbar = xbar * (n / (n + n2)) + xm[j] * (n2 / (n + n2)) | excess_x = excess_n * xs[i] / ns[i] | |||

n += n2 | res_x[trailing_i] = (x_sum - excess_x)/ w | |||

j -= 1 | x_sum -= xs[trailing_i] | |||

if n < w: | n_sum -= ns[trailing_i] | |||

# Ran out of horizons before enough points. | trailing_i -= 1 | |||

break | ||||

res_h.append(hs[i]) | res_h = hs[(trailing_i + 1):] | |||

res_x.append(xbar) | res_x = res_x[(trailing_i + 1):] | |||

i -= 1 | ||||

res_h.reverse() | ||||

res_x.reverse() | ||||

return pd.DataFrame({'horizon': res_h, name: res_x}) | return pd.DataFrame({'horizon': res_h, name: res_x}) | |||

def rolling_median_by_h(x, h, w, name): | def rolling_median_by_h(x, h, w, name): | |||

"""Compute a rolling median of x, after first aggregating by h. | """Compute a rolling median of x, after first aggregating by h. | |||

Right-aligned. Computes a single median for each unique value of h. Each | Right-aligned. Computes a single median for each unique value of h. Each | |||

median is over at least w samples. | median is over at least w samples. | |||

For each h where there are fewer than w samples, we take samples from the pr evious h, | For each h where there are fewer than w samples, we take samples from the pr evious h, | |||

moving backwards. (In other words, we ~ assume that the x's are shuffled wit hin each h.) | moving backwards. (In other words, we ~ assume that the x's are shuffled wit hin each h.) | |||

End of changes. 8 change blocks. | ||||

26 lines changed or deleted | | 29 lines changed or added |