models.py (prophet-1.1) | : | models.py (prophet-1.1.1) | ||
---|---|---|---|---|
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | |||
# Copyright (c) Facebook, Inc. and its affiliates. | # Copyright (c) Facebook, Inc. and its affiliates. | |||
# This source code is licensed under the MIT license found in the | # This source code is licensed under the MIT license found in the | |||
# LICENSE file in the root directory of this source tree. | # LICENSE file in the root directory of this source tree. | |||
from __future__ import absolute_import, division, print_function | from __future__ import absolute_import, division, print_function | |||
from abc import abstractmethod, ABC | from abc import abstractmethod, ABC | |||
from tempfile import mkdtemp | ||||
from typing import Tuple | from typing import Tuple | |||
from collections import OrderedDict | from collections import OrderedDict | |||
from enum import Enum | from enum import Enum | |||
from pathlib import Path | from pathlib import Path | |||
import os | ||||
import pickle | ||||
import pkg_resources | import pkg_resources | |||
import platform | import platform | |||
import logging | import logging | |||
logger = logging.getLogger('prophet.models') | logger = logging.getLogger('prophet.models') | |||
PLATFORM = "unix" | PLATFORM = "win" if platform.platform().startswith("Win") else "unix" | |||
if platform.platform().startswith("Win"): | ||||
PLATFORM = "win" | ||||
class IStanBackend(ABC): | class IStanBackend(ABC): | |||
def __init__(self): | def __init__(self): | |||
self.model = self.load_model() | self.model = self.load_model() | |||
self.stan_fit = None | self.stan_fit = None | |||
self.newton_fallback = True | self.newton_fallback = True | |||
def set_options(self, **kwargs): | def set_options(self, **kwargs): | |||
""" | """ | |||
Specify model options as kwargs. | Specify model options as kwargs. | |||
skipping to change at line 95 | skipping to change at line 90 | |||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | |||
if 'inits' not in kwargs and 'init' in kwargs: | if 'inits' not in kwargs and 'init' in kwargs: | |||
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] | kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] | |||
args = dict( | args = dict( | |||
data=stan_data, | data=stan_data, | |||
inits=stan_init, | inits=stan_init, | |||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', | algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', | |||
iter=int(1e4), | iter=int(1e4), | |||
output_dir = mkdtemp(), | ||||
) | ) | |||
args.update(kwargs) | args.update(kwargs) | |||
try: | try: | |||
self.stan_fit = self.model.optimize(**args) | self.stan_fit = self.model.optimize(**args) | |||
except RuntimeError as e: | except RuntimeError as e: | |||
# Fall back on Newton | # Fall back on Newton | |||
if self.newton_fallback and args['algorithm'] != 'Newton': | if not self.newton_fallback or args['algorithm'] == 'Newton': | |||
logger.warning( | ||||
'Optimization terminated abnormally. Falling back to Newton. | ||||
' | ||||
) | ||||
args['algorithm'] = 'Newton' | ||||
self.stan_fit = self.model.optimize(**args) | ||||
else: | ||||
raise e | raise e | |||
logger.warning('Optimization terminated abnormally. Falling back to | ||||
Newton.') | ||||
args['algorithm'] = 'Newton' | ||||
self.stan_fit = self.model.optimize(**args) | ||||
params = self.stan_to_dict_numpy( | params = self.stan_to_dict_numpy( | |||
self.stan_fit.column_names, self.stan_fit.optimized_params_np) | self.stan_fit.column_names, self.stan_fit.optimized_params_np) | |||
for par in params: | for par in params: | |||
params[par] = params[par].reshape((1, -1)) | params[par] = params[par].reshape((1, -1)) | |||
return params | return params | |||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: | |||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) | |||
if 'inits' not in kwargs and 'init' in kwargs: | if 'inits' not in kwargs and 'init' in kwargs: | |||
skipping to change at line 193 | skipping to change at line 183 | |||
import numpy as np | import numpy as np | |||
output = OrderedDict() | output = OrderedDict() | |||
prev = None | prev = None | |||
start = 0 | start = 0 | |||
end = 0 | end = 0 | |||
two_dims = len(data.shape) > 1 | two_dims = len(data.shape) > 1 | |||
for cname in column_names: | for cname in column_names: | |||
if "." in cname: | parsed = cname.split(".") if "." in cname else cname.split("[") | |||
parsed = cname.split(".") | ||||
else: | ||||
parsed = cname.split("[") | ||||
curr = parsed[0] | curr = parsed[0] | |||
if prev is None: | if prev is None: | |||
prev = curr | prev = curr | |||
if curr != prev: | if curr != prev: | |||
if prev in output: | if prev in output: | |||
raise RuntimeError( | raise RuntimeError( | |||
"Found repeated column name" | "Found repeated column name" | |||
) | ) | |||
if two_dims: | if two_dims: | |||
output[prev] = np.array(data[:, start:end]) | output[prev] = np.array(data[:, start:end]) | |||
else: | else: | |||
output[prev] = np.array(data[start:end]) | output[prev] = np.array(data[start:end]) | |||
prev = curr | prev = curr | |||
start = end | start = end | |||
end += 1 | end += 1 | |||
else: | ||||
end += 1 | ||||
if prev in output: | if prev in output: | |||
raise RuntimeError( | raise RuntimeError( | |||
"Found repeated column name" | "Found repeated column name" | |||
) | ) | |||
if two_dims: | if two_dims: | |||
output[prev] = np.array(data[:, start:end]) | output[prev] = np.array(data[:, start:end]) | |||
else: | else: | |||
output[prev] = np.array(data[start:end]) | output[prev] = np.array(data[start:end]) | |||
return output | return output | |||
class StanBackendEnum(Enum): | class StanBackendEnum(Enum): | |||
CMDSTANPY = CmdStanPyBackend | CMDSTANPY = CmdStanPyBackend | |||
@staticmethod | @staticmethod | |||
def get_backend_class(name: str) -> IStanBackend: | def get_backend_class(name: str) -> IStanBackend: | |||
try: | try: | |||
return StanBackendEnum[name].value | return StanBackendEnum[name].value | |||
except KeyError as e: | except KeyError as e: | |||
raise ValueError("Unknown stan backend: {}".format(name)) from e | raise ValueError(f"Unknown stan backend: {name}") from e | |||
End of changes. 9 change blocks. | ||||
25 lines changed or deleted | 8 lines changed or added |