models.py (prophet-1.0) | : | models.py (prophet-1.1) | ||
---|---|---|---|---|
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | |||
# Copyright (c) Facebook, Inc. and its affiliates. | # Copyright (c) Facebook, Inc. and its affiliates. | |||
# This source code is licensed under the MIT license found in the | # This source code is licensed under the MIT license found in the | |||
# LICENSE file in the root directory of this source tree. | # LICENSE file in the root directory of this source tree. | |||
from __future__ import absolute_import, division, print_function | from __future__ import absolute_import, division, print_function | |||
from abc import abstractmethod, ABC | from abc import abstractmethod, ABC | |||
from tempfile import mkdtemp | ||||
from typing import Tuple | from typing import Tuple | |||
from collections import OrderedDict | from collections import OrderedDict | |||
from enum import Enum | from enum import Enum | |||
from pathlib import Path | from pathlib import Path | |||
import os | ||||
import pickle | import pickle | |||
import pkg_resources | import pkg_resources | |||
import os | import platform | |||
import logging | import logging | |||
logger = logging.getLogger('prophet.models') | logger = logging.getLogger('prophet.models') | |||
PLATFORM = "unix" | ||||
if platform.platform().startswith("Win"): | ||||
PLATFORM = "win" | ||||
class IStanBackend(ABC): | class IStanBackend(ABC): | |||
def __init__(self): | def __init__(self): | |||
self.model = self.load_model() | self.model = self.load_model() | |||
self.stan_fit = None | self.stan_fit = None | |||
self.newton_fallback = True | self.newton_fallback = True | |||
def set_options(self, **kwargs): | def set_options(self, **kwargs): | |||
""" | """ | |||
Specify model options as kwargs. | Specify model options as kwargs. | |||
* newton_fallback [bool]: whether to fallback to Newton if L-BFGS fails | * newton_fallback [bool]: whether to fallback to Newton if L-BFGS fails | |||
skipping to change at line 54 | skipping to change at line 60 | |||
pass | pass | |||
@abstractmethod | @abstractmethod | |||
def fit(self, stan_init, stan_data, **kwargs) -> dict: | def fit(self, stan_init, stan_data, **kwargs) -> dict: | |||
pass | pass | |||
@abstractmethod | @abstractmethod | |||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | |||
pass | pass | |||
@staticmethod | ||||
@abstractmethod | ||||
def build_model(target_dir, model_dir): | ||||
pass | ||||
class CmdStanPyBackend(IStanBackend): | class CmdStanPyBackend(IStanBackend): | |||
CMDSTAN_VERSION = "2.26.1" | ||||
def __init__(self): | ||||
import cmdstanpy | ||||
# this must be set before super.__init__() for load_model to work on Win | ||||
dows | ||||
local_cmdstan = pkg_resources.resource_filename( | ||||
"prophet", f"stan_model/cmdstan-{self.CMDSTAN_VERSION}" | ||||
) | ||||
if Path(local_cmdstan).exists(): | ||||
cmdstanpy.set_cmdstan_path(local_cmdstan) | ||||
super().__init__() | ||||
@staticmethod | @staticmethod | |||
def get_type(): | def get_type(): | |||
return StanBackendEnum.CMDSTANPY.name | return StanBackendEnum.CMDSTANPY.name | |||
@staticmethod | ||||
def build_model(target_dir, model_dir): | ||||
from shutil import copy | ||||
import cmdstanpy | ||||
model_name = 'prophet.stan' | ||||
target_name = 'prophet_model.bin' | ||||
sm = cmdstanpy.CmdStanModel( | ||||
stan_file=os.path.join(model_dir, model_name)) | ||||
sm.compile() | ||||
copy(sm.exe_file, os.path.join(target_dir, target_name)) | ||||
def load_model(self): | def load_model(self): | |||
import cmdstanpy | import cmdstanpy | |||
model_file = pkg_resources.resource_filename( | model_file = pkg_resources.resource_filename( | |||
'prophet', | 'prophet', | |||
'stan_model/prophet_model.bin', | 'stan_model/prophet_model.bin', | |||
) | ) | |||
return cmdstanpy.CmdStanModel(exe_file=model_file) | return cmdstanpy.CmdStanModel(exe_file=model_file) | |||
def fit(self, stan_init, stan_data, **kwargs): | def fit(self, stan_init, stan_data, **kwargs): | |||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | |||
if 'algorithm' not in kwargs: | ||||
kwargs['algorithm'] = 'Newton' if stan_data['T'] < 100 else 'LBFGS' | if 'inits' not in kwargs and 'init' in kwargs: | |||
iterations = int(1e4) | kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] | |||
args = dict( | ||||
data=stan_data, | ||||
inits=stan_init, | ||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', | ||||
iter=int(1e4), | ||||
output_dir = mkdtemp(), | ||||
) | ||||
args.update(kwargs) | ||||
try: | try: | |||
self.stan_fit = self.model.optimize(data=stan_data, | self.stan_fit = self.model.optimize(**args) | |||
inits=stan_init, | ||||
iter=iterations, | ||||
**kwargs) | ||||
except RuntimeError as e: | except RuntimeError as e: | |||
# Fall back on Newton | # Fall back on Newton | |||
if self.newton_fallback and kwargs['algorithm'] != 'Newton': | if self.newton_fallback and args['algorithm'] != 'Newton': | |||
logger.warning( | logger.warning( | |||
'Optimization terminated abnormally. Falling back to Newton. ' | 'Optimization terminated abnormally. Falling back to Newton. ' | |||
) | ) | |||
kwargs['algorithm'] = 'Newton' | args['algorithm'] = 'Newton' | |||
self.stan_fit = self.model.optimize(data=stan_data, | self.stan_fit = self.model.optimize(**args) | |||
inits=stan_init, | ||||
iter=iterations, | ||||
**kwargs) | ||||
else: | else: | |||
raise e | raise e | |||
params = self.stan_to_dict_numpy( | params = self.stan_to_dict_numpy( | |||
self.stan_fit.column_names, self.stan_fit.optimized_params_np) | self.stan_fit.column_names, self.stan_fit.optimized_params_np) | |||
for par in params: | for par in params: | |||
params[par] = params[par].reshape((1, -1)) | params[par] = params[par].reshape((1, -1)) | |||
return params | return params | |||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | |||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | |||
if 'inits' not in kwargs and 'init' in kwargs: | ||||
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] | ||||
args = dict( | ||||
data=stan_data, | ||||
inits=stan_init, | ||||
) | ||||
if 'chains' not in kwargs: | if 'chains' not in kwargs: | |||
kwargs['chains'] = 4 | kwargs['chains'] = 4 | |||
iter_half = samples // 2 | iter_half = samples // 2 | |||
kwargs['iter_sampling'] = iter_half | ||||
if 'iter_warmup' not in kwargs: | if 'iter_warmup' not in kwargs: | |||
kwargs['iter_warmup'] = iter_half | kwargs['iter_warmup'] = iter_half | |||
self.stan_fit = self.model.sample(data=stan_data, | args.update(kwargs) | |||
inits=stan_init, | ||||
iter_sampling=iter_half, | self.stan_fit = self.model.sample(**args) | |||
**kwargs) | ||||
res = self.stan_fit.draws() | res = self.stan_fit.draws() | |||
(samples, c, columns) = res.shape | (samples, c, columns) = res.shape | |||
res = res.reshape((samples * c, columns)) | res = res.reshape((samples * c, columns)) | |||
params = self.stan_to_dict_numpy(self.stan_fit.column_names, res) | params = self.stan_to_dict_numpy(self.stan_fit.column_names, res) | |||
for par in params: | for par in params: | |||
s = params[par].shape | s = params[par].shape | |||
if s[1] == 1: | if s[1] == 1: | |||
params[par] = params[par].reshape((s[0],)) | params[par] = params[par].reshape((s[0],)) | |||
skipping to change at line 166 | skipping to change at line 177 | |||
's_m': data['s_m'].tolist(), | 's_m': data['s_m'].tolist(), | |||
'X': data['X'].to_numpy().tolist(), | 'X': data['X'].to_numpy().tolist(), | |||
'sigmas': data['sigmas'] | 'sigmas': data['sigmas'] | |||
} | } | |||
cmdstanpy_init = { | cmdstanpy_init = { | |||
'k': init['k'], | 'k': init['k'], | |||
'm': init['m'], | 'm': init['m'], | |||
'delta': init['delta'].tolist(), | 'delta': init['delta'].tolist(), | |||
'beta': init['beta'].tolist(), | 'beta': init['beta'].tolist(), | |||
'sigma_obs': 1 | 'sigma_obs': init['sigma_obs'] | |||
} | } | |||
return (cmdstanpy_init, cmdstanpy_data) | return (cmdstanpy_init, cmdstanpy_data) | |||
@staticmethod | @staticmethod | |||
def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'): | def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'): | |||
import numpy as np | import numpy as np | |||
output = OrderedDict() | output = OrderedDict() | |||
prev = None | prev = None | |||
skipping to change at line 216 | skipping to change at line 227 | |||
if prev in output: | if prev in output: | |||
raise RuntimeError( | raise RuntimeError( | |||
"Found repeated column name" | "Found repeated column name" | |||
) | ) | |||
if two_dims: | if two_dims: | |||
output[prev] = np.array(data[:, start:end]) | output[prev] = np.array(data[:, start:end]) | |||
else: | else: | |||
output[prev] = np.array(data[start:end]) | output[prev] = np.array(data[start:end]) | |||
return output | return output | |||
class PyStanBackend(IStanBackend): | ||||
@staticmethod | ||||
def get_type(): | ||||
return StanBackendEnum.PYSTAN.name | ||||
@staticmethod | ||||
def build_model(target_dir, model_dir): | ||||
import pystan | ||||
model_name = 'prophet.stan' | ||||
target_name = 'prophet_model.pkl' | ||||
with open(os.path.join(model_dir, model_name)) as f: | ||||
model_code = f.read() | ||||
sm = pystan.StanModel(model_code=model_code) | ||||
with open(os.path.join(target_dir, target_name), 'wb') as f: | ||||
pickle.dump(sm, f, protocol=pickle.HIGHEST_PROTOCOL) | ||||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | ||||
args = dict( | ||||
data=stan_data, | ||||
init=lambda: stan_init, | ||||
iter=samples, | ||||
) | ||||
args.update(kwargs) | ||||
self.stan_fit = self.model.sampling(**args) | ||||
out = {} | ||||
for par in self.stan_fit.model_pars: | ||||
out[par] = self.stan_fit[par] | ||||
# Shape vector parameters | ||||
if par in ['delta', 'beta'] and len(out[par].shape) < 2: | ||||
out[par] = out[par].reshape((-1, 1)) | ||||
return out | ||||
def fit(self, stan_init, stan_data, **kwargs) -> dict: | ||||
args = dict( | ||||
data=stan_data, | ||||
init=lambda: stan_init, | ||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', | ||||
iter=1e4, | ||||
) | ||||
args.update(kwargs) | ||||
try: | ||||
self.stan_fit = self.model.optimizing(**args) | ||||
except RuntimeError as e: | ||||
# Fall back on Newton | ||||
if self.newton_fallback and args['algorithm'] != 'Newton': | ||||
logger.warning( | ||||
'Optimization terminated abnormally. Falling back to Newton. | ||||
' | ||||
) | ||||
args['algorithm'] = 'Newton' | ||||
self.stan_fit = self.model.optimizing(**args) | ||||
else: | ||||
raise e | ||||
params = {} | ||||
for par in self.stan_fit.keys(): | ||||
params[par] = self.stan_fit[par].reshape((1, -1)) | ||||
return params | ||||
def load_model(self): | ||||
"""Load compiled Stan model""" | ||||
model_file = pkg_resources.resource_filename( | ||||
'prophet', | ||||
'stan_model/prophet_model.pkl', | ||||
) | ||||
with Path(model_file).open('rb') as f: | ||||
return pickle.load(f) | ||||
class StanBackendEnum(Enum): | class StanBackendEnum(Enum): | |||
PYSTAN = PyStanBackend | ||||
CMDSTANPY = CmdStanPyBackend | CMDSTANPY = CmdStanPyBackend | |||
@staticmethod | @staticmethod | |||
def get_backend_class(name: str) -> IStanBackend: | def get_backend_class(name: str) -> IStanBackend: | |||
try: | try: | |||
return StanBackendEnum[name].value | return StanBackendEnum[name].value | |||
except KeyError as e: | except KeyError as e: | |||
raise ValueError("Unknown stan backend: {}".format(name)) from e | raise ValueError("Unknown stan backend: {}".format(name)) from e | |||
End of changes. 17 change blocks. | ||||
110 lines changed or deleted | 48 lines changed or added |