serialize.py (prophet-1.1) | : | serialize.py (prophet-1.1.1) | ||
---|---|---|---|---|
skipping to change at line 13 | skipping to change at line 13 | |||
# This source code is licensed under the MIT license found in the | # This source code is licensed under the MIT license found in the | |||
# LICENSE file in the root directory of this source tree. | # LICENSE file in the root directory of this source tree. | |||
from __future__ import absolute_import, division, print_function | from __future__ import absolute_import, division, print_function | |||
from collections import OrderedDict | from collections import OrderedDict | |||
from copy import deepcopy | from copy import deepcopy | |||
from io import StringIO | from io import StringIO | |||
import json | import json | |||
from pathlib import Path | ||||
import numpy as np | import numpy as np | |||
import pandas as pd | import pandas as pd | |||
from prophet.forecaster import Prophet | from prophet.forecaster import Prophet | |||
from prophet import __version__ | ||||
about = {} | ||||
here = Path(__file__).parent.resolve() | ||||
with open(here / "__version__.py", "r") as f: | ||||
exec(f.read(), about) | ||||
SIMPLE_ATTRIBUTES = [ | SIMPLE_ATTRIBUTES = [ | |||
'growth', 'n_changepoints', 'specified_changepoints', 'changepoint_range', | 'growth', 'n_changepoints', 'specified_changepoints', 'changepoint_range', | |||
'yearly_seasonality', 'weekly_seasonality', 'daily_seasonality', | 'yearly_seasonality', 'weekly_seasonality', 'daily_seasonality', | |||
'seasonality_mode', 'seasonality_prior_scale', 'changepoint_prior_scale', | 'seasonality_mode', 'seasonality_prior_scale', 'changepoint_prior_scale', | |||
'holidays_prior_scale', 'mcmc_samples', 'interval_width', 'uncertainty_sampl es', | 'holidays_prior_scale', 'mcmc_samples', 'interval_width', 'uncertainty_sampl es', | |||
'y_scale', 'logistic_floor', 'country_holidays', 'component_modes' | 'y_scale', 'logistic_floor', 'country_holidays', 'component_modes' | |||
] | ] | |||
PD_SERIES = ['changepoints', 'history_dates', 'train_holiday_names'] | PD_SERIES = ['changepoints', 'history_dates', 'train_holiday_names'] | |||
skipping to change at line 103 | skipping to change at line 108 | |||
for k, v in fit_kwargs['init'].items(): | for k, v in fit_kwargs['init'].items(): | |||
if isinstance(v, np.ndarray): | if isinstance(v, np.ndarray): | |||
fit_kwargs['init'][k] = v.tolist() | fit_kwargs['init'][k] = v.tolist() | |||
elif isinstance(v, np.floating): | elif isinstance(v, np.floating): | |||
fit_kwargs['init'][k] = float(v) | fit_kwargs['init'][k] = float(v) | |||
model_dict['fit_kwargs'] = fit_kwargs | model_dict['fit_kwargs'] = fit_kwargs | |||
# Params (Dict[str, np.ndarray]) | # Params (Dict[str, np.ndarray]) | |||
model_dict['params'] = {k: v.tolist() for k, v in model.params.items()} | model_dict['params'] = {k: v.tolist() for k, v in model.params.items()} | |||
# Attributes that are skipped: stan_fit, stan_backend | # Attributes that are skipped: stan_fit, stan_backend | |||
model_dict['__prophet_version'] = __version__ | model_dict['__prophet_version'] = about["__version__"] | |||
return model_dict | return model_dict | |||
def model_to_json(model): | def model_to_json(model): | |||
"""Serialize a Prophet model to json string. | """Serialize a Prophet model to json string. | |||
Model must be fitted. Skips Stan objects that are not needed for predict. | Model must be fitted. Skips Stan objects that are not needed for predict. | |||
Can be deserialized with model_from_json. | Can be deserialized with model_from_json. | |||
Parameters | Parameters | |||
skipping to change at line 152 | skipping to change at line 157 | |||
if model_dict[attribute] is None: | if model_dict[attribute] is None: | |||
setattr(model, attribute, None) | setattr(model, attribute, None) | |||
else: | else: | |||
s = pd.read_json(StringIO(model_dict[attribute]), typ='series', orie nt='split') | s = pd.read_json(StringIO(model_dict[attribute]), typ='series', orie nt='split') | |||
if s.name == 'ds': | if s.name == 'ds': | |||
if len(s) == 0: | if len(s) == 0: | |||
s = pd.to_datetime(s) | s = pd.to_datetime(s) | |||
s = s.dt.tz_localize(None) | s = s.dt.tz_localize(None) | |||
setattr(model, attribute, s) | setattr(model, attribute, s) | |||
for attribute in PD_TIMESTAMP: | for attribute in PD_TIMESTAMP: | |||
setattr(model, attribute, pd.Timestamp.utcfromtimestamp(model_dict[attri bute])) | setattr(model, attribute, pd.Timestamp.utcfromtimestamp(model_dict[attri bute]).tz_localize(None)) | |||
for attribute in PD_TIMEDELTA: | for attribute in PD_TIMEDELTA: | |||
setattr(model, attribute, pd.Timedelta(seconds=model_dict[attribute])) | setattr(model, attribute, pd.Timedelta(seconds=model_dict[attribute])) | |||
for attribute in PD_DATAFRAME: | for attribute in PD_DATAFRAME: | |||
if model_dict[attribute] is None: | if model_dict[attribute] is None: | |||
setattr(model, attribute, None) | setattr(model, attribute, None) | |||
else: | else: | |||
df = pd.read_json(StringIO(model_dict[attribute]), typ='frame', orie nt='table', convert_dates=['ds']) | df = pd.read_json(StringIO(model_dict[attribute]), typ='frame', orie nt='table', convert_dates=['ds']) | |||
if attribute == 'train_component_cols': | if attribute == 'train_component_cols': | |||
# Special handling because of named index column | # Special handling because of named index column | |||
df.columns.name = 'component' | df.columns.name = 'component' | |||
End of changes. 4 change blocks. | ||||
3 lines changed or deleted | 8 lines changed or added |