serialize.py (prophet-1.0) | : | serialize.py (prophet-1.1) | ||
---|---|---|---|---|
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | |||
# Copyright (c) Facebook, Inc. and its affiliates. | # Copyright (c) Facebook, Inc. and its affiliates. | |||
# This source code is licensed under the MIT license found in the | # This source code is licensed under the MIT license found in the | |||
# LICENSE file in the root directory of this source tree. | # LICENSE file in the root directory of this source tree. | |||
from __future__ import absolute_import, division, print_function | from __future__ import absolute_import, division, print_function | |||
from collections import OrderedDict | from collections import OrderedDict | |||
from copy import deepcopy | from copy import deepcopy | |||
from io import StringIO | ||||
import json | import json | |||
import numpy as np | import numpy as np | |||
import pandas as pd | import pandas as pd | |||
from prophet.forecaster import Prophet | from prophet.forecaster import Prophet | |||
from prophet import __version__ | from prophet import __version__ | |||
SIMPLE_ATTRIBUTES = [ | SIMPLE_ATTRIBUTES = [ | |||
'growth', 'n_changepoints', 'specified_changepoints', 'changepoint_range', | 'growth', 'n_changepoints', 'specified_changepoints', 'changepoint_range', | |||
skipping to change at line 39 | skipping to change at line 40 | |||
PD_TIMESTAMP = ['start'] | PD_TIMESTAMP = ['start'] | |||
PD_TIMEDELTA = ['t_scale'] | PD_TIMEDELTA = ['t_scale'] | |||
PD_DATAFRAME = ['holidays', 'history', 'train_component_cols'] | PD_DATAFRAME = ['holidays', 'history', 'train_component_cols'] | |||
NP_ARRAY = ['changepoints_t'] | NP_ARRAY = ['changepoints_t'] | |||
ORDEREDDICT = ['seasonalities', 'extra_regressors'] | ORDEREDDICT = ['seasonalities', 'extra_regressors'] | |||
def model_to_json(model): | def model_to_dict(model): | |||
"""Serialize a Prophet model to json string. | """Convert a Prophet model to a dictionary suitable for JSON serialization. | |||
Model must be fitted. Skips Stan objects that are not needed for predict. | Model must be fitted. Skips Stan objects that are not needed for predict. | |||
Can be deserialized with model_from_json. | Can be reversed with model_from_dict. | |||
Parameters | Parameters | |||
---------- | ---------- | |||
model: Prophet model object. | model: Prophet model object. | |||
Returns | Returns | |||
------- | ------- | |||
json string that can be deserialized into a Prophet model. | dict that can be used to serialize a Prophet model as JSON or loaded back | |||
into a Prophet model. | ||||
""" | """ | |||
if model.history is None: | if model.history is None: | |||
raise ValueError( | raise ValueError( | |||
"This can only be used to serialize models that have already been fi t." | "This can only be used to serialize models that have already been fi t." | |||
) | ) | |||
model_json = { | model_dict = { | |||
attribute: getattr(model, attribute) for attribute in SIMPLE_ATTRIBUTES | attribute: getattr(model, attribute) for attribute in SIMPLE_ATTRIBUTES | |||
} | } | |||
# Handle attributes of non-core types | # Handle attributes of non-core types | |||
for attribute in PD_SERIES: | for attribute in PD_SERIES: | |||
if getattr(model, attribute) is None: | if getattr(model, attribute) is None: | |||
model_json[attribute] = None | model_dict[attribute] = None | |||
else: | else: | |||
model_json[attribute] = getattr(model, attribute).to_json( | model_dict[attribute] = getattr(model, attribute).to_json( | |||
orient='split', date_format='iso' | orient='split', date_format='iso' | |||
) | ) | |||
for attribute in PD_TIMESTAMP: | for attribute in PD_TIMESTAMP: | |||
model_json[attribute] = getattr(model, attribute).timestamp() | model_dict[attribute] = getattr(model, attribute).timestamp() | |||
for attribute in PD_TIMEDELTA: | for attribute in PD_TIMEDELTA: | |||
model_json[attribute] = getattr(model, attribute).total_seconds() | model_dict[attribute] = getattr(model, attribute).total_seconds() | |||
for attribute in PD_DATAFRAME: | for attribute in PD_DATAFRAME: | |||
if getattr(model, attribute) is None: | if getattr(model, attribute) is None: | |||
model_json[attribute] = None | model_dict[attribute] = None | |||
else: | else: | |||
model_json[attribute] = getattr(model, attribute).to_json(orient='ta ble', index=False) | model_dict[attribute] = getattr(model, attribute).to_json(orient='ta ble', index=False) | |||
for attribute in NP_ARRAY: | for attribute in NP_ARRAY: | |||
model_json[attribute] = getattr(model, attribute).tolist() | model_dict[attribute] = getattr(model, attribute).tolist() | |||
for attribute in ORDEREDDICT: | for attribute in ORDEREDDICT: | |||
model_json[attribute] = [ | model_dict[attribute] = [ | |||
list(getattr(model, attribute).keys()), | list(getattr(model, attribute).keys()), | |||
getattr(model, attribute), | getattr(model, attribute), | |||
] | ] | |||
# Other attributes with special handling | # Other attributes with special handling | |||
# fit_kwargs -> Transform any numpy types before serializing. | # fit_kwargs -> Transform any numpy types before serializing. | |||
# They do not need to be transformed back on deserializing. | # They do not need to be transformed back on deserializing. | |||
fit_kwargs = deepcopy(model.fit_kwargs) | fit_kwargs = deepcopy(model.fit_kwargs) | |||
if 'init' in fit_kwargs: | if 'init' in fit_kwargs: | |||
for k, v in fit_kwargs['init'].items(): | for k, v in fit_kwargs['init'].items(): | |||
if isinstance(v, np.ndarray): | if isinstance(v, np.ndarray): | |||
fit_kwargs['init'][k] = v.tolist() | fit_kwargs['init'][k] = v.tolist() | |||
elif isinstance(v, np.floating): | elif isinstance(v, np.floating): | |||
fit_kwargs['init'][k] = float(v) | fit_kwargs['init'][k] = float(v) | |||
model_json['fit_kwargs'] = fit_kwargs | model_dict['fit_kwargs'] = fit_kwargs | |||
# Params (Dict[str, np.ndarray]) | # Params (Dict[str, np.ndarray]) | |||
model_json['params'] = {k: v.tolist() for k, v in model.params.items()} | model_dict['params'] = {k: v.tolist() for k, v in model.params.items()} | |||
# Attributes that are skipped: stan_fit, stan_backend | # Attributes that are skipped: stan_fit, stan_backend | |||
model_json['__prophet_version'] = __version__ | model_dict['__prophet_version'] = __version__ | |||
return model_dict | ||||
def model_to_json(model): | ||||
"""Serialize a Prophet model to json string. | ||||
Model must be fitted. Skips Stan objects that are not needed for predict. | ||||
Can be deserialized with model_from_json. | ||||
Parameters | ||||
---------- | ||||
model: Prophet model object. | ||||
Returns | ||||
------- | ||||
json string that can be deserialized into a Prophet model. | ||||
""" | ||||
model_json = model_to_dict(model) | ||||
return json.dumps(model_json) | return json.dumps(model_json) | |||
def model_from_json(model_json): | def model_from_dict(model_dict): | |||
"""Deserialize a Prophet model from json string. | """Recreate a Prophet model from a dictionary. | |||
Deserializes models that were serialized with model_to_json. | Recreates models that were converted with model_to_dict. | |||
Parameters | Parameters | |||
---------- | ---------- | |||
model_json: Serialized model string | model_dict: Dictionary containing model, created with model_to_dict. | |||
Returns | Returns | |||
------- | ------- | |||
Prophet model. | Prophet model. | |||
""" | """ | |||
attr_dict = json.loads(model_json) | ||||
model = Prophet() # We will overwrite all attributes set in init anyway | model = Prophet() # We will overwrite all attributes set in init anyway | |||
# Simple types | # Simple types | |||
for attribute in SIMPLE_ATTRIBUTES: | for attribute in SIMPLE_ATTRIBUTES: | |||
setattr(model, attribute, attr_dict[attribute]) | setattr(model, attribute, model_dict[attribute]) | |||
for attribute in PD_SERIES: | for attribute in PD_SERIES: | |||
if attr_dict[attribute] is None: | if model_dict[attribute] is None: | |||
setattr(model, attribute, None) | setattr(model, attribute, None) | |||
else: | else: | |||
s = pd.read_json(attr_dict[attribute], typ='series', orient='split') | s = pd.read_json(StringIO(model_dict[attribute]), typ='series', orie nt='split') | |||
if s.name == 'ds': | if s.name == 'ds': | |||
if len(s) == 0: | if len(s) == 0: | |||
s = pd.to_datetime(s) | s = pd.to_datetime(s) | |||
s = s.dt.tz_localize(None) | s = s.dt.tz_localize(None) | |||
setattr(model, attribute, s) | setattr(model, attribute, s) | |||
for attribute in PD_TIMESTAMP: | for attribute in PD_TIMESTAMP: | |||
setattr(model, attribute, pd.Timestamp.utcfromtimestamp(attr_dict[attrib ute])) | setattr(model, attribute, pd.Timestamp.utcfromtimestamp(model_dict[attri bute])) | |||
for attribute in PD_TIMEDELTA: | for attribute in PD_TIMEDELTA: | |||
setattr(model, attribute, pd.Timedelta(seconds=attr_dict[attribute])) | setattr(model, attribute, pd.Timedelta(seconds=model_dict[attribute])) | |||
for attribute in PD_DATAFRAME: | for attribute in PD_DATAFRAME: | |||
if attr_dict[attribute] is None: | if model_dict[attribute] is None: | |||
setattr(model, attribute, None) | setattr(model, attribute, None) | |||
else: | else: | |||
df = pd.read_json(attr_dict[attribute], typ='frame', orient='table', convert_dates=['ds']) | df = pd.read_json(StringIO(model_dict[attribute]), typ='frame', orie nt='table', convert_dates=['ds']) | |||
if attribute == 'train_component_cols': | if attribute == 'train_component_cols': | |||
# Special handling because of named index column | # Special handling because of named index column | |||
df.columns.name = 'component' | df.columns.name = 'component' | |||
df.index.name = 'col' | df.index.name = 'col' | |||
setattr(model, attribute, df) | setattr(model, attribute, df) | |||
for attribute in NP_ARRAY: | for attribute in NP_ARRAY: | |||
setattr(model, attribute, np.array(attr_dict[attribute])) | setattr(model, attribute, np.array(model_dict[attribute])) | |||
for attribute in ORDEREDDICT: | for attribute in ORDEREDDICT: | |||
key_list, unordered_dict = attr_dict[attribute] | key_list, unordered_dict = model_dict[attribute] | |||
od = OrderedDict() | od = OrderedDict() | |||
for key in key_list: | for key in key_list: | |||
od[key] = unordered_dict[key] | od[key] = unordered_dict[key] | |||
setattr(model, attribute, od) | setattr(model, attribute, od) | |||
# Other attributes with special handling | # Other attributes with special handling | |||
# fit_kwargs | # fit_kwargs | |||
model.fit_kwargs = attr_dict['fit_kwargs'] | model.fit_kwargs = model_dict['fit_kwargs'] | |||
# Params (Dict[str, np.ndarray]) | # Params (Dict[str, np.ndarray]) | |||
model.params = {k: np.array(v) for k, v in attr_dict['params'].items()} | model.params = {k: np.array(v) for k, v in model_dict['params'].items()} | |||
# Skipped attributes | # Skipped attributes | |||
model.stan_backend = None | model.stan_backend = None | |||
model.stan_fit = None | model.stan_fit = None | |||
return model | return model | |||
def model_from_json(model_json): | ||||
"""Deserialize a Prophet model from json string. | ||||
Deserializes models that were serialized with model_to_json. | ||||
Parameters | ||||
---------- | ||||
model_json: Serialized model string | ||||
Returns | ||||
------- | ||||
Prophet model. | ||||
""" | ||||
model_dict = json.loads(model_json) | ||||
return model_from_dict(model_dict) | ||||
End of changes. 32 change blocks. | ||||
32 lines changed or deleted | 51 lines changed or added |