test_serialize.py (prophet-0.7) | : | test_serialize.py (prophet-1.0) | ||
---|---|---|---|---|
skipping to change at line 18 | skipping to change at line 18 | |||
from __future__ import print_function | from __future__ import print_function | |||
from __future__ import unicode_literals | from __future__ import unicode_literals | |||
import json | import json | |||
import os | import os | |||
import sys | import sys | |||
from unittest import TestCase, skipUnless | from unittest import TestCase, skipUnless | |||
import numpy as np | import numpy as np | |||
import pandas as pd | import pandas as pd | |||
from fbprophet import Prophet | from prophet import Prophet | |||
from fbprophet.serialize import model_to_json, model_from_json, PD_SERIES, PD_DA | from prophet.serialize import model_to_json, model_from_json, PD_SERIES, PD_DATA | |||
TAFRAME | FRAME | |||
DATA = pd.read_csv( | DATA = pd.read_csv( | |||
os.path.join(os.path.dirname(__file__), 'data.csv'), | os.path.join(os.path.dirname(__file__), 'data.csv'), | |||
parse_dates=['ds'], | parse_dates=['ds'], | |||
) | ) | |||
class TestSerialize(TestCase): | class TestSerialize(TestCase): | |||
def test_simple_serialize(self): | def test_simple_serialize(self): | |||
m = Prophet() | m = Prophet() | |||
skipping to change at line 42 | skipping to change at line 42 | |||
df = DATA.head(N - days) | df = DATA.head(N - days) | |||
m.fit(df) | m.fit(df) | |||
future = m.make_future_dataframe(2, include_history=False) | future = m.make_future_dataframe(2, include_history=False) | |||
fcst = m.predict(future) | fcst = m.predict(future) | |||
model_str = model_to_json(m) | model_str = model_to_json(m) | |||
# Make sure json doesn't get too large in the future | # Make sure json doesn't get too large in the future | |||
self.assertTrue(len(model_str) < 200000) | self.assertTrue(len(model_str) < 200000) | |||
z = json.loads(model_str) | z = json.loads(model_str) | |||
self.assertEqual(z['__fbprophet_version'], '0.7.1') | self.assertEqual(z['__prophet_version'], '1.0') | |||
m2 = model_from_json(model_str) | m2 = model_from_json(model_str) | |||
# Check that m and m2 are equal | # Check that m and m2 are equal | |||
self.assertEqual(m.__dict__.keys(), m2.__dict__.keys()) | self.assertEqual(m.__dict__.keys(), m2.__dict__.keys()) | |||
for k, v in m.__dict__.items(): | for k, v in m.__dict__.items(): | |||
if k in ['stan_fit', 'stan_backend']: | if k in ['stan_fit', 'stan_backend']: | |||
continue | continue | |||
if k == 'params': | if k == 'params': | |||
self.assertEqual(v.keys(), m2.params.keys()) | self.assertEqual(v.keys(), m2.params.keys()) | |||
skipping to change at line 141 | skipping to change at line 141 | |||
self.assertTrue(m2.stan_backend is None) | self.assertTrue(m2.stan_backend is None) | |||
# Check that m2 makes the same forecast | # Check that m2 makes the same forecast | |||
future = m2.make_future_dataframe(periods=100, include_history=False) | future = m2.make_future_dataframe(periods=100, include_history=False) | |||
fcst2 = m2.predict(test) | fcst2 = m2.predict(test) | |||
self.assertTrue(np.array_equal(fcst['yhat'].values, fcst2['yhat'].values )) | self.assertTrue(np.array_equal(fcst['yhat'].values, fcst2['yhat'].values )) | |||
def test_backwards_compatibility(self): | def test_backwards_compatibility(self): | |||
old_versions = { | old_versions = { | |||
'0.6.1.dev0': 29.3669923968994, | '0.6.1.dev0': (29.3669923968994, 'fb'), | |||
'0.7.1': (29.282810844704414, 'fb'), | ||||
} | } | |||
for v, pred_val in old_versions.items(): | for v, (pred_val, v_str) in old_versions.items(): | |||
fname = os.path.join( | fname = os.path.join( | |||
os.path.dirname(__file__), | os.path.dirname(__file__), | |||
'serialized_model_v{}.json'.format(v) | 'serialized_model_v{}.json'.format(v) | |||
) | ) | |||
with open(fname, 'r') as fin: | with open(fname, 'r') as fin: | |||
model_str = json.load(fin) | model_str = json.load(fin) | |||
# Check that deserializes | # Check that deserializes | |||
m = model_from_json(model_str) | m = model_from_json(model_str) | |||
self.assertEqual(json.loads(model_str)['__fbprophet_version'], v) | self.assertEqual(json.loads(model_str)[f'__{v_str}prophet_version'], v) | |||
# Predict | # Predict | |||
future = m.make_future_dataframe(10) | future = m.make_future_dataframe(10) | |||
fcst = m.predict(future) | fcst = m.predict(future) | |||
self.assertAlmostEqual(fcst['yhat'].values[-1], pred_val) | self.assertAlmostEqual(fcst['yhat'].values[-1], pred_val) | |||
End of changes. 5 change blocks. | ||||
7 lines changed or deleted | 8 lines changed or added |