"Fossies" - the Fresh Open Source Software Archive  

Source code changes of the file "aesara/scan/utils.py" between
aesara-rel-2.1.1.tar.gz and aesara-rel-2.1.2.tar.gz

About: Aesara is a Python library that allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays. It can use GPUs and perform efficient symbolic differentiation (formerly "Theano-PyMC"; a fork of the no longer developed original Theano library).

utils.py  (aesara-rel-2.1.1):utils.py  (aesara-rel-2.1.2)
skipping to change at line 18 skipping to change at line 18
"Pascal Lamblin " "Pascal Lamblin "
"Arnaud Bergeron " "Arnaud Bergeron "
"PyMC Developers " "PyMC Developers "
"Aesara Developers " "Aesara Developers "
) )
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
import copy import copy
import logging import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict, namedtuple
import numpy as np import numpy as np
from aesara import scalar as aes from aesara import scalar as aes
from aesara import tensor as aet from aesara import tensor as aet
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import (
Constant, Constant,
Variable, Variable,
clone_replace, clone_replace,
skipping to change at line 859 skipping to change at line 859
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
allinputs = list(graph_inputs(outputs)) allinputs = list(graph_inputs(outputs))
for inp in allinputs: for inp in allinputs:
if isinstance(inp, Constant): if isinstance(inp, Constant):
givens[inp] = inp.clone() givens[inp] = inp.clone()
nw_outputs = clone_replace(outputs, replace=givens) nw_outputs = clone_replace(outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
class scan_args: FieldInfo = namedtuple(
""" "FieldInfo", ("name", "agg_name", "index", "inner_index", "agg_index")
Parses the inputs and outputs of scan in an easy to manipulate format. )
""" def safe_index(lst, x):
try:
return lst.index(x)
except ValueError:
return None
def default_filter_scanargs(x):
return x.startswith("inner_") or x.startswith("outer_")
class ScanArgs:
"""Parses the inputs and outputs of `Scan` in an easy to manipulate format."
""
default_filter = default_filter_scanargs
nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit
_mot")
def __init__( def __init__(
self, outer_inputs, outer_outputs, _inner_inputs, _inner_outputs, info self, outer_inputs, outer_outputs, _inner_inputs, _inner_outputs, info
): ):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "") rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
if info["as_while"]: if info["as_while"]:
self.cond = [rval[1][-1]] self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1] inner_outputs = rval[1][:-1]
else: else:
skipping to change at line 989 skipping to change at line 1002
"mode", "mode",
"destroy_map", "destroy_map",
"gpua", "gpua",
"as_while", "as_while",
"profile", "profile",
"allow_gc", "allow_gc",
): ):
if k in info: if k in info:
self.other_info[k] = info[k] self.other_info[k] = info[k]
@staticmethod
def from_node(node):
from aesara.scan.op import Scan
if not isinstance(node.op, Scan):
raise TypeError("{} is not a Scan node".format(node))
return ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.
info
)
@classmethod
def create_empty(cls):
info = OrderedDict(
[
("n_seqs", 0),
("n_mit_mot", 0),
("n_mit_sot", 0),
("tap_array", []),
("n_sit_sot", 0),
("n_nit_sot", 0),
("n_shared_outs", 0),
("n_mit_mot_outs", 0),
("mit_mot_out_slices", []),
("truncate_gradient", -1),
("name", None),
("mode", None),
("destroy_map", OrderedDict()),
("gpua", False),
("as_while", False),
("profile", False),
("allow_gc", False),
]
)
res = cls([1], [], [], [], info)
res.n_steps = None
return res
@property
def n_nit_sot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oou
t_mappings`
return self.info["n_nit_sot"]
@property
def inputs(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oou
t_mappings`
return self.inner_inputs
@property
def n_mit_mot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oou
t_mappings`
return self.info["n_mit_mot"]
@property
def var_mappings(self):
from aesara.scan.op import Scan
return Scan.get_oinp_iinp_iout_oout_mappings(self)
@property
def field_names(self):
res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"]
res.extend(
[
attr
for attr in self.__dict__
if attr.startswith("inner_in")
or attr.startswith("inner_out")
or attr.startswith("outer_in")
or attr.startswith("outer_out")
or attr == "n_steps"
]
)
return res
@property @property
def inner_inputs(self): def inner_inputs(self):
return ( return (
self.inner_in_seqs self.inner_in_seqs
+ sum(self.inner_in_mit_mot, []) + sum(self.inner_in_mit_mot, [])
+ sum(self.inner_in_mit_sot, []) + sum(self.inner_in_mit_sot, [])
+ self.inner_in_sit_sot + self.inner_in_sit_sot
+ self.inner_in_shared + self.inner_in_shared
+ self.inner_in_non_seqs + self.inner_in_non_seqs
) )
skipping to change at line 1053 skipping to change at line 1140
+ [[-1]] * len(self.inner_in_sit_sot) + [[-1]] * len(self.inner_in_sit_sot)
), ),
n_sit_sot=len(self.outer_in_sit_sot), n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared), n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices), n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices, mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info, **self.other_info,
) )
def get_alt_field(self, var_info, alt_prefix):
"""Get the alternate input/output field for a given element of `ScanArgs
`.
For example, if `var_info` is in `ScanArgs.outer_out_sit_sot`, then
`get_alt_field(var_info, "inner_out")` returns the element corresponding
`var_info` in `ScanArgs.inner_out_sit_sot`.
Parameters
----------
var_info: TensorVariable or FieldInfo
The element for which we want the alternate
alt_prefix: str
The string prefix for the alternate field type. It can be one of
the following: "inner_out", "inner_in", "outer_in", and "outer_out".
Outputs
-------
TensorVariable
Returns the alternate variable.
"""
if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info)
alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :]
alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index]
return alt_var
def find_among_fields(self, i, field_filter=default_filter):
"""Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element.
Parameters
----------
i: theano.gof.graph.Variable
The element to find among this object's fields.
field_filter: function
A function passed to `filter` that determines which fields to
consider. It must take a string field name and return a truthy
value.
Returns
-------
A tuple of length 4 containing the field name string, the first index,
the second index (for nested lists), and the "major" index (i.e. the
index within the aggregate lists like `self.inner_inputs`,
`self.outer_outputs`, etc.), or a triple of `None` when no match is
found.
"""
field_names = filter(field_filter, self.field_names)
for field_name in field_names:
lst = getattr(self, field_name)
field_prefix = field_name[:8]
if field_prefix.endswith("in"):
agg_field_name = "{}puts".format(field_prefix)
else:
agg_field_name = "{}tputs".format(field_prefix)
agg_list = getattr(self, agg_field_name)
if field_name in self.nested_list_fields:
for n, sub_lst in enumerate(lst):
idx = safe_index(sub_lst, i)
if idx is not None:
agg_idx = safe_index(agg_list, i)
return FieldInfo(field_name, agg_field_name, n, idx, agg
_idx)
else:
idx = safe_index(lst, i)
if idx is not None:
agg_idx = safe_index(agg_list, i)
return FieldInfo(field_name, agg_field_name, idx, None, agg_
idx)
return None
def _remove_from_fields(self, i, field_filter=default_filter):
field_info = self.find_among_fields(i, field_filter=field_filter)
if field_info is None:
return None
if field_info.inner_index is not None:
getattr(self, field_info.name)[field_info.index].remove(i)
else:
getattr(self, field_info.name).remove(i)
return field_info
def get_dependent_nodes(self, i, seen=None):
from aesara.graph import inputs as at_inputs
if seen is None:
seen = {i}
else:
seen.add(i)
var_mappings = self.var_mappings
field_info = self.find_among_fields(i)
if field_info is None:
raise ValueError("{} not found among fields.".format(i))
# Find the `var_mappings` key suffix that matches the field/set of
# arguments containing our source node
if field_info.name[:8].endswith("_in"):
map_key_suffix = "{}p".format(field_info.name[:8])
else:
map_key_suffix = field_info.name[:9]
dependent_nodes = set()
for k, v in var_mappings.items():
if not k.endswith(map_key_suffix):
continue
dependent_idx = v[field_info.agg_index]
dependent_idx = (
dependent_idx if isinstance(dependent_idx, list) else [dependent
_idx]
)
# Get the `ScanArgs` field name for the aggregate list property
# corresponding to these dependent argument types (i.e. either
# "outer_inputs", "inner_inputs", "inner_outputs", or
# "outer_outputs").
# To do this, we need to parse the "shared" prefix of the
# current `var_mappings` key and append the missing parts so that
# it either forms `"*_inputs"` or `"*_outputs"`.
to_agg_field_prefix = k[:9]
if to_agg_field_prefix.endswith("p"):
to_agg_field_name = "{}uts".format(to_agg_field_prefix)
else:
to_agg_field_name = "{}puts".format(to_agg_field_prefix)
to_agg_field = getattr(self, to_agg_field_name)
for d_id in dependent_idx:
if d_id < 0:
continue
dependent_var = to_agg_field[d_id]
if dependent_var not in seen:
dependent_nodes.add(dependent_var)
if field_info.name.startswith("inner_in"):
# If starting from an inner-input, then we need to find any
# inner-outputs that depend on it.
for out_n in self.inner_outputs:
if i in at_inputs([out_n]):
if out_n not in seen:
dependent_nodes.add(out_n)
for n in tuple(dependent_nodes):
if n in seen:
continue
sub_dependent_nodes = self.get_dependent_nodes(n, seen=seen)
dependent_nodes |= sub_dependent_nodes
seen |= sub_dependent_nodes
return dependent_nodes
def remove_from_fields(self, i, rm_dependents=True):
if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i}
else:
vars_to_remove = {i}
rm_info = []
for v in vars_to_remove:
dependent_rm_info = self._remove_from_fields(v)
rm_info.append((v, dependent_rm_info))
return rm_info
def __copy__(self): def __copy__(self):
res = object.__new__(type(self)) res = object.__new__(type(self))
res.__dict__.update(self.__dict__) res.__dict__.update(self.__dict__)
# also copy mutable attrs # also copy mutable attrs
for attr in self.__dict__: for attr in self.__dict__:
if ( if (
attr.startswith("inner_in") attr.startswith("inner_in")
or attr.startswith("inner_out") or attr.startswith("inner_out")
or attr.startswith("outer_in") or attr.startswith("outer_in")
or attr.startswith("outer_out") or attr.startswith("outer_out")
skipping to change at line 1088 skipping to change at line 1348
attr.startswith("inner_in") attr.startswith("inner_in")
or attr.startswith("inner_out") or attr.startswith("inner_out")
or attr.startswith("outer_in") or attr.startswith("outer_in")
or attr.startswith("outer_out") or attr.startswith("outer_out")
or attr or attr
in ("mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slice s") in ("mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slice s")
): ):
getattr(res, attr).extend(getattr(other, attr)) getattr(res, attr).extend(getattr(other, attr))
return res return res
def __str__(self):
inner_arg_strs = [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("outer_in") or p == "n_steps"
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("inner_in")
]
inner_arg_strs += [
"\tmit_mot_in_slices={}".format(self.mit_mot_in_slices),
"\tmit_sot_in_slices={}".format(self.mit_sot_in_slices),
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("inner_out")
]
inner_arg_strs += [
"\tmit_mot_out_slices={}".format(self.mit_mot_out_slices),
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("outer_out")
]
res = "ScanArgs(\n{})".format(",\n".join(inner_arg_strs))
return res
def __repr__(self):
return self.__str__()
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
for field_name in self.field_names:
if not hasattr(other, field_name) or getattr(self, field_name) != ge
tattr(
other, field_name
):
return False
return True
def forced_replace(out, x, y): def forced_replace(out, x, y):
""" """
Check all internal values of the graph that compute the variable ``out`` Check all internal values of the graph that compute the variable ``out``
for occurrences of values identical with ``x``. If such occurrences are for occurrences of values identical with ``x``. If such occurrences are
encountered then they are replaced with variable ``y``. encountered then they are replaced with variable ``y``.
Parameters Parameters
---------- ----------
out : Aesara Variable out : Aesara Variable
x : Aesara Variable x : Aesara Variable
 End of changes. 6 change blocks. 
5 lines changed or deleted 322 lines changed or added

Home  |  About  |  Features  |  All  |  Newest  |  Dox  |  Diffs  |  RSS Feeds  |  Screenshots  |  Comments  |  Imprint  |  Privacy  |  HTTP(S)