"Fossies" - the Fresh Open Source Software Archive  

Source code changes of the file "python/prophet/models.py" between
prophet-1.1.tar.gz and prophet-1.1.1.tar.gz

About: Prophet is a tool for producing high quality forecasts for time series data that has multiple seasonality with linear or non-linear growth.

models.py  (prophet-1.1):models.py  (prophet-1.1.1)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from tempfile import mkdtemp
from typing import Tuple from typing import Tuple
from collections import OrderedDict from collections import OrderedDict
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
import os
import pickle
import pkg_resources import pkg_resources
import platform import platform
import logging import logging
logger = logging.getLogger('prophet.models') logger = logging.getLogger('prophet.models')
PLATFORM = "unix" PLATFORM = "win" if platform.platform().startswith("Win") else "unix"
if platform.platform().startswith("Win"):
PLATFORM = "win"
class IStanBackend(ABC): class IStanBackend(ABC):
def __init__(self): def __init__(self):
self.model = self.load_model() self.model = self.load_model()
self.stan_fit = None self.stan_fit = None
self.newton_fallback = True self.newton_fallback = True
def set_options(self, **kwargs): def set_options(self, **kwargs):
""" """
Specify model options as kwargs. Specify model options as kwargs.
skipping to change at line 95 skipping to change at line 90
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data) (stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'inits' not in kwargs and 'init' in kwargs: if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
args = dict( args = dict(
data=stan_data, data=stan_data,
inits=stan_init, inits=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
iter=int(1e4), iter=int(1e4),
output_dir = mkdtemp(),
) )
args.update(kwargs) args.update(kwargs)
try: try:
self.stan_fit = self.model.optimize(**args) self.stan_fit = self.model.optimize(**args)
except RuntimeError as e: except RuntimeError as e:
# Fall back on Newton # Fall back on Newton
if self.newton_fallback and args['algorithm'] != 'Newton': if not self.newton_fallback or args['algorithm'] == 'Newton':
logger.warning(
'Optimization terminated abnormally. Falling back to Newton.
'
)
args['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(**args)
else:
raise e raise e
logger.warning('Optimization terminated abnormally. Falling back to
Newton.')
args['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(**args)
params = self.stan_to_dict_numpy( params = self.stan_to_dict_numpy(
self.stan_fit.column_names, self.stan_fit.optimized_params_np) self.stan_fit.column_names, self.stan_fit.optimized_params_np)
for par in params: for par in params:
params[par] = params[par].reshape((1, -1)) params[par] = params[par].reshape((1, -1))
return params return params
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data) (stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'inits' not in kwargs and 'init' in kwargs: if 'inits' not in kwargs and 'init' in kwargs:
skipping to change at line 193 skipping to change at line 183
import numpy as np import numpy as np
output = OrderedDict() output = OrderedDict()
prev = None prev = None
start = 0 start = 0
end = 0 end = 0
two_dims = len(data.shape) > 1 two_dims = len(data.shape) > 1
for cname in column_names: for cname in column_names:
if "." in cname: parsed = cname.split(".") if "." in cname else cname.split("[")
parsed = cname.split(".")
else:
parsed = cname.split("[")
curr = parsed[0] curr = parsed[0]
if prev is None: if prev is None:
prev = curr prev = curr
if curr != prev: if curr != prev:
if prev in output: if prev in output:
raise RuntimeError( raise RuntimeError(
"Found repeated column name" "Found repeated column name"
) )
if two_dims: if two_dims:
output[prev] = np.array(data[:, start:end]) output[prev] = np.array(data[:, start:end])
else: else:
output[prev] = np.array(data[start:end]) output[prev] = np.array(data[start:end])
prev = curr prev = curr
start = end start = end
end += 1 end += 1
else:
end += 1
if prev in output: if prev in output:
raise RuntimeError( raise RuntimeError(
"Found repeated column name" "Found repeated column name"
) )
if two_dims: if two_dims:
output[prev] = np.array(data[:, start:end]) output[prev] = np.array(data[:, start:end])
else: else:
output[prev] = np.array(data[start:end]) output[prev] = np.array(data[start:end])
return output return output
class StanBackendEnum(Enum): class StanBackendEnum(Enum):
CMDSTANPY = CmdStanPyBackend CMDSTANPY = CmdStanPyBackend
@staticmethod @staticmethod
def get_backend_class(name: str) -> IStanBackend: def get_backend_class(name: str) -> IStanBackend:
try: try:
return StanBackendEnum[name].value return StanBackendEnum[name].value
except KeyError as e: except KeyError as e:
raise ValueError("Unknown stan backend: {}".format(name)) from e raise ValueError(f"Unknown stan backend: {name}") from e
 End of changes. 9 change blocks. 
25 lines changed or deleted 8 lines changed or added

Home  |  About  |  Features  |  All  |  Newest  |  Dox  |  Diffs  |  RSS Feeds  |  Screenshots  |  Comments  |  Imprint  |  Privacy  |  HTTP(S)