"Fossies" - the Fresh Open Source Software Archive  

Source code changes of the file "aesara/scan/opt.py" between
aesara-rel-2.1.1.tar.gz and aesara-rel-2.1.2.tar.gz

About: Aesara is a Python library that allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays. It can use GPUs and perform efficient symbolic differentiation (formerly "Theano-PyMC"; a fork of the no longer developed original Theano library).

opt.py  (aesara-rel-2.1.1):opt.py  (aesara-rel-2.1.2)
skipping to change at line 83 skipping to change at line 83
is_in_ancestors, is_in_ancestors,
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import ( from aesara.scan.utils import (
ScanArgs,
compress_outs, compress_outs,
expand_empty, expand_empty,
reconstruct_graph, reconstruct_graph,
safe_new, safe_new,
scan_args,
scan_can_remove_outs, scan_can_remove_outs,
) )
from aesara.tensor import basic_opt, math_opt from aesara.tensor import basic_opt, math_opt
from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, dot, maximum, minimum from aesara.tensor.math import Dot, dot, maximum, minimum
from aesara.tensor.shape import shape from aesara.tensor.shape import shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
IncSubtensor, IncSubtensor,
skipping to change at line 748 skipping to change at line 748
] ]
for node in nodelist: for node in nodelist:
# Process the node as long as something gets optimized # Process the node as long as something gets optimized
while node is not None: while node is not None:
node = self.process_node(fgraph, node) node = self.process_node(fgraph, node)
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
op = node.op op = node.op
# Use scan_args to parse the inputs and outputs of scan for ease of # Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use # use
args = scan_args(node.inputs, node.outputs, op.inputs, op.outputs, op.in fo) args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.inf o)
new_scan_node = None new_scan_node = None
clients = {} clients = {}
local_fgraph_topo = io_toposort( local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients args.inner_inputs, args.inner_outputs, clients=clients
) )
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if ( if (
isinstance(nd.op, Elemwise) isinstance(nd.op, Elemwise)
skipping to change at line 914 skipping to change at line 914
i for i in range(len(outer_vars)) if outer_vars[i] is None i for i in range(len(outer_vars)) if outer_vars[i] is None
] ]
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots] add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
if len(add_as_nitsots) > 0: if len(add_as_nitsots) > 0:
new_scan_node = self.add_nitsot_outputs( new_scan_node = self.add_nitsot_outputs(
fgraph, old_scan_node, old_scan_args, add_as_nitsots fgraph, old_scan_node, old_scan_args, add_as_nitsots
) )
new_scan_args = scan_args( new_scan_args = ScanArgs(
new_scan_node.inputs, new_scan_node.inputs,
new_scan_node.outputs, new_scan_node.outputs,
new_scan_node.op.inputs, new_scan_node.op.inputs,
new_scan_node.op.outputs, new_scan_node.op.outputs,
new_scan_node.op.info, new_scan_node.op.info,
) )
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :] new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
for i in range(len(new_outs)): for i in range(len(new_outs)):
outer_vars[idx_add_as_nitsots[i]] = new_outs[i] outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
skipping to change at line 941 skipping to change at line 941
nb_new_outs = len(new_outputs_inner) nb_new_outs = len(new_outputs_inner)
# Create the initial values for the new nitsot outputs # Create the initial values for the new nitsot outputs
# (the initial value is the nb of steps to store. For a nistot, # (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan) # it should be the number of steps performed by scan)
new_nitsots_initial_value = [ new_nitsots_initial_value = [
old_scan_node.inputs[0] for i in range(nb_new_outs) old_scan_node.inputs[0] for i in range(nb_new_outs)
] ]
# Create the scan_args corresponding to the new scan op to # Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
# create
new_scan_args = copy.copy(old_scan_args) new_scan_args = copy.copy(old_scan_args)
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner) new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value) new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
# Create the scan op from the scan_args # Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan( new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_outputs, new_scan_ar gs.info new_scan_args.inner_inputs, new_scan_args.inner_outputs, new_scan_ar gs.info
) )
# Create the Apply node for the scan op # Create the Apply node for the scan op
new_scan_node = new_scan_op( new_scan_node = new_scan_op(
*new_scan_args.outer_inputs, **dict(return_list=True) *new_scan_args.outer_inputs, **dict(return_list=True)
)[0].owner )[0].owner
# Modify the outer graph to make sure the outputs of the new scan are # Modify the outer graph to make sure the outputs of the new scan are
skipping to change at line 1961 skipping to change at line 1960
return left, right return left, right
@local_optimizer([Scan]) @local_optimizer([Scan])
def scan_merge_inouts(fgraph, node): def scan_merge_inouts(fgraph, node):
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
return False return False
# Do a first pass to merge identical external inputs. # Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new # Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates. # scan node created without duplicates.
a = scan_args( a = ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
) )
inp_equiv = OrderedDict() inp_equiv = OrderedDict()
if has_duplicates(a.outer_in_seqs): if has_duplicates(a.outer_in_seqs):
new_outer_seqs = [] new_outer_seqs = []
new_inner_seqs = [] new_inner_seqs = []
for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs):
if out_seq in new_outer_seqs: if out_seq in new_outer_seqs:
skipping to change at line 2007 skipping to change at line 2006
info = a.info info = a.info
a_inner_outs = a.inner_outputs a_inner_outs = a.inner_outputs
inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv) inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv)
op = Scan(inner_inputs, inner_outputs, info) op = Scan(inner_inputs, inner_outputs, info)
outputs = op(*outer_inputs) outputs = op(*outer_inputs)
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, (list, tuple)):
outputs = [outputs] outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info) na = ScanArgs(outer_inputs, outputs, op.inputs, op.outputs, op.info)
remove = [node] remove = [node]
else: else:
na = a na = a
remove = [] remove = []
# Now that the identical external inputs have been merged, we do a new # Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things # loop in order to merge external outputs that compute the same things
# from the same inputs. # from the same inputs.
left = [] left = []
right = [] right = []
 End of changes. 9 change blocks. 
9 lines changed or deleted 8 lines changed or added

Home  |  About  |  Features  |  All  |  Newest  |  Dox  |  Diffs  |  RSS Feeds  |  Screenshots  |  Comments  |  Imprint  |  Privacy  |  HTTP(S)