diff --git a/distarray/globalapi/tests/test_distributed_io.py b/distarray/globalapi/tests/test_distributed_io.py index c95abdcf..8fdbce85 100644 --- a/distarray/globalapi/tests/test_distributed_io.py +++ b/distarray/globalapi/tests/test_distributed_io.py @@ -20,7 +20,7 @@ from distarray.externals.six.moves import range -from distarray.testing import import_or_skip, DefaultContextTestCase +from distarray.testing import import_parallel_h5py, DefaultContextTestCase from distarray.globalapi.distarray import DistArray from distarray.globalapi.maps import Distribution @@ -216,7 +216,7 @@ class TestHdf5FileSave(DefaultContextTestCase): def setUp(self): super(TestHdf5FileSave, self).setUp() - self.h5py = import_or_skip('h5py') + self.h5py = import_parallel_h5py() self.output_path = self.context.apply(engine_temp_path, ('.hdf5',), targets=[self.context.targets[0]])[0] @@ -280,7 +280,7 @@ class TestHdf5FileLoad(DefaultContextTestCase): @classmethod def setUpClass(cls): - cls.h5py = import_or_skip('h5py') + cls.h5py = import_parallel_h5py() super(TestHdf5FileLoad, cls).setUpClass() cls.output_path = cls.context.apply(engine_temp_path, ('.hdf5',), targets=[cls.context.targets[0]])[0] diff --git a/distarray/localapi/tests/paralleltest_io.py b/distarray/localapi/tests/paralleltest_io.py index a114f989..b3e988e9 100644 --- a/distarray/localapi/tests/paralleltest_io.py +++ b/distarray/localapi/tests/paralleltest_io.py @@ -8,7 +8,8 @@ import numpy from numpy.testing import assert_allclose, assert_equal -from distarray.testing import ParallelTestCase, import_or_skip, temp_filepath +from distarray.testing import (ParallelTestCase, import_parallel_h5py, + temp_filepath) from distarray.localapi import LocalArray, ndenumerate from distarray.localapi import (save_dnpy, load_dnpy, save_hdf5, load_hdf5, load_npy) @@ -188,7 +189,7 @@ class TestHdf5FileSave(ParallelTestCase): def setUp(self): self.rank = self.comm.Get_rank() - self.h5py = import_or_skip('h5py') + self.h5py = import_parallel_h5py() self.key = "data" # set up a common file to work with @@ -235,7 +236,7 @@ class TestHdf5FileLoad(ParallelTestCase): def setUp(self): self.rank = self.comm.Get_rank() - self.h5py = import_or_skip('h5py') + self.h5py = import_parallel_h5py() self.key = "data" self.expected = numpy.arange(20).reshape(2, 10) diff --git a/distarray/testing.py b/distarray/testing.py index e25f4986..7bd1df6e 100644 --- a/distarray/testing.py +++ b/distarray/testing.py @@ -20,7 +20,6 @@ from distarray.externals import six from distarray.externals import protocol_validator from distarray.globalapi.context import Context, ContextCreationError -from distarray.globalapi.ipython_utils import IPythonClient from distarray.error import InvalidCommSizeError from distarray.localapi.mpiutils import MPI, create_comm_of_size @@ -91,6 +90,32 @@ def import_or_skip(name): raise unittest.SkipTest(errmsg) +def import_parallel_h5py(): + """Import h5py built against parallel HDF5. Raise SkipTest on failure. + + Returns + ------- + h5py_module: module object + Module object imported by importlib. + + Raises + ------ + unittest.SkipTest + If the attempted import raises an ImportError. + + Examples + -------- + >>> h5py = import_parallel_h5py('h5py') + >>> h5py.get_config() + + + """ + h5py = import_or_skip('h5py') + if not h5py.get_config().mpi: + errmsg = 'h5py not built against parallel hdf5... skipping.' + raise unittest.SkipTest(errmsg) + + def comm_null_passes(fn): """Decorator. If `self.comm` is COMM_NULL, pass.