"Fossies" - the Fresh Open Source Software Archive

Member "scikit-image-0.19.3/skimage/restoration/tests/test_denoise.py" (12 Jun 2022, 49069 Bytes) of package /linux/misc/scikit-image-0.19.3.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. See also the latest Fossies "Diffs" side-by-side code changes report for "test_denoise.py": 0.19.2_vs_0.19.3.

    1 import functools
    2 import itertools
    3 
    4 import numpy as np
    5 import pytest
    6 import pywt
    7 from numpy.testing import (assert_array_almost_equal, assert_array_equal,
    8                            assert_warns)
    9 
   10 from skimage import color, data, img_as_float, restoration
   11 from skimage._shared._warnings import expected_warnings
   12 from skimage._shared.utils import _supported_float_type, slice_at_axis
   13 from skimage.metrics import peak_signal_noise_ratio, structural_similarity
   14 from skimage.restoration._denoise import _wavelet_threshold
   15 
   16 try:
   17     import dask  # noqa
   18 except ImportError:
   19     DASK_NOT_INSTALLED_WARNING = (
   20         'The optional dask dependency is not installed'
   21     )
   22 else:
   23     DASK_NOT_INSTALLED_WARNING = None
   24 
   25 
   26 np.random.seed(1234)
   27 
   28 
   29 astro = img_as_float(data.astronaut()[:128, :128])
   30 astro_gray = color.rgb2gray(astro)
   31 checkerboard_gray = img_as_float(data.checkerboard())
   32 checkerboard = color.gray2rgb(checkerboard_gray)
   33 # versions with one odd-sized dimension
   34 astro_gray_odd = astro_gray[:, :-1]
   35 astro_odd = astro[:, :-1]
   36 
   37 
   38 float_dtypes = [np.float16, np.float32, np.float64]
   39 try:
   40     float_dtypes += [np.float128]
   41 except AttributeError:
   42     pass
   43 
   44 
   45 @pytest.mark.parametrize('dtype', float_dtypes)
   46 def test_denoise_tv_chambolle_2d(dtype):
   47     # astronaut image
   48     img = astro_gray.astype(dtype, copy=True)
   49     # add noise to astronaut
   50     img += 0.5 * img.std() * np.random.rand(*img.shape)
   51     # clip noise so that it does not exceed allowed range for float images.
   52     img = np.clip(img, 0, 1)
   53     # denoise
   54     denoised_astro = restoration.denoise_tv_chambolle(img, weight=0.1)
   55     assert denoised_astro.dtype == _supported_float_type(img)
   56 
   57     from scipy import ndimage as ndi
   58 
   59     # Convert to a floating point type supported by scipy.ndimage
   60     float_dtype = _supported_float_type(img)
   61     img = img.astype(float_dtype, copy=False)
   62 
   63     grad = ndi.morphological_gradient(img, size=((3, 3)))
   64     grad_denoised = ndi.morphological_gradient(denoised_astro, size=((3, 3)))
   65     # test if the total variation has decreased
   66     assert grad_denoised.dtype == float_dtype
   67     assert np.sqrt((grad_denoised**2).sum()) < np.sqrt((grad**2).sum())
   68 
   69 
   70 @pytest.mark.parametrize('channel_axis', [0, 1, 2, -1])
   71 def test_denoise_tv_chambolle_multichannel(channel_axis):
   72     denoised0 = restoration.denoise_tv_chambolle(astro[..., 0], weight=0.1)
   73 
   74     img = np.moveaxis(astro, -1, channel_axis)
   75     denoised = restoration.denoise_tv_chambolle(img, weight=0.1,
   76                                                 channel_axis=channel_axis)
   77     _at = functools.partial(slice_at_axis, axis=channel_axis % img.ndim)
   78     assert_array_equal(denoised[_at(0)], denoised0)
   79 
   80     # tile astronaut subset to generate 3D+channels data
   81     astro3 = np.tile(astro[:64, :64, np.newaxis, :], [1, 1, 2, 1])
   82     # modify along tiled dimension to give non-zero gradient on 3rd axis
   83     astro3[:, :, 0, :] = 2*astro3[:, :, 0, :]
   84     denoised0 = restoration.denoise_tv_chambolle(astro3[..., 0], weight=0.1)
   85 
   86     astro3 = np.moveaxis(astro3, -1, channel_axis)
   87     denoised = restoration.denoise_tv_chambolle(astro3, weight=0.1,
   88                                                 channel_axis=channel_axis)
   89     _at = functools.partial(slice_at_axis,
   90                             axis=channel_axis % astro3.ndim)
   91     assert_array_equal(denoised[_at(0)], denoised0)
   92 
   93 
   94 def test_denoise_tv_chambolle_multichannel_deprecation():
   95     denoised0 = restoration.denoise_tv_chambolle(astro[..., 0], weight=0.1)
   96 
   97     with expected_warnings(["`multichannel` is a deprecated argument"]):
   98         restoration.denoise_tv_chambolle(astro, weight=0.1,
   99                                          multichannel=True)
  100 
  101     # providing multichannel argument positionally also warns
  102     with expected_warnings(["Providing the `multichannel` argument"]):
  103         denoised = restoration.denoise_tv_chambolle(astro, 0.1, 2e-4, 200,
  104                                                     True)
  105 
  106     assert_array_equal(denoised[..., 0], denoised0)
  107 
  108 
  109 def test_denoise_tv_chambolle_n_iter_max_deprecation():
  110     expected = restoration.denoise_tv_chambolle(astro[..., 0], weight=0.1,
  111                                                 max_num_iter=10)
  112 
  113     with expected_warnings(["`n_iter_max` is a deprecated argument"]):
  114         denoised = restoration.denoise_tv_chambolle(astro[..., 0], weight=0.1,
  115                                                     n_iter_max=10)
  116 
  117     assert_array_equal(expected, denoised)
  118 
  119 
  120 def test_denoise_tv_chambolle_float_result_range():
  121     # astronaut image
  122     img = astro_gray
  123     int_astro = np.multiply(img, 255).astype(np.uint8)
  124     assert np.max(int_astro) > 1
  125     denoised_int_astro = restoration.denoise_tv_chambolle(int_astro,
  126                                                           weight=0.1)
  127     # test if the value range of output float data is within [0.0:1.0]
  128     assert denoised_int_astro.dtype == float
  129     assert np.max(denoised_int_astro) <= 1.0
  130     assert np.min(denoised_int_astro) >= 0.0
  131 
  132 
  133 def test_denoise_tv_chambolle_3d():
  134     """Apply the TV denoising algorithm on a 3D image representing a sphere."""
  135     x, y, z = np.ogrid[0:40, 0:40, 0:40]
  136     mask = (x - 22)**2 + (y - 20)**2 + (z - 17)**2 < 8**2
  137     mask = 100 * mask.astype(float)
  138     mask += 60
  139     mask += 20 * np.random.rand(*mask.shape)
  140     mask[mask < 0] = 0
  141     mask[mask > 255] = 255
  142     res = restoration.denoise_tv_chambolle(mask.astype(np.uint8), weight=0.1)
  143     assert res.dtype == float
  144     assert res.std() * 255 < mask.std()
  145 
  146 
  147 def test_denoise_tv_chambolle_1d():
  148     """Apply the TV denoising algorithm on a 1D sinusoid."""
  149     x = 125 + 100*np.sin(np.linspace(0, 8*np.pi, 1000))
  150     x += 20 * np.random.rand(x.size)
  151     x = np.clip(x, 0, 255)
  152     res = restoration.denoise_tv_chambolle(x.astype(np.uint8), weight=0.1)
  153     assert res.dtype == float
  154     assert res.std() * 255 < x.std()
  155 
  156 
  157 def test_denoise_tv_chambolle_4d():
  158     """ TV denoising for a 4D input."""
  159     im = 255 * np.random.rand(8, 8, 8, 8)
  160     res = restoration.denoise_tv_chambolle(im.astype(np.uint8), weight=0.1)
  161     assert res.dtype == float
  162     assert res.std() * 255 < im.std()
  163 
  164 
  165 def test_denoise_tv_chambolle_weighting():
  166     # make sure a specified weight gives consistent results regardless of
  167     # the number of input image dimensions
  168     rstate = np.random.default_rng(1234)
  169     img2d = astro_gray.copy()
  170     img2d += 0.15 * rstate.standard_normal(img2d.shape)
  171     img2d = np.clip(img2d, 0, 1)
  172 
  173     # generate 4D image by tiling
  174     img4d = np.tile(img2d[..., None, None], (1, 1, 2, 2))
  175 
  176     w = 0.2
  177     denoised_2d = restoration.denoise_tv_chambolle(img2d, weight=w)
  178     denoised_4d = restoration.denoise_tv_chambolle(img4d, weight=w)
  179     assert structural_similarity(denoised_2d,
  180                                  denoised_4d[:, :, 0, 0]) > 0.99
  181 
  182 
  183 def test_denoise_tv_bregman_2d():
  184     img = checkerboard_gray.copy()
  185     # add some random noise
  186     img += 0.5 * img.std() * np.random.rand(*img.shape)
  187     img = np.clip(img, 0, 1)
  188 
  189     out1 = restoration.denoise_tv_bregman(img, weight=10)
  190     out2 = restoration.denoise_tv_bregman(img, weight=5)
  191 
  192     # make sure noise is reduced in the checkerboard cells
  193     assert img[30:45, 5:15].std() > out1[30:45, 5:15].std()
  194     assert out1[30:45, 5:15].std() > out2[30:45, 5:15].std()
  195 
  196 
  197 def test_denoise_tv_bregman_float_result_range():
  198     # astronaut image
  199     img = astro_gray.copy()
  200     int_astro = np.multiply(img, 255).astype(np.uint8)
  201     assert np.max(int_astro) > 1
  202     denoised_int_astro = restoration.denoise_tv_bregman(int_astro, weight=60.0)
  203     # test if the value range of output float data is within [0.0:1.0]
  204     assert denoised_int_astro.dtype == float
  205     assert np.max(denoised_int_astro) <= 1.0
  206     assert np.min(denoised_int_astro) >= 0.0
  207 
  208 
  209 def test_denoise_tv_bregman_3d():
  210     img = checkerboard.copy()
  211     # add some random noise
  212     img += 0.5 * img.std() * np.random.rand(*img.shape)
  213     img = np.clip(img, 0, 1)
  214 
  215     out1 = restoration.denoise_tv_bregman(img, weight=10)
  216     out2 = restoration.denoise_tv_bregman(img, weight=5)
  217 
  218     # make sure noise is reduced in the checkerboard cells
  219     assert img[30:45, 5:15].std() > out1[30:45, 5:15].std()
  220     assert out1[30:45, 5:15].std() > out2[30:45, 5:15].std()
  221 
  222 
  223 @pytest.mark.parametrize('channel_axis', [0, 1, 2, -1])
  224 def test_denoise_tv_bregman_3d_multichannel(channel_axis):
  225     img_astro = astro.copy()
  226     denoised0 = restoration.denoise_tv_bregman(img_astro[..., 0], weight=60.0)
  227     img_astro = np.moveaxis(img_astro, -1, channel_axis)
  228     denoised = restoration.denoise_tv_bregman(img_astro, weight=60.0,
  229                                               channel_axis=channel_axis)
  230     _at = functools.partial(slice_at_axis,
  231                             axis=channel_axis % img_astro.ndim)
  232     assert_array_equal(denoised0, denoised[_at(0)])
  233 
  234 
  235 def test_denoise_tv_bregman_3d_multichannel_deprecation():
  236     img_astro = astro.copy()
  237     denoised0 = restoration.denoise_tv_bregman(img_astro[..., 0], weight=60.0)
  238     with expected_warnings(["`multichannel` is a deprecated argument"]):
  239         denoised = restoration.denoise_tv_bregman(img_astro, weight=60.0,
  240                                                   multichannel=True)
  241 
  242     assert_array_equal(denoised0, denoised[..., 0])
  243 
  244 
  245 def test_denoise_tv_bregman_max_iter_deprecation():
  246     with expected_warnings(["`max_iter` is a deprecated argument"]):
  247         restoration.denoise_tv_bregman(astro_gray, weight=60.0, max_iter=5)
  248 
  249 
  250 def test_denoise_tv_bregman_multichannel():
  251     img = checkerboard_gray.copy()[:50, :50]
  252     # add some random noise
  253     img += 0.5 * img.std() * np.random.rand(*img.shape)
  254     img = np.clip(img, 0, 1)
  255 
  256     out1 = restoration.denoise_tv_bregman(img, weight=60.0)
  257     out2 = restoration.denoise_tv_bregman(img, weight=60.0, channel_axis=-1)
  258 
  259     assert_array_equal(out1, out2)
  260 
  261 
  262 def test_denoise_bilateral_null():
  263     img = np.zeros((50, 50))
  264     out = restoration.denoise_bilateral(img)
  265 
  266     # image full of zeros should return identity
  267     assert_array_equal(out, img)
  268 
  269 
  270 def test_denoise_bilateral_negative():
  271     img = -np.ones((50, 50))
  272     out = restoration.denoise_bilateral(img)
  273 
  274     # image with only negative values should be ok
  275     assert_array_equal(out, img)
  276 
  277 
  278 def test_denoise_bilateral_negative2():
  279     img = np.ones((50, 50))
  280     img[2, 2] = 2
  281 
  282     out1 = restoration.denoise_bilateral(img)
  283     out2 = restoration.denoise_bilateral(img - 10)  # contains negative values
  284 
  285     # 2 images with a given offset should give the same result (with the same
  286     # offset)
  287     assert_array_almost_equal(out1, out2 + 10)
  288 
  289 
  290 def test_denoise_bilateral_2d():
  291     img = checkerboard_gray.copy()[:50, :50]
  292     # add some random noise
  293     img += 0.5 * img.std() * np.random.rand(*img.shape)
  294     img = np.clip(img, 0, 1)
  295 
  296     out1 = restoration.denoise_bilateral(img, sigma_color=0.1,
  297                                          sigma_spatial=10, channel_axis=None)
  298     out2 = restoration.denoise_bilateral(img, sigma_color=0.2,
  299                                          sigma_spatial=20, channel_axis=None)
  300 
  301     # make sure noise is reduced in the checkerboard cells
  302     assert img[30:45, 5:15].std() > out1[30:45, 5:15].std()
  303     assert out1[30:45, 5:15].std() > out2[30:45, 5:15].std()
  304 
  305 
  306 def test_denoise_bilateral_pad():
  307     """This test checks if the bilateral filter is returning an image
  308     correctly padded."""
  309     img = img_as_float(data.chelsea())[100:200, 100:200]
  310     img_bil = restoration.denoise_bilateral(img, sigma_color=0.1,
  311                                             sigma_spatial=10,
  312                                             channel_axis=-1)
  313     condition_padding = np.count_nonzero(np.isclose(img_bil,
  314                                                     0,
  315                                                     atol=0.001))
  316     assert_array_equal(condition_padding, 0)
  317 
  318 
  319 @pytest.mark.parametrize('dtype', [np.float32, np.double])
  320 def test_denoise_bilateral_types(dtype):
  321     img = checkerboard_gray.copy()[:50, :50]
  322     # add some random noise
  323     img += 0.5 * img.std() * np.random.rand(*img.shape)
  324     img = np.clip(img, 0, 1).astype(dtype)
  325 
  326     # check that we can process multiple float types
  327     restoration.denoise_bilateral(img, sigma_color=0.1,
  328                                   sigma_spatial=10, channel_axis=None)
  329 
  330 
  331 @pytest.mark.parametrize('dtype', [np.float32, np.double])
  332 def test_denoise_bregman_types(dtype):
  333     img = checkerboard_gray.copy()[:50, :50]
  334     # add some random noise
  335     img += 0.5 * img.std() * np.random.rand(*img.shape)
  336     img = np.clip(img, 0, 1).astype(dtype)
  337 
  338     # check that we can process multiple float types
  339     restoration.denoise_tv_bregman(img, weight=5)
  340 
  341 
  342 def test_denoise_bilateral_zeros():
  343     img = np.zeros((10, 10))
  344     assert_array_equal(img,
  345                        restoration.denoise_bilateral(img, channel_axis=None))
  346 
  347 
  348 def test_denoise_bilateral_constant():
  349     img = np.ones((10, 10)) * 5
  350     assert_array_equal(img,
  351                        restoration.denoise_bilateral(img, channel_axis=None))
  352 
  353 
  354 @pytest.mark.parametrize('channel_axis', [0, 1, -1])
  355 def test_denoise_bilateral_color(channel_axis):
  356     img = checkerboard.copy()[:50, :50]
  357     # add some random noise
  358     img += 0.5 * img.std() * np.random.rand(*img.shape)
  359     img = np.clip(img, 0, 1)
  360 
  361     img = np.moveaxis(img, -1, channel_axis)
  362     out1 = restoration.denoise_bilateral(img, sigma_color=0.1,
  363                                          sigma_spatial=10,
  364                                          channel_axis=channel_axis)
  365     out2 = restoration.denoise_bilateral(img, sigma_color=0.2,
  366                                          sigma_spatial=20,
  367                                          channel_axis=channel_axis)
  368     img = np.moveaxis(img, channel_axis, -1)
  369     out1 = np.moveaxis(out1, channel_axis, -1)
  370     out2 = np.moveaxis(out2, channel_axis, -1)
  371 
  372     # make sure noise is reduced in the checkerboard cells
  373     assert img[30:45, 5:15].std() > out1[30:45, 5:15].std()
  374     assert out1[30:45, 5:15].std() > out2[30:45, 5:15].std()
  375 
  376 
  377 def test_denoise_bilateral_multichannel_deprecation():
  378     img = checkerboard.copy()[:50, :50]
  379     # add some random noise
  380     img += 0.5 * img.std() * np.random.rand(*img.shape)
  381     img = np.clip(img, 0, 1)
  382 
  383     with expected_warnings(["`multichannel` is a deprecated argument"]):
  384         out1 = restoration.denoise_bilateral(img, sigma_color=0.1,
  385                                              sigma_spatial=10,
  386                                              multichannel=True)
  387     with expected_warnings(["`multichannel` is a deprecated argument"]):
  388         out2 = restoration.denoise_bilateral(img, sigma_color=0.2,
  389                                              sigma_spatial=20,
  390                                              multichannel=True)
  391 
  392     # make sure noise is reduced in the checkerboard cells
  393     assert img[30:45, 5:15].std() > out1[30:45, 5:15].std()
  394     assert out1[30:45, 5:15].std() > out2[30:45, 5:15].std()
  395 
  396 
  397 def test_denoise_bilateral_3d_grayscale():
  398     img = np.ones((50, 50, 3))
  399     with pytest.raises(ValueError):
  400         restoration.denoise_bilateral(img, channel_axis=None)
  401 
  402 
  403 def test_denoise_bilateral_3d_multichannel():
  404     img = np.ones((50, 50, 50))
  405     with expected_warnings(["grayscale"]):
  406         result = restoration.denoise_bilateral(img, channel_axis=-1)
  407 
  408     assert_array_equal(result, img)
  409 
  410 
  411 def test_denoise_bilateral_multidimensional():
  412     img = np.ones((10, 10, 10, 10))
  413     with pytest.raises(ValueError):
  414         restoration.denoise_bilateral(img, channel_axis=None)
  415     with pytest.raises(ValueError):
  416         restoration.denoise_bilateral(img, channel_axis=-1)
  417 
  418 
  419 def test_denoise_bilateral_nan():
  420     img = np.full((50, 50), np.NaN)
  421     # This is in fact an optional warning for our test suite.
  422     # Python 3.5 will not trigger a warning.
  423     with expected_warnings([r'invalid|\A\Z']):
  424         out = restoration.denoise_bilateral(img, channel_axis=None)
  425     assert_array_equal(img, out)
  426 
  427 
  428 @pytest.mark.parametrize('fast_mode', [False, True])
  429 def test_denoise_nl_means_2d(fast_mode):
  430     img = np.zeros((40, 40))
  431     img[10:-10, 10:-10] = 1.
  432     sigma = 0.3
  433     img += sigma * np.random.standard_normal(img.shape)
  434     img_f32 = img.astype('float32')
  435     for s in [sigma, 0]:
  436         denoised = restoration.denoise_nl_means(img, 7, 5, 0.2,
  437                                                 fast_mode=fast_mode,
  438                                                 channel_axis=None,
  439                                                 sigma=s)
  440         # make sure noise is reduced
  441         assert img.std() > denoised.std()
  442 
  443         denoised_f32 = restoration.denoise_nl_means(img_f32, 7, 5, 0.2,
  444                                                     fast_mode=fast_mode,
  445                                                     channel_axis=None,
  446                                                     sigma=s)
  447         # make sure noise is reduced
  448         assert img.std() > denoised_f32.std()
  449 
  450         # Sheck single precision result
  451         assert np.allclose(denoised_f32, denoised, atol=1e-2)
  452 
  453 
  454 @pytest.mark.parametrize('fast_mode', [False, True])
  455 @pytest.mark.parametrize('n_channels', [2, 3, 6])
  456 @pytest.mark.parametrize('dtype', ['float64', 'float32'])
  457 def test_denoise_nl_means_2d_multichannel(fast_mode, n_channels, dtype):
  458     # reduce image size because nl means is slow
  459     img = np.copy(astro[:50, :50])
  460     img = np.concatenate((img, ) * 2, )  # 6 channels
  461     img = img.astype(dtype)
  462 
  463     # add some random noise
  464     sigma = 0.1
  465     imgn = img + sigma * np.random.standard_normal(img.shape)
  466     imgn = np.clip(imgn, 0, 1)
  467     imgn = imgn.astype(dtype)
  468 
  469     for s in [sigma, 0]:
  470         psnr_noisy = peak_signal_noise_ratio(
  471             img[..., :n_channels], imgn[..., :n_channels])
  472         denoised = restoration.denoise_nl_means(imgn[..., :n_channels],
  473                                                 3, 5, h=0.75 * sigma,
  474                                                 fast_mode=fast_mode,
  475                                                 channel_axis=-1,
  476                                                 sigma=s)
  477         psnr_denoised = peak_signal_noise_ratio(
  478             denoised[..., :n_channels], img[..., :n_channels])
  479 
  480         # make sure noise is reduced
  481         assert psnr_denoised > psnr_noisy
  482 
  483 
  484 def test_denoise_nl_means_2d_multichannel_deprecated():
  485     # reduce image size because nl means is slow
  486     img = np.copy(astro[:50, :50])
  487 
  488     # add some random noise
  489     sigma = 0.1
  490     imgn = img + sigma * np.random.standard_normal(img.shape)
  491     imgn = np.clip(imgn, 0, 1)
  492 
  493     psnr_noisy = peak_signal_noise_ratio(img, imgn)
  494     with expected_warnings(["`multichannel` is a deprecated argument"]):
  495         denoised = restoration.denoise_nl_means(imgn,
  496                                                 3, 5, h=0.75 * sigma,
  497                                                 multichannel=True,
  498                                                 sigma=sigma)
  499     psnr_denoised = peak_signal_noise_ratio(denoised, img)
  500 
  501     # make sure noise is reduced
  502     assert psnr_denoised > psnr_noisy
  503 
  504     # providing multichannel argument positionally also warns
  505     with expected_warnings(["Providing the `multichannel` argument"]):
  506         restoration.denoise_nl_means(imgn, 3, 5, 0.75 * sigma, True,
  507                                      sigma=sigma)
  508 
  509 
  510 @pytest.mark.parametrize('fast_mode', [False, True])
  511 @pytest.mark.parametrize('dtype', ['float64', 'float32'])
  512 def test_denoise_nl_means_3d(fast_mode, dtype):
  513     img = np.zeros((12, 12, 8), dtype=dtype)
  514     img[5:-5, 5:-5, 2:-2] = 1.
  515     sigma = 0.3
  516     imgn = img + sigma * np.random.standard_normal(img.shape)
  517     imgn = imgn.astype(dtype)
  518     psnr_noisy = peak_signal_noise_ratio(img, imgn)
  519     for s in [sigma, 0]:
  520         denoised = restoration.denoise_nl_means(imgn, 3, 4, h=0.75 * sigma,
  521                                                 fast_mode=fast_mode,
  522                                                 channel_axis=None, sigma=s)
  523         # make sure noise is reduced
  524         assert peak_signal_noise_ratio(img, denoised) > psnr_noisy
  525 
  526 
  527 @pytest.mark.parametrize('fast_mode', [False, True])
  528 @pytest.mark.parametrize('dtype', ['float64', 'float32', 'float16'])
  529 @pytest.mark.parametrize('channel_axis', [0, -1])
  530 def test_denoise_nl_means_multichannel(fast_mode, dtype, channel_axis):
  531     # for true 3D data, 3D denoising is better than denoising as 2D+channels
  532 
  533     # synthetic 3d volume
  534     img = data.binary_blobs(length=32, n_dim=3, seed=5)
  535     img = img[:, :24, :16].astype(dtype, copy=False)
  536 
  537     sigma = 0.2
  538     rng = np.random.default_rng(5)
  539     imgn = img + sigma * rng.standard_normal(img.shape)
  540     imgn = imgn.astype(dtype)
  541 
  542     # test 3D denoising (channel_axis = None)
  543     denoised_ok_multichannel = restoration.denoise_nl_means(
  544         imgn.copy(), 3, 2, h=0.6 * sigma, sigma=sigma, fast_mode=fast_mode,
  545         channel_axis=None)
  546 
  547     # set a channel axis: one dimension is (incorrectly) considered "channels"
  548     imgn = np.moveaxis(imgn, -1, channel_axis)
  549     denoised_wrong_multichannel = restoration.denoise_nl_means(
  550         imgn.copy(), 3, 2, h=0.6 * sigma, sigma=sigma, fast_mode=fast_mode,
  551         channel_axis=channel_axis
  552     )
  553     denoised_wrong_multichannel = np.moveaxis(
  554         denoised_wrong_multichannel, channel_axis, -1
  555     )
  556 
  557     img = img.astype(denoised_wrong_multichannel.dtype)
  558     psnr_wrong = peak_signal_noise_ratio(img, denoised_wrong_multichannel)
  559     psnr_ok = peak_signal_noise_ratio(img, denoised_ok_multichannel)
  560     assert psnr_ok > psnr_wrong
  561 
  562 
  563 def test_denoise_nl_means_4d():
  564     rng = np.random.default_rng(5)
  565     img = np.zeros((10, 10, 8, 5))
  566     img[2:-2, 2:-2, 2:-2, :2] = 0.5
  567     img[2:-2, 2:-2, 2:-2, 2:] = 1.
  568     sigma = 0.3
  569     imgn = img + sigma * rng.standard_normal(img.shape)
  570 
  571     nlmeans_kwargs = dict(patch_size=3, patch_distance=2, h=0.3 * sigma,
  572                           sigma=sigma, fast_mode=True)
  573 
  574     psnr_noisy = peak_signal_noise_ratio(img, imgn, data_range=1.)
  575 
  576     # denoise by looping over 3D slices
  577     denoised_3d = np.zeros_like(imgn)
  578     for ch in range(img.shape[-1]):
  579         denoised_3d[..., ch] = restoration.denoise_nl_means(
  580             imgn[..., ch],
  581             channel_axis=None,
  582             **nlmeans_kwargs)
  583     psnr_3d = peak_signal_noise_ratio(img, denoised_3d, data_range=1.)
  584     assert psnr_3d > psnr_noisy
  585 
  586     # denoise as 4D
  587     denoised_4d = restoration.denoise_nl_means(imgn,
  588                                                channel_axis=None,
  589                                                **nlmeans_kwargs)
  590     psnr_4d = peak_signal_noise_ratio(img, denoised_4d, data_range=1.)
  591     assert psnr_4d > psnr_3d
  592 
  593     # denoise as 3D + channels instead
  594     denoised_3dmc = restoration.denoise_nl_means(imgn,
  595                                                  channel_axis=-1,
  596                                                  **nlmeans_kwargs)
  597     psnr_3dmc = peak_signal_noise_ratio(img, denoised_3dmc, data_range=1.)
  598     assert psnr_3dmc > psnr_3d
  599 
  600 
  601 def test_denoise_nl_means_4d_multichannel():
  602     img = np.zeros((8, 8, 8, 4, 4))
  603     img[2:-2, 2:-2, 2:-2, 1:-1, :] = 1.
  604     sigma = 0.3
  605     imgn = img + sigma * np.random.randn(*img.shape)
  606 
  607     psnr_noisy = peak_signal_noise_ratio(img, imgn, data_range=1.)
  608 
  609     denoised_4dmc = restoration.denoise_nl_means(imgn, 3, 3, h=0.35 * sigma,
  610                                                  fast_mode=True,
  611                                                  channel_axis=-1,
  612                                                  sigma=sigma)
  613     psnr_4dmc = peak_signal_noise_ratio(img, denoised_4dmc, data_range=1.)
  614     assert psnr_4dmc > psnr_noisy
  615 
  616 
  617 def test_denoise_nl_means_wrong_dimension():
  618     # 1D not implemented
  619     img = np.zeros((5, ))
  620     with pytest.raises(NotImplementedError):
  621         restoration.denoise_nl_means(img, channel_axis=None)
  622 
  623     img = np.zeros((5, 3))
  624     with pytest.raises(NotImplementedError):
  625         restoration.denoise_nl_means(img, channel_axis=-1)
  626 
  627     # 3D + channels only for fast mode
  628     img = np.zeros((5, 5, 5, 5))
  629     with pytest.raises(NotImplementedError):
  630         restoration.denoise_nl_means(img, channel_axis=-1, fast_mode=False)
  631 
  632     # 4D only for fast mode
  633     img = np.zeros((5, 5, 5, 5))
  634     with pytest.raises(NotImplementedError):
  635         restoration.denoise_nl_means(img, channel_axis=None, fast_mode=False)
  636 
  637     # 4D + channels only for fast mode
  638     img = np.zeros((5, 5, 5, 5, 5))
  639     with pytest.raises(NotImplementedError):
  640         restoration.denoise_nl_means(img, channel_axis=-1, fast_mode=False)
  641 
  642     # 5D not implemented
  643     img = np.zeros((5, 5, 5, 5, 5))
  644     with pytest.raises(NotImplementedError):
  645         restoration.denoise_nl_means(img, channel_axis=None)
  646 
  647 
  648 @pytest.mark.parametrize('fast_mode', [False, True])
  649 @pytest.mark.parametrize('dtype', ['float64', 'float32'])
  650 def test_no_denoising_for_small_h(fast_mode, dtype):
  651     img = np.zeros((40, 40))
  652     img[10:-10, 10:-10] = 1.
  653     img += 0.3 * np.random.standard_normal(img.shape)
  654     img = img.astype(dtype)
  655     # very small h should result in no averaging with other patches
  656     denoised = restoration.denoise_nl_means(img, 7, 5, 0.01,
  657                                             fast_mode=fast_mode,
  658                                             channel_axis=None)
  659     assert np.allclose(denoised, img)
  660     denoised = restoration.denoise_nl_means(img, 7, 5, 0.01,
  661                                             fast_mode=fast_mode,
  662                                             channel_axis=None)
  663     assert np.allclose(denoised, img)
  664 
  665 
  666 @pytest.mark.parametrize('fast_mode', [False, True])
  667 def test_denoise_nl_means_2d_dtype(fast_mode):
  668     img = np.zeros((40, 40), dtype=int)
  669     img_f32 = img.astype('float32')
  670     img_f64 = img.astype('float64')
  671 
  672     assert restoration.denoise_nl_means(
  673         img, fast_mode=fast_mode).dtype == 'float64'
  674 
  675     assert restoration.denoise_nl_means(
  676         img_f32, fast_mode=fast_mode).dtype == img_f32.dtype
  677 
  678     assert restoration.denoise_nl_means(
  679         img_f64, fast_mode=fast_mode).dtype == img_f64.dtype
  680 
  681 
  682 @pytest.mark.parametrize('fast_mode', [False, True])
  683 def test_denoise_nl_means_3d_dtype(fast_mode):
  684     img = np.zeros((12, 12, 8), dtype=int)
  685     img_f32 = img.astype('float32')
  686     img_f64 = img.astype('float64')
  687 
  688     assert restoration.denoise_nl_means(
  689         img, patch_distance=2, fast_mode=fast_mode).dtype == 'float64'
  690 
  691     assert restoration.denoise_nl_means(
  692         img_f32, patch_distance=2, fast_mode=fast_mode).dtype == img_f32.dtype
  693 
  694     assert restoration.denoise_nl_means(
  695         img_f64, patch_distance=2, fast_mode=fast_mode).dtype == img_f64.dtype
  696 
  697 
  698 @pytest.mark.parametrize(
  699     'img, multichannel, convert2ycbcr',
  700     [(astro_gray, False, False),
  701      (astro_gray_odd, False, False),
  702      (astro_odd, True, False),
  703      (astro_odd, True, True)]
  704 )
  705 def test_wavelet_denoising(img, multichannel, convert2ycbcr):
  706     rstate = np.random.default_rng(1234)
  707     sigma = 0.1
  708     noisy = img + sigma * rstate.standard_normal(img.shape)
  709     noisy = np.clip(noisy, 0, 1)
  710 
  711     channel_axis = -1 if multichannel else None
  712 
  713     # Verify that SNR is improved when true sigma is used
  714     denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
  715                                            channel_axis=channel_axis,
  716                                            convert2ycbcr=convert2ycbcr,
  717                                            rescale_sigma=True)
  718     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  719     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  720     assert psnr_denoised > psnr_noisy
  721 
  722     # Verify that SNR is improved with internally estimated sigma
  723     denoised = restoration.denoise_wavelet(noisy,
  724                                            channel_axis=channel_axis,
  725                                            convert2ycbcr=convert2ycbcr,
  726                                            rescale_sigma=True)
  727     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  728     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  729     assert psnr_denoised > psnr_noisy
  730 
  731     # SNR is improved less with 1 wavelet level than with the default.
  732     denoised_1 = restoration.denoise_wavelet(noisy,
  733                                              channel_axis=channel_axis,
  734                                              wavelet_levels=1,
  735                                              convert2ycbcr=convert2ycbcr,
  736                                              rescale_sigma=True)
  737     psnr_denoised_1 = peak_signal_noise_ratio(img, denoised_1)
  738     assert psnr_denoised > psnr_denoised_1
  739     assert psnr_denoised_1 > psnr_noisy
  740 
  741     # Test changing noise_std (higher threshold, so less energy in signal)
  742     res1 = restoration.denoise_wavelet(noisy, sigma=2 * sigma,
  743                                        channel_axis=channel_axis,
  744                                        rescale_sigma=True)
  745     res2 = restoration.denoise_wavelet(noisy, sigma=sigma,
  746                                        channel_axis=channel_axis,
  747                                        rescale_sigma=True)
  748     assert np.sum(res1**2) <= np.sum(res2**2)
  749 
  750 
  751 @pytest.mark.parametrize('channel_axis', [0, 1, 2, -1])
  752 @pytest.mark.parametrize('convert2ycbcr', [False, True])
  753 def test_wavelet_denoising_channel_axis(channel_axis, convert2ycbcr):
  754     rstate = np.random.default_rng(1234)
  755     sigma = 0.1
  756     img = astro_odd
  757     noisy = img + sigma * rstate.standard_normal(img.shape)
  758     noisy = np.clip(noisy, 0, 1)
  759 
  760     img = np.moveaxis(img, -1, channel_axis)
  761     noisy = np.moveaxis(noisy, -1, channel_axis)
  762 
  763     # Verify that SNR is improved when true sigma is used
  764     denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
  765                                            channel_axis=channel_axis,
  766                                            convert2ycbcr=convert2ycbcr,
  767                                            rescale_sigma=True)
  768     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  769     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  770     assert psnr_denoised > psnr_noisy
  771 
  772 
  773 def test_wavelet_denoising_deprecated():
  774     rstate = np.random.default_rng(1234)
  775     sigma = 0.1
  776     img = astro_odd
  777     noisy = img + sigma * rstate.standard_normal(img.shape)
  778     noisy = np.clip(noisy, 0, 1)
  779 
  780     with expected_warnings(["`multichannel` is a deprecated argument"]):
  781         # Verify that SNR is improved when true sigma is used
  782         denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
  783                                                multichannel=True,
  784                                                rescale_sigma=True)
  785     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  786     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  787     assert psnr_denoised > psnr_noisy
  788 
  789     # providing multichannel argument positionally also warns
  790     with expected_warnings(["Providing the `multichannel` argument"]):
  791         restoration.denoise_wavelet(noisy, sigma, 'db1', 'soft', None, True,
  792                                     rescale_sigma=True)
  793 
  794 
  795 @pytest.mark.parametrize(
  796     'case, dtype, convert2ycbcr, estimate_sigma',
  797     itertools.product(
  798         ['1d', '2d multichannel'],
  799         [np.float16, np.float32, np.float64, np.int16, np.uint8],
  800         [True, False],
  801         [True, False])
  802 )
  803 def test_wavelet_denoising_scaling(case, dtype, convert2ycbcr,
  804                                    estimate_sigma):
  805     """Test cases for images without prescaling via img_as_float."""
  806     rstate = np.random.default_rng(1234)
  807 
  808     if case == '1d':
  809         # 1D single-channel in range [0, 255]
  810         x = np.linspace(0, 255, 1024)
  811     elif case == '2d multichannel':
  812         # 2D multichannel in range [0, 255]
  813         x = data.astronaut()[:64, :64]
  814     x = x.astype(dtype)
  815 
  816     # add noise and clip to original signal range
  817     sigma = 25.
  818     noisy = x + sigma * rstate.standard_normal(x.shape)
  819     noisy = np.clip(noisy, x.min(), x.max())
  820     noisy = noisy.astype(x.dtype)
  821 
  822     channel_axis = -1 if x.shape[-1] == 3 else None
  823 
  824     if estimate_sigma:
  825         sigma_est = restoration.estimate_sigma(noisy,
  826                                                channel_axis=channel_axis)
  827     else:
  828         sigma_est = None
  829 
  830     if convert2ycbcr and channel_axis is None:
  831         # YCbCr requires multichannel == True
  832         with pytest.raises(ValueError):
  833             denoised = restoration.denoise_wavelet(noisy,
  834                                                    sigma=sigma_est,
  835                                                    wavelet='sym4',
  836                                                    channel_axis=channel_axis,
  837                                                    convert2ycbcr=convert2ycbcr,
  838                                                    rescale_sigma=True)
  839         return
  840 
  841     denoised = restoration.denoise_wavelet(noisy, sigma=sigma_est,
  842                                            wavelet='sym4',
  843                                            channel_axis=channel_axis,
  844                                            convert2ycbcr=convert2ycbcr,
  845                                            rescale_sigma=True)
  846     assert denoised.dtype == _supported_float_type(noisy)
  847 
  848     data_range = x.max() - x.min()
  849     psnr_noisy = peak_signal_noise_ratio(x, noisy, data_range=data_range)
  850     clipped = np.dtype(dtype).kind != 'f'
  851     if not clipped:
  852         psnr_denoised = peak_signal_noise_ratio(x, denoised,
  853                                                 data_range=data_range)
  854 
  855         # output's max value is not substantially smaller than x's
  856         assert denoised.max() > 0.9 * x.max()
  857     else:
  858         # have to compare to x_as_float in integer input cases
  859         x_as_float = img_as_float(x)
  860         f_data_range = x_as_float.max() - x_as_float.min()
  861         psnr_denoised = peak_signal_noise_ratio(x_as_float, denoised,
  862                                                 data_range=f_data_range)
  863 
  864         # output has been clipped to expected range
  865         assert denoised.max() <= 1.0
  866         if np.dtype(dtype).kind == 'u':
  867             assert denoised.min() >= 0
  868         else:
  869             assert denoised.min() >= -1
  870 
  871     assert psnr_denoised > psnr_noisy
  872 
  873 
  874 def test_wavelet_threshold():
  875     rstate = np.random.default_rng(1234)
  876 
  877     img = astro_gray
  878     sigma = 0.1
  879     noisy = img + sigma * rstate.standard_normal(img.shape)
  880     noisy = np.clip(noisy, 0, 1)
  881 
  882     # employ a single, user-specified threshold instead of BayesShrink sigmas
  883     denoised = _wavelet_threshold(noisy, wavelet='db1', method=None,
  884                                   threshold=sigma)
  885     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  886     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  887     assert psnr_denoised > psnr_noisy
  888 
  889     # either method or threshold must be defined
  890     with pytest.raises(ValueError):
  891         _wavelet_threshold(noisy, wavelet='db1', method=None, threshold=None)
  892 
  893     # warns if a threshold is provided in a case where it would be ignored
  894     with expected_warnings(["Thresholding method "]):
  895         _wavelet_threshold(noisy, wavelet='db1', method='BayesShrink',
  896                            threshold=sigma)
  897 
  898 
  899 @pytest.mark.parametrize(
  900     'rescale_sigma, method, ndim',
  901     itertools.product(
  902         [True, False],
  903         ['VisuShrink', 'BayesShrink'],
  904         range(1, 5)
  905     )
  906 )
  907 def test_wavelet_denoising_nd(rescale_sigma, method, ndim):
  908     rstate = np.random.default_rng(1234)
  909     # Generate a very simple test image
  910     if ndim < 3:
  911         img = 0.2*np.ones((128, )*ndim)
  912     else:
  913         img = 0.2*np.ones((16, )*ndim)
  914     img[(slice(5, 13), ) * ndim] = 0.8
  915 
  916     sigma = 0.1
  917     noisy = img + sigma * rstate.standard_normal(img.shape)
  918     noisy = np.clip(noisy, 0, 1)
  919 
  920     # Mark H. 2018.08:
  921     #   The issue arises because when ndim in [1, 2]
  922     #   ``waverecn`` calls ``_match_coeff_dims``
  923     #   Which includes a numpy 1.15 deprecation.
  924     #   for larger number of dimensions _match_coeff_dims isn't called
  925     #   for some reason.
  926     # Verify that SNR is improved with internally estimated sigma
  927     denoised = restoration.denoise_wavelet(
  928         noisy, method=method,
  929         rescale_sigma=rescale_sigma)
  930     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  931     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  932     assert psnr_denoised > psnr_noisy
  933 
  934 
  935 def test_wavelet_invalid_method():
  936     with pytest.raises(ValueError):
  937         restoration.denoise_wavelet(np.ones(16), method='Unimplemented',
  938                                     rescale_sigma=True)
  939 
  940 
  941 @pytest.mark.parametrize('rescale_sigma', [True, False])
  942 def test_wavelet_denoising_levels(rescale_sigma):
  943     rstate = np.random.default_rng(1234)
  944     ndim = 2
  945     N = 256
  946     wavelet = 'db1'
  947     # Generate a very simple test image
  948     img = 0.2*np.ones((N, )*ndim)
  949     img[(slice(5, 13), ) * ndim] = 0.8
  950 
  951     sigma = 0.1
  952     noisy = img + sigma * rstate.standard_normal(img.shape)
  953     noisy = np.clip(noisy, 0, 1)
  954 
  955     denoised = restoration.denoise_wavelet(noisy, wavelet=wavelet,
  956                                            rescale_sigma=rescale_sigma)
  957     denoised_1 = restoration.denoise_wavelet(noisy, wavelet=wavelet,
  958                                              wavelet_levels=1,
  959                                              rescale_sigma=rescale_sigma)
  960     psnr_noisy = peak_signal_noise_ratio(img, noisy)
  961     psnr_denoised = peak_signal_noise_ratio(img, denoised)
  962     psnr_denoised_1 = peak_signal_noise_ratio(img, denoised_1)
  963 
  964     # multi-level case should outperform single level case
  965     assert psnr_denoised > psnr_denoised_1 > psnr_noisy
  966 
  967     # invalid number of wavelet levels results in a ValueError or UserWarning
  968     max_level = pywt.dwt_max_level(np.min(img.shape),
  969                                    pywt.Wavelet(wavelet).dec_len)
  970     # exceeding max_level raises a UserWarning in PyWavelets >= 1.0.0
  971     with expected_warnings([
  972             'all coefficients will experience boundary effects']):
  973         restoration.denoise_wavelet(
  974             noisy, wavelet=wavelet, wavelet_levels=max_level + 1,
  975             rescale_sigma=rescale_sigma)
  976 
  977     with pytest.raises(ValueError):
  978         restoration.denoise_wavelet(
  979             noisy,
  980             wavelet=wavelet, wavelet_levels=-1,
  981             rescale_sigma=rescale_sigma)
  982 
  983 
  984 def test_estimate_sigma_gray():
  985     rstate = np.random.default_rng(1234)
  986     # astronaut image
  987     img = astro_gray.copy()
  988     sigma = 0.1
  989     # add noise to astronaut
  990     img += sigma * rstate.standard_normal(img.shape)
  991 
  992     sigma_est = restoration.estimate_sigma(img, channel_axis=None)
  993     assert_array_almost_equal(sigma, sigma_est, decimal=2)
  994 
  995 
  996 def test_estimate_sigma_masked_image():
  997     # Verify computation on an image with a large, noise-free border.
  998     # (zero regions will be masked out by _sigma_est_dwt to avoid returning
  999     #  sigma = 0)
 1000     rstate = np.random.default_rng(1234)
 1001     # uniform image
 1002     img = np.zeros((128, 128))
 1003     center_roi = (slice(32, 96), slice(32, 96))
 1004     img[center_roi] = 0.8
 1005     sigma = 0.1
 1006 
 1007     img[center_roi] = sigma * rstate.standard_normal(img[center_roi].shape)
 1008 
 1009     sigma_est = restoration.estimate_sigma(img, channel_axis=None)
 1010     assert_array_almost_equal(sigma, sigma_est, decimal=1)
 1011 
 1012 
 1013 @pytest.mark.parametrize('channel_axis', [0, 1, 2, -1])
 1014 def test_estimate_sigma_color(channel_axis):
 1015     rstate = np.random.default_rng(1234)
 1016     # astronaut image
 1017     img = astro.copy()
 1018     sigma = 0.1
 1019     # add noise to astronaut
 1020     img += sigma * rstate.standard_normal(img.shape)
 1021     img = np.moveaxis(img, -1, channel_axis)
 1022 
 1023     sigma_est = restoration.estimate_sigma(img, channel_axis=channel_axis,
 1024                                            average_sigmas=True)
 1025     assert_array_almost_equal(sigma, sigma_est, decimal=2)
 1026 
 1027     sigma_list = restoration.estimate_sigma(img, channel_axis=channel_axis,
 1028                                             average_sigmas=False)
 1029     assert_array_equal(len(sigma_list), img.shape[channel_axis])
 1030     assert_array_almost_equal(sigma_list[0], sigma_est, decimal=2)
 1031 
 1032     if channel_axis % img.ndim == 2:
 1033         # default channel_axis=None should raise a warning about last axis size
 1034         assert_warns(UserWarning, restoration.estimate_sigma, img)
 1035 
 1036 
 1037 def test_estimate_sigma_color_deprecated_multichannel():
 1038     rstate = np.random.default_rng(1234)
 1039     # astronaut image
 1040     img = astro.copy()
 1041     sigma = 0.1
 1042     # add noise to astronaut
 1043     img += sigma * rstate.standard_normal(img.shape)
 1044 
 1045     with expected_warnings(["`multichannel` is a deprecated argument"]):
 1046         sigma_est = restoration.estimate_sigma(img, multichannel=True,
 1047                                                average_sigmas=True)
 1048     assert_array_almost_equal(sigma, sigma_est, decimal=2)
 1049 
 1050     # providing multichannel argument positionally also warns
 1051     with expected_warnings(["Providing the `multichannel` argument"]):
 1052         sigma_est = restoration.estimate_sigma(img, True, True)
 1053     assert_array_almost_equal(sigma, sigma_est, decimal=2)
 1054 
 1055 
 1056 @pytest.mark.parametrize('rescale_sigma', [True, False])
 1057 def test_wavelet_denoising_args(rescale_sigma):
 1058     """
 1059     Some of the functions inside wavelet denoising throw an error the wrong
 1060     arguments are passed. This protects against that and verifies that all
 1061     arguments can be passed.
 1062     """
 1063     img = astro
 1064     noisy = img.copy() + 0.1 * np.random.standard_normal(img.shape)
 1065 
 1066     for convert2ycbcr in [True, False]:
 1067         for multichannel in [True, False]:
 1068             channel_axis = -1 if multichannel else None
 1069             if convert2ycbcr and not multichannel:
 1070                 with pytest.raises(ValueError):
 1071                     restoration.denoise_wavelet(noisy,
 1072                                                 convert2ycbcr=convert2ycbcr,
 1073                                                 channel_axis=channel_axis,
 1074                                                 rescale_sigma=rescale_sigma)
 1075                 continue
 1076             for sigma in [0.1, [0.1, 0.1, 0.1], None]:
 1077                 if (not multichannel and not convert2ycbcr) or \
 1078                         (isinstance(sigma, list) and not multichannel):
 1079                     continue
 1080                 restoration.denoise_wavelet(noisy, sigma=sigma,
 1081                                             convert2ycbcr=convert2ycbcr,
 1082                                             channel_axis=channel_axis,
 1083                                             rescale_sigma=rescale_sigma)
 1084 
 1085 
 1086 @pytest.mark.parametrize('rescale_sigma', [True, False])
 1087 def test_denoise_wavelet_biorthogonal(rescale_sigma):
 1088     """Biorthogonal wavelets should raise a warning during thresholding."""
 1089     img = astro_gray
 1090     assert_warns(UserWarning, restoration.denoise_wavelet, img,
 1091                  wavelet='bior2.2', channel_axis=None,
 1092                  rescale_sigma=rescale_sigma)
 1093 
 1094 
 1095 @pytest.mark.parametrize('channel_axis', [-1, None])
 1096 @pytest.mark.parametrize('rescale_sigma', [True, False])
 1097 def test_cycle_spinning_multichannel(rescale_sigma, channel_axis):
 1098     sigma = 0.1
 1099     rstate = np.random.default_rng(1234)
 1100 
 1101     if channel_axis is not None:
 1102         img = astro
 1103         # can either omit or be 0 along the channels axis
 1104         valid_shifts = [1, (0, 1), (1, 0), (1, 1), (1, 1, 0)]
 1105         # can either omit or be 1 on channels axis.
 1106         valid_steps = [1, 2, (1, 2), (1, 2, 1)]
 1107         # too few or too many shifts or non-zero shift on channels
 1108         invalid_shifts = [(1, 1, 2), (1, ), (1, 1, 0, 1)]
 1109         # too few or too many shifts or any shifts <= 0
 1110         invalid_steps = [(1, ), (1, 1, 1, 1), (0, 1), (-1, -1)]
 1111     else:
 1112         img = astro_gray
 1113         valid_shifts = [1, (0, 1), (1, 0), (1, 1)]
 1114         valid_steps = [1, 2, (1, 2)]
 1115         invalid_shifts = [(1, 1, 2), (1, )]
 1116         invalid_steps = [(1, ), (1, 1, 1), (0, 1), (-1, -1)]
 1117 
 1118     noisy = img.copy() + 0.1 * rstate.standard_normal(img.shape)
 1119 
 1120     denoise_func = restoration.denoise_wavelet
 1121     func_kw = dict(sigma=sigma, channel_axis=channel_axis,
 1122                    rescale_sigma=rescale_sigma)
 1123 
 1124     # max_shifts=0 is equivalent to just calling denoise_func
 1125     with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
 1126         dn_cc = restoration.cycle_spin(noisy, denoise_func, max_shifts=0,
 1127                                        func_kw=func_kw,
 1128                                        channel_axis=channel_axis)
 1129         dn = denoise_func(noisy, **func_kw)
 1130     assert_array_equal(dn, dn_cc)
 1131 
 1132     # denoising with cycle spinning will give better PSNR than without
 1133     for max_shifts in valid_shifts:
 1134         with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
 1135             dn_cc = restoration.cycle_spin(noisy, denoise_func,
 1136                                            max_shifts=max_shifts,
 1137                                            func_kw=func_kw,
 1138                                            channel_axis=channel_axis)
 1139         psnr = peak_signal_noise_ratio(img, dn)
 1140         psnr_cc = peak_signal_noise_ratio(img, dn_cc)
 1141         assert psnr_cc > psnr
 1142 
 1143     for shift_steps in valid_steps:
 1144         with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
 1145             dn_cc = restoration.cycle_spin(noisy, denoise_func,
 1146                                            max_shifts=2,
 1147                                            shift_steps=shift_steps,
 1148                                            func_kw=func_kw,
 1149                                            channel_axis=channel_axis)
 1150         psnr = peak_signal_noise_ratio(img, dn)
 1151         psnr_cc = peak_signal_noise_ratio(img, dn_cc)
 1152         assert psnr_cc > psnr
 1153 
 1154     for max_shifts in invalid_shifts:
 1155         with pytest.raises(ValueError):
 1156             dn_cc = restoration.cycle_spin(noisy, denoise_func,
 1157                                            max_shifts=max_shifts,
 1158                                            func_kw=func_kw,
 1159                                            channel_axis=channel_axis)
 1160     for shift_steps in invalid_steps:
 1161         with pytest.raises(ValueError):
 1162             dn_cc = restoration.cycle_spin(noisy, denoise_func,
 1163                                            max_shifts=2,
 1164                                            shift_steps=shift_steps,
 1165                                            func_kw=func_kw,
 1166                                            channel_axis=channel_axis)
 1167 
 1168 
 1169 def test_cycle_spinning_num_workers():
 1170     img = astro_gray
 1171     sigma = 0.1
 1172     rstate = np.random.default_rng(1234)
 1173     noisy = img.copy() + 0.1 * rstate.standard_normal(img.shape)
 1174 
 1175     denoise_func = restoration.denoise_wavelet
 1176     func_kw = dict(sigma=sigma, channel_axis=-1, rescale_sigma=True)
 1177 
 1178     # same results are expected whether using 1 worker or multiple workers
 1179     dn_cc1 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
 1180                                     func_kw=func_kw, channel_axis=None,
 1181                                     num_workers=1)
 1182 
 1183     # Repeat dn_cc1 computation, but without channel_axis specified to
 1184     # verify that the default behavior is channel_axis=None
 1185     dn_cc1_ = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
 1186                                      func_kw=func_kw, num_workers=1)
 1187     assert_array_equal(dn_cc1, dn_cc1_)
 1188 
 1189     with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
 1190         dn_cc2 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
 1191                                         func_kw=func_kw, channel_axis=None,
 1192                                         num_workers=4)
 1193         dn_cc3 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
 1194                                         func_kw=func_kw, channel_axis=None,
 1195                                         num_workers=None)
 1196     assert_array_almost_equal(dn_cc1, dn_cc2)
 1197     assert_array_almost_equal(dn_cc1, dn_cc3)
 1198 
 1199 
 1200 def test_cycle_spinning_num_workers_deprecated_multichannel():
 1201     img = astro_gray[:32, :32]
 1202     sigma = 0.1
 1203     rstate = np.random.default_rng(1234)
 1204     noisy = img.copy() + 0.1 * rstate.standard_normal(img.shape)
 1205 
 1206     denoise_func = restoration.denoise_wavelet
 1207 
 1208     func_kw = dict(sigma=sigma, channel_axis=-1, rescale_sigma=True)
 1209 
 1210     mc_warn_str = "`multichannel` is a deprecated argument"
 1211 
 1212     # same results are expected whether using 1 worker or multiple workers
 1213     with expected_warnings([mc_warn_str]):
 1214         dn_cc1 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
 1215                                         func_kw=func_kw, multichannel=False,
 1216                                         num_workers=1)
 1217 
 1218     if DASK_NOT_INSTALLED_WARNING is None:
 1219         exp_warn = [mc_warn_str]
 1220     else:
 1221         exp_warn = [mc_warn_str, DASK_NOT_INSTALLED_WARNING]
 1222     with expected_warnings(exp_warn):
 1223         dn_cc2 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
 1224                                         func_kw=func_kw, multichannel=False,
 1225                                         num_workers=2)
 1226     assert_array_almost_equal(dn_cc1, dn_cc2)
 1227 
 1228     # providing multichannel argument positionally also warns
 1229     mc_warn_str = "Providing the `multichannel` argument"
 1230     if DASK_NOT_INSTALLED_WARNING is None:
 1231         exp_warn = [mc_warn_str]
 1232     else:
 1233         exp_warn = [mc_warn_str, DASK_NOT_INSTALLED_WARNING]
 1234 
 1235     with expected_warnings(exp_warn):
 1236         restoration.cycle_spin(noisy, denoise_func, 1, 1, None, False)