test_serialize.py (prophet-1.1) | : | test_serialize.py (prophet-1.1.1) | ||
---|---|---|---|---|
skipping to change at line 41 | skipping to change at line 41 | |||
N = DATA.shape[0] | N = DATA.shape[0] | |||
df = DATA.head(N - days) | df = DATA.head(N - days) | |||
m.fit(df) | m.fit(df) | |||
future = m.make_future_dataframe(2, include_history=False) | future = m.make_future_dataframe(2, include_history=False) | |||
fcst = m.predict(future) | fcst = m.predict(future) | |||
model_str = model_to_json(m) | model_str = model_to_json(m) | |||
# Make sure json doesn't get too large in the future | # Make sure json doesn't get too large in the future | |||
self.assertTrue(len(model_str) < 200000) | self.assertTrue(len(model_str) < 200000) | |||
z = json.loads(model_str) | ||||
self.assertEqual(z['__prophet_version'], '1.0') | ||||
m2 = model_from_json(model_str) | m2 = model_from_json(model_str) | |||
# Check that m and m2 are equal | # Check that m and m2 are equal | |||
self.assertEqual(m.__dict__.keys(), m2.__dict__.keys()) | self.assertEqual(m.__dict__.keys(), m2.__dict__.keys()) | |||
for k, v in m.__dict__.items(): | for k, v in m.__dict__.items(): | |||
if k in ['stan_fit', 'stan_backend']: | if k in ['stan_fit', 'stan_backend']: | |||
continue | continue | |||
if k == 'params': | if k == 'params': | |||
self.assertEqual(v.keys(), m2.params.keys()) | self.assertEqual(v.keys(), m2.params.keys()) | |||
for kk, vv in v.items(): | for kk, vv in v.items(): | |||
skipping to change at line 143 | skipping to change at line 140 | |||
# Check that m2 makes the same forecast | # Check that m2 makes the same forecast | |||
future = m2.make_future_dataframe(periods=100, include_history=False) | future = m2.make_future_dataframe(periods=100, include_history=False) | |||
fcst2 = m2.predict(test) | fcst2 = m2.predict(test) | |||
self.assertTrue(np.array_equal(fcst['yhat'].values, fcst2['yhat'].values )) | self.assertTrue(np.array_equal(fcst['yhat'].values, fcst2['yhat'].values )) | |||
def test_backwards_compatibility(self): | def test_backwards_compatibility(self): | |||
old_versions = { | old_versions = { | |||
'0.6.1.dev0': (29.3669923968994, 'fb'), | '0.6.1.dev0': (29.3669923968994, 'fb'), | |||
'0.7.1': (29.282810844704414, 'fb'), | '0.7.1': (29.282810844704414, 'fb'), | |||
'1.0.1': (29.282810844704414, ''), | ||||
} | } | |||
for v, (pred_val, v_str) in old_versions.items(): | for v, (pred_val, v_str) in old_versions.items(): | |||
fname = os.path.join( | fname = os.path.join( | |||
os.path.dirname(__file__), | os.path.dirname(__file__), | |||
'serialized_model_v{}.json'.format(v) | 'serialized_model_v{}.json'.format(v) | |||
) | ) | |||
with open(fname, 'r') as fin: | with open(fname, 'r') as fin: | |||
model_str = json.load(fin) | model_str = json.load(fin) | |||
# Check that deserializes | # Check that deserializes | |||
m = model_from_json(model_str) | m = model_from_json(model_str) | |||
End of changes. 2 change blocks. | ||||
3 lines changed or deleted | 1 lines changed or added |