diff --git a/bitarray/__init__.py b/bitarray/__init__.py index f0d95db0..0094dc55 100644 --- a/bitarray/__init__.py +++ b/bitarray/__init__.py @@ -9,12 +9,13 @@ Author: Ilan Schnell """ +import contextlib from collections import namedtuple from bitarray._bitarray import ( bitarray, decodetree, bits2bytes, _bitarray_reconstructor, get_default_endian, _set_default_endian, _sysinfo, - BITARRAY_VERSION as __version__ + BITARRAY_VERSION as __version__, _default_endian_contextvar ) __all__ = ['bitarray', 'frozenbitarray', 'decodetree', 'bits2bytes'] @@ -64,3 +65,30 @@ def test(verbosity=1): """ from bitarray import test_bitarray return test_bitarray.run(verbosity=verbosity) + +# from bitarray.h: +# #define ENDIAN_LITTLE 0 +# #define ENDIAN_BIG 1 +ENDIANNESS_MAPPING = {'little': 0, 'big': 1} + +@contextlib.contextmanager +def default_endian(endian): + """default_endian(endian, /) -> str + Context manager for controlling the default endianness for bitarrays. + + Set the default endianness for the scope of a ``with`` block and restore the + original default endianness at the end. Possible values are 'big' and 'little'. + """ + if not isinstance(endian, str): + raise TypeError( + f"default endianness must be 'big' or 'little', got {endian}" + f"with type {type(endian)}" + ) + if endian not in ['big', 'little']: + raise ValueError( + f"default endianness must be 'big' or 'little', got {endian}") + token = _default_endian_contextvar.set(ENDIANNESS_MAPPING[endian]) + try: + yield endian + finally: + _default_endian_contextvar.reset(token) diff --git a/bitarray/_bitarray.c b/bitarray/_bitarray.c index d4345265..bd8406c9 100644 --- a/bitarray/_bitarray.c +++ b/bitarray/_bitarray.c @@ -17,7 +17,7 @@ #define BLOCKSIZE 65536 /* default bit-endianness */ -static int default_endian = ENDIAN_BIG; +static PyObject* default_endian; /* translation table - setup during module initialization */ static char reverse_trans[256]; @@ -3107,6 +3107,9 @@ decodetree_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return obj; } +// forward declaration +static int get_c_default_endian(void); + static PyObject * decodetree_todict(decodetreeobject *self) { @@ -3116,7 +3119,7 @@ decodetree_todict(decodetreeobject *self) if ((dict = PyDict_New()) == NULL) return NULL; - prefix = newbitarrayobject(&Bitarray_Type, 0, default_endian); + prefix = newbitarrayobject(&Bitarray_Type, 0, get_c_default_endian()); if (prefix == NULL) goto error; @@ -3595,15 +3598,29 @@ static PyMethodDef bitarray_methods[] = { /* ------------------------ bitarray initialization -------------------- */ +static int +get_c_default_endian(void) +{ + PyObject *current_default_endian; + if (PyContextVar_Get(default_endian, NULL, ¤t_default_endian) < 0) + { + return -1; + } + int ret = (int)PyLong_AsLong(current_default_endian); + Py_DECREF(current_default_endian); + return ret; +} + /* Given string 'str', return an integer representing the bit-endianness. If the string is invalid, set exception and return -1. */ static int endian_from_string(const char *str) { - assert(default_endian == ENDIAN_LITTLE || default_endian == ENDIAN_BIG); + int c_default_endian = get_c_default_endian(); + assert(c_default_endian == ENDIAN_LITTLE || c_default_endian == ENDIAN_BIG); if (str == NULL) - return default_endian; + return c_default_endian; if (strcmp(str, "little") == 0) return ENDIAN_LITTLE; @@ -4126,7 +4143,7 @@ reconstructor(PyObject *module, PyObject *args) static PyObject * get_default_endian(PyObject *module) { - return PyUnicode_FromString(ENDIAN_STR(default_endian)); + return PyUnicode_FromString(ENDIAN_STR(get_c_default_endian())); } PyDoc_STRVAR(get_default_endian_doc, @@ -4150,8 +4167,15 @@ set_default_endian(PyObject *module, PyObject *args) in a temporary variable before setting default_endian. */ if ((t = endian_from_string(endian_str)) < 0) return NULL; - default_endian = t; + PyObject *py_t = PyLong_FromLong(t); + if (py_t == NULL) + return NULL; + + if (PyContextVar_Set(default_endian, py_t) == NULL) { + return NULL; + } + Py_DECREF(py_t); Py_RETURN_NONE; } @@ -4264,7 +4288,15 @@ PyInit__bitarray(void) return NULL; Py_SET_TYPE(&Bitarray_Type, &PyType_Type); Py_INCREF((PyObject *) &Bitarray_Type); - PyModule_AddObject(m, "bitarray", (PyObject *) &Bitarray_Type); + PyModule_AddObject(m, "bitarray", (PyObject *)&Bitarray_Type); + + PyObject *Py_ENDIAN_BIG = PyLong_FromLong(ENDIAN_BIG); + + default_endian = PyContextVar_New("default_endian", Py_ENDIAN_BIG); + + Py_DECREF(Py_ENDIAN_BIG); + + PyModule_AddObject(m, "_default_endian_contextvar", default_endian); if (register_abc() < 0) return NULL; diff --git a/bitarray/test_bitarray.py b/bitarray/test_bitarray.py index f7ecfcd9..71b3420d 100644 --- a/bitarray/test_bitarray.py +++ b/bitarray/test_bitarray.py @@ -20,11 +20,14 @@ # imports needed inside tests import array +import asyncio +import concurrent.futures import copy import itertools import mmap import pickle import shelve +import threading import weakref @@ -35,7 +38,7 @@ from bitarray import (bitarray, frozenbitarray, bits2bytes, decodetree, get_default_endian, _set_default_endian, _bitarray_reconstructor, _sysinfo as sysinfo, - BufferInfo, __version__) + BufferInfo, __version__, default_endian) def skipIf(condition): "Skip a test if the condition is true." @@ -176,35 +179,52 @@ def test_sysinfo_byteorder(self): sysinfo("PY_BIG_ENDIAN")) def test_set_default_endian(self): - for default_endian in 'big', 'little': - _set_default_endian(default_endian) - a = bitarray() - self.assertEqual(a.endian, default_endian) - for x in None, 0, 64, '10111', [1, 0]: - a = bitarray(x) - self.assertEqual(a.endian, default_endian) - - for endian in 'big', 'little', None: - a = bitarray(endian=endian) - self.assertEqual(a.endian, - default_endian if endian is None else endian) - - # make sure that wrong calling _set_default_endian() does not - # change the default endianness - self.assertRaises(ValueError, _set_default_endian, 'foobar') - self.assertEqual(bitarray().endian, default_endian) + for default in 'big', 'little': + with default_endian(default) as used_endian: + assert used_endian == default + a = bitarray() + self.assertEqual(a.endian, default) + for x in None, 0, 64, '10111', [1, 0]: + a = bitarray(x) + self.assertEqual(a.endian, default) + + for endian in 'big', 'little', None: + a = bitarray(endian=endian) + self.assertEqual(a.endian, + default if endian is None else endian) + + # make sure that wrong calling _set_default_endian() does not + # change the default endianness + self.assertRaises(ValueError, _set_default_endian, 'foobar') + self.assertEqual(bitarray().endian, default) def test_set_default_endian_errors(self): self.assertRaises(TypeError, _set_default_endian, 0) self.assertRaises(TypeError, _set_default_endian, 'little', 0) self.assertRaises(ValueError, _set_default_endian, 'foo') + def test_default_endian_errors(self): + def integer_arg(): + with default_endian(0): + pass + self.assertRaises(TypeError, integer_arg) + + def too_many_args(): + with default_endian('little', 0): + pass + self.assertRaises(TypeError, too_many_args) + + def invalid_string_arg(): + with default_endian('foo'): + pass + self.assertRaises(ValueError, invalid_string_arg) + def test_get_default_endian(self): - for default_endian in 'big', 'little': - _set_default_endian(default_endian) - endian = get_default_endian() - self.assertEqual(endian, default_endian) - self.assertEqual(type(endian), str) + for default in 'big', 'little': + with default_endian(default): + endian = get_default_endian() + self.assertEqual(endian, default) + self.assertEqual(type(endian), str) def test_get_default_endian_errors(self): # takes no arguments @@ -258,11 +278,10 @@ def test_endian(self): self.assertEqual(a.tobytes(), b.tobytes()) def test_endian_default(self): - _set_default_endian('big') - a_big = bitarray() - _set_default_endian('little') - a_little = bitarray() - _set_default_endian('big') + with default_endian('big'): + a_big = bitarray() + with default_endian('little'): + a_little = bitarray() self.assertEqual(a_big.endian, 'big') self.assertEqual(a_little.endian, 'little') @@ -283,8 +302,8 @@ def test_buffer_endian(self): a = bitarray(buffer=b'', endian=endian) self.assertEQUAL(a, bitarray(0, endian)) - _set_default_endian(endian) - a = bitarray(buffer=b'A') + with default_endian(endian): + a = bitarray(buffer=b'A') self.assertEqual(a.endian, endian) self.assertEqual(len(a), 8) @@ -1535,6 +1554,58 @@ def test_basic(self): self.assertRaises(IndexError, a.__delitem__, [10]) self.assertRaises(TypeError, a.__delitem__, (1, 3)) + # asyncio.Barrier was introduced in Python 3.11 + @skipIf(sys.version_info < (3, 11)) + def test_default_endian_async_safe(self): + endian_start = get_default_endian() + b = asyncio.Barrier(10) + + async def big(): + _set_default_endian('big') + await b.wait() + endian = get_default_endian() + assert endian == 'big' + + async def little(): + _set_default_endian('little') + await b.wait() + endian = get_default_endian() + assert endian == 'little' + + async def main(): + await asyncio.gather( + big(), big(), big(), big(), big(), + little(), little(), little(), little(), little() + ) + + loop = asyncio.new_event_loop() + asyncio.run(main()) + loop.close() + assert get_default_endian() == endian_start + + def test_default_endian_thread_safe(self): + endian_start = get_default_endian() + b = threading.Barrier(10) + + def big(): + b.wait() + _set_default_endian('big') + endian = get_default_endian() + assert endian == 'big' + + def little(): + b.wait() + _set_default_endian('little') + endian = get_default_endian() + assert endian == 'little' + + tpe = concurrent.futures.ThreadPoolExecutor(max_workers=10) + futures = [tpe.submit(func) for func in [big]*5 + [little]*5] + for f in futures: + f.result() + + assert get_default_endian() == endian_start + def test_delete_one(self): for a in self.randombitarrays(start=1): b = a.copy() @@ -2421,9 +2492,9 @@ def test_frozenbitarray(self): self.assertEqual(type(b), frozenbitarray) self.assertRaises(TypeError, a.__ilshift__, 4) + @default_endian("big") @skipIf(is_pypy) def test_imported(self): - _set_default_endian("big") a = bytearray([0xf0, 0x01, 0x02, 0x0f]) b = bitarray(buffer=a) self.assertFalse(b.readonly) @@ -2938,19 +3009,19 @@ class PackTests(unittest.TestCase, Util): def test_pack_simple(self): for endian in 'little', 'big': - _set_default_endian(endian) - a = bitarray() - a.pack(bytes()) - self.assertEQUAL(a, bitarray()) - a.pack(b'\x00') - self.assertEQUAL(a, bitarray('0')) - a.pack(b'\xff') - self.assertEQUAL(a, bitarray('01')) - a.pack(b'\x01\x00\x7a') - self.assertEQUAL(a, bitarray('01101')) - a.pack(bytearray([0x01, 0x00, 0xff, 0xa7])) - self.assertEQUAL(a, bitarray('01101 1011')) - self.check_obj(a) + with default_endian(endian): + a = bitarray() + a.pack(bytes()) + self.assertEQUAL(a, bitarray()) + a.pack(b'\x00') + self.assertEQUAL(a, bitarray('0')) + a.pack(b'\xff') + self.assertEQUAL(a, bitarray('01')) + a.pack(b'\x01\x00\x7a') + self.assertEQUAL(a, bitarray('01101')) + a.pack(bytearray([0x01, 0x00, 0xff, 0xa7])) + self.assertEQUAL(a, bitarray('01101 1011')) + self.check_obj(a) def test_pack_types(self): a = bitarray() diff --git a/bitarray/test_util.py b/bitarray/test_util.py index 5e865909..495ba411 100644 --- a/bitarray/test_util.py +++ b/bitarray/test_util.py @@ -24,7 +24,7 @@ from collections import Counter from bitarray import (bitarray, frozenbitarray, decodetree, bits2bytes, - _set_default_endian) + default_endian) from bitarray.test_bitarray import Util, skipIf, is_pypy, urandom_2, PTRSIZE from bitarray.util import ( @@ -46,24 +46,24 @@ class ZerosOnesTests(unittest.TestCase): def test_basic(self): for _ in range(50): - default_endian = choice(['little', 'big']) - _set_default_endian(default_endian) - a = choice([zeros(0), zeros(0, None), zeros(0, endian=None), - ones(0), ones(0, None), ones(0, endian=None)]) - self.assertEqual(a, bitarray()) - self.assertEqual(a.endian, default_endian) - self.assertEqual(type(a), bitarray) + default_endian_choice = choice(['little', 'big']) + with default_endian(default_endian_choice): + a = choice([zeros(0), zeros(0, None), zeros(0, endian=None), + ones(0), ones(0, None), ones(0, endian=None)]) + self.assertEqual(a, bitarray()) + self.assertEqual(a.endian, default_endian_choice) + self.assertEqual(type(a), bitarray) - endian = choice(['little', 'big', None]) - n = randrange(100) + endian = choice(['little', 'big', None]) + n = randrange(100) - a = choice([zeros(n, endian), zeros(n, endian=endian)]) - self.assertEqual(a.to01(), n * "0") - self.assertEqual(a.endian, endian or default_endian) + a = choice([zeros(n, endian), zeros(n, endian=endian)]) + self.assertEqual(a.to01(), n * "0") + self.assertEqual(a.endian, endian or default_endian_choice) - b = choice([ones(n, endian), ones(n, endian=endian)]) - self.assertEqual(b.to01(), n * "1") - self.assertEqual(b.endian, endian or default_endian) + b = choice([ones(n, endian), ones(n, endian=endian)]) + self.assertEqual(b.to01(), n * "1") + self.assertEqual(b.endian, endian or default_endian_choice) def test_errors(self): for f in zeros, ones: @@ -86,19 +86,19 @@ class URandomTests(unittest.TestCase): def test_basic(self): for _ in range(20): - default_endian = choice(['little', 'big']) - _set_default_endian(default_endian) - a = choice([urandom(0), urandom(0, endian=None)]) - self.assertEqual(a, bitarray()) - self.assertEqual(a.endian, default_endian) + default_endian_choice = choice(['little', 'big']) + with default_endian(default_endian_choice): + a = choice([urandom(0), urandom(0, endian=None)]) + self.assertEqual(a, bitarray()) + self.assertEqual(a.endian, default_endian_choice) - endian = choice(['little', 'big', None]) - n = randrange(100) + endian = choice(['little', 'big', None]) + n = randrange(100) - a = choice([urandom(n, endian), urandom(n, endian=endian)]) - self.assertEqual(len(a), n) - self.assertEqual(a.endian, endian or default_endian) - self.assertEqual(type(a), bitarray) + a = choice([urandom(n, endian), urandom(n, endian=endian)]) + self.assertEqual(len(a), n) + self.assertEqual(a.endian, endian or default_endian_choice) + self.assertEqual(type(a), bitarray) def test_errors(self): U = urandom @@ -122,16 +122,16 @@ class Random_K_Tests(unittest.TestCase): def test_basic(self): for _ in range(250): - default_endian = choice(['little', 'big']) - _set_default_endian(default_endian) - endian = choice(['little', 'big', None]) - n = randrange(120) - k = randint(0, n) - a = random_k(n, k, endian) - self.assertTrue(type(a), bitarray) - self.assertEqual(len(a), n) - self.assertEqual(a.count(), k) - self.assertEqual(a.endian, endian or default_endian) + default_endian_choice = choice(['little', 'big']) + with default_endian(default_endian_choice): + endian = choice(['little', 'big', None]) + n = randrange(120) + k = randint(0, n) + a = random_k(n, k, endian) + self.assertTrue(type(a), bitarray) + self.assertEqual(len(a), n) + self.assertEqual(a.count(), k) + self.assertEqual(a.endian, endian or default_endian_choice) def test_inputs_and_edge_cases(self): R = random_k @@ -206,11 +206,11 @@ def collect_code_branches(self): res.append(random_k(5_000, k)) return res + @default_endian("little") def test_seed(self): # We ensure that after setting a seed value, random_k() will # always return the same random bitarrays. However, we do not ensure # that these results will not change in future versions of bitarray. - _set_default_endian("little") a = [] for val in 654321, 654322, 654321, 654322: seed(val) @@ -289,15 +289,15 @@ class Random_P_Tests(unittest.TestCase): def test_basic(self): for _ in range(250): - default_endian = choice(['little', 'big']) - _set_default_endian(default_endian) - endian = choice(['little', 'big', None]) - n = randrange(120) - p = choice([0.0, 0.0001, 0.2, 0.5, 0.9, 1.0]) - a = random_p(n, p, endian) - self.assertTrue(type(a), bitarray) - self.assertEqual(len(a), n) - self.assertEqual(a.endian, endian or default_endian) + default_endian_choice = choice(['little', 'big']) + with default_endian(default_endian_choice): + endian = choice(['little', 'big', None]) + n = randrange(120) + p = choice([0.0, 0.0001, 0.2, 0.5, 0.9, 1.0]) + a = random_p(n, p, endian) + self.assertTrue(type(a), bitarray) + self.assertEqual(len(a), n) + self.assertEqual(a.endian, endian or default_endian_choice) def test_inputs_and_edge_cases(self): R = random_p @@ -343,11 +343,11 @@ def collect_code_branches(self): res.append(random_p(150, p)) return res + @default_endian("little") def test_seed(self): # We ensure that after setting a seed value, random_p() will always # return the same random bitarrays. However, we do not ensure that # these results will not change in future versions of bitarray. - _set_default_endian("little") a = [] for val in 123456, 123457, 123456, 123457: seed(val) @@ -391,18 +391,18 @@ def test_errors(self): def test_explitcit(self): for n in range(230): - default_endian = choice(['little', 'big']) - _set_default_endian(default_endian) - endian = choice(["little", "big", None]) - odd = getrandbits(1) - a = gen_primes(n, endian, odd) - self.assertEqual(len(a), n) - self.assertEqual(a.endian, endian or default_endian) - if odd: - lst = [2] + [2 * i + 1 for i in a.search(1)] - else: - lst = [i for i in a.search(1)] - self.assertEqual(lst, self.primes[:len(lst)]) + default_endian_choice = choice(['little', 'big']) + with default_endian(default_endian_choice): + endian = choice(["little", "big", None]) + odd = getrandbits(1) + a = gen_primes(n, endian, odd) + self.assertEqual(len(a), n) + self.assertEqual(a.endian, endian or default_endian_choice) + if odd: + lst = [2] + [2 * i + 1 for i in a.search(1)] + else: + lst = [i for i in a.search(1)] + self.assertEqual(lst, self.primes[:len(lst)]) def test_cmp(self): N = 10_000 @@ -530,16 +530,16 @@ def test_simple(self): self.assertRaises(TypeError, strip, '0110') self.assertRaises(TypeError, strip, bitarray(), 123) self.assertRaises(ValueError, strip, bitarray(), 'up') - for default_endian in 'big', 'little': - _set_default_endian(default_endian) - a = bitarray('00010110000') - self.assertEQUAL(strip(a), bitarray('0001011')) - self.assertEQUAL(strip(a, 'left'), bitarray('10110000')) - self.assertEQUAL(strip(a, 'both'), bitarray('1011')) - b = frozenbitarray('00010110000') - c = strip(b, 'both') - self.assertEqual(c, bitarray('1011')) - self.assertEqual(type(c), frozenbitarray) + for endian in 'big', 'little': + with default_endian(endian): + a = bitarray('00010110000') + self.assertEQUAL(strip(a), bitarray('0001011')) + self.assertEQUAL(strip(a, 'left'), bitarray('10110000')) + self.assertEQUAL(strip(a, 'both'), bitarray('1011')) + b = frozenbitarray('00010110000') + c = strip(b, 'both') + self.assertEqual(c, bitarray('1011')) + self.assertEqual(type(c), frozenbitarray) def test_zeros_ones(self): for _ in range(50): @@ -1411,8 +1411,8 @@ def test_ba2hex_errors(self): # embedded null character in sep self.assertRaises(ValueError, ba2hex, a, 2, " \0") + @default_endian("big") def test_hex2ba_whitespace(self): - _set_default_endian('big') self.assertEqual(hex2ba("F1 FA %s f3 c0" % whitespace), bitarray("11110001 11111010 11110011 11000000")) self.assertEQUAL(hex2ba(b' a F ', 'big'), @@ -1441,15 +1441,15 @@ def test_hex2ba_types(self): def test_random(self): for _ in range(100): - default_endian = self.random_endian() - _set_default_endian(default_endian) - endian = choice(["little", "big", None]) - a = urandom_2(4 * randrange(100), endian) - s = ba2hex(a, group=randrange(10), sep=choice(whitespace)) - b = hex2ba(s, endian) - self.assertEqual(b.endian, endian or default_endian) - self.assertEqual(a, b) - self.check_obj(b) + random_endian = self.random_endian() + with default_endian(random_endian): + endian = choice(["little", "big", None]) + a = urandom_2(4 * randrange(100), endian) + s = ba2hex(a, group=randrange(10), sep=choice(whitespace)) + b = hex2ba(s, endian) + self.assertEqual(b.endian, endian or random_endian) + self.assertEqual(a, b) + self.check_obj(b) def test_hexdigits(self): a = hex2ba(hexdigits) @@ -1978,24 +1978,23 @@ def test_explicit(self): (b'\xb5\xa7\x18', '0101 0100111 0011'), (b'\x95\xb7\x1c', '0101 0110111 001110'), ]: - default_endian = self.random_endian() - _set_default_endian(default_endian) - - a = bitarray(s) - self.assertEqual(vl_encode(a), blob) - c = vl_decode(blob) - self.assertEqual(c, a) - self.assertEqual(c.endian, default_endian) - - for endian in 'big', 'little', None: - a = bitarray(s, endian) - c = vl_encode(a) - self.assertEqual(type(c), bytes) - self.assertEqual(c, blob) - - c = vl_decode(blob, endian) + random_endian = self.random_endian() + with default_endian(random_endian): + a = bitarray(s) + self.assertEqual(vl_encode(a), blob) + c = vl_decode(blob) self.assertEqual(c, a) - self.assertEqual(c.endian, endian or default_endian) + self.assertEqual(c.endian, random_endian) + + for endian in 'big', 'little', None: + a = bitarray(s, endian) + c = vl_encode(a) + self.assertEqual(type(c), bytes) + self.assertEqual(c, blob) + + c = vl_decode(blob, endian) + self.assertEqual(c, a) + self.assertEqual(c.endian, endian or random_endian) def test_encode_types(self): s = "0011 01" @@ -2263,8 +2262,8 @@ def test_int2ba_length(self): self.assertEqual(int2ba(2 ** n - 1, endian='little'), bitarray(n * '1')) + @default_endian("big") def test_explicit(self): - _set_default_endian('big') for i, sa in [( 0, '0'), (1, '1'), ( 2, '10'), (3, '11'), (25, '11001'), (265, '100001001'),