"Fossies" - the Fresh Open Source Software Archive  

Source code changes of the file "tests/scan/test_utils.py" between
aesara-rel-2.1.1.tar.gz and aesara-rel-2.1.2.tar.gz

About: Aesara is a Python library that allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays. It can use GPUs and perform efficient symbolic differentiation (formerly "Theano-PyMC"; a fork of the no longer developed original Theano library).

test_utils.py  (aesara-rel-2.1.1):test_utils.py  (aesara-rel-2.1.2)
import itertools import itertools
import numpy as np import numpy as np
import pytest import pytest
import aesara import aesara
from aesara import tensor as aet from aesara import tensor as aet
from aesara.scan.utils import map_variables from aesara.scan.utils import ScanArgs, map_variables
from aesara.tensor.type import scalar, vector from aesara.tensor.type import scalar, vector
class TestMapVariables: class TestMapVariables:
@staticmethod @staticmethod
def replacer(graph): def replacer(graph):
return getattr(graph.tag, "replacement", graph) return getattr(graph.tag, "replacement", graph)
def test_leaf(self): def test_leaf(self):
a = scalar("a") a = scalar("a")
b = scalar("b") b = scalar("b")
skipping to change at line 147 skipping to change at line 147
f = aesara.function([c, d, outer], [t, t2]) f = aesara.function([c, d, outer], [t, t2])
for m, n in itertools.combinations(range(10), 2): for m, n in itertools.combinations(range(10), 2):
assert f(m, n, outer=0.5) == [m + n, m - n] assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared # test that the unsupported case of replacement with a shared
# variable with updates crashes # variable with updates crashes
shared.update = shared + 1 shared.update = shared + 1
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
map_variables(self.replacer, [t]) map_variables(self.replacer, [t])
def test_ScanArgs():
scan_args = ScanArgs.create_empty()
assert scan_args.n_steps is None
for name in scan_args.field_names:
if name == "n_steps":
continue
assert len(getattr(scan_args, name)) == 0
with pytest.raises(TypeError):
ScanArgs.from_node(aet.ones(2).owner)
 End of changes. 2 change blocks. 
1 lines changed or deleted 1 lines changed or added

Home  |  About  |  Features  |  All  |  Newest  |  Dox  |  Diffs  |  RSS Feeds  |  Screenshots  |  Comments  |  Imprint  |  Privacy  |  HTTP(S)