Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion bitarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

Author: Ilan Schnell
"""
import contextlib
from collections import namedtuple

from bitarray._bitarray import (
bitarray, decodetree, bits2bytes, _bitarray_reconstructor,
get_default_endian, _set_default_endian, _sysinfo,
BITARRAY_VERSION as __version__
BITARRAY_VERSION as __version__, _default_endian_contextvar
)

__all__ = ['bitarray', 'frozenbitarray', 'decodetree', 'bits2bytes']
Expand Down Expand Up @@ -64,3 +65,30 @@ def test(verbosity=1):
"""
from bitarray import test_bitarray
return test_bitarray.run(verbosity=verbosity)

# from bitarray.h:
# #define ENDIAN_LITTLE 0
# #define ENDIAN_BIG 1
ENDIANNESS_MAPPING = {'little': 0, 'big': 1}

@contextlib.contextmanager
def default_endian(endian):
"""default_endian(endian, /) -> str
Context manager for controlling the default endianness for bitarrays.

Set the default endianness for the scope of a ``with`` block and restore the
original default endianness at the end. Possible values are 'big' and 'little'.
"""
if not isinstance(endian, str):
raise TypeError(
f"default endianness must be 'big' or 'little', got {endian}"
f"with type {type(endian)}"
)
if endian not in ['big', 'little']:
raise ValueError(
f"default endianness must be 'big' or 'little', got {endian}")
token = _default_endian_contextvar.set(ENDIANNESS_MAPPING[endian])
try:
yield endian
finally:
_default_endian_contextvar.reset(token)
46 changes: 39 additions & 7 deletions bitarray/_bitarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#define BLOCKSIZE 65536

/* default bit-endianness */
static int default_endian = ENDIAN_BIG;
static PyObject* default_endian;

/* translation table - setup during module initialization */
static char reverse_trans[256];
Expand Down Expand Up @@ -3107,6 +3107,9 @@ decodetree_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return obj;
}

// forward declaration
static int get_c_default_endian(void);

static PyObject *
decodetree_todict(decodetreeobject *self)
{
Expand All @@ -3116,7 +3119,7 @@ decodetree_todict(decodetreeobject *self)
if ((dict = PyDict_New()) == NULL)
return NULL;

prefix = newbitarrayobject(&Bitarray_Type, 0, default_endian);
prefix = newbitarrayobject(&Bitarray_Type, 0, get_c_default_endian());
if (prefix == NULL)
goto error;

Expand Down Expand Up @@ -3595,15 +3598,29 @@ static PyMethodDef bitarray_methods[] = {

/* ------------------------ bitarray initialization -------------------- */

static int
get_c_default_endian(void)
{
PyObject *current_default_endian;
if (PyContextVar_Get(default_endian, NULL, &current_default_endian) < 0)
{
return -1;
}
int ret = (int)PyLong_AsLong(current_default_endian);
Py_DECREF(current_default_endian);
return ret;
}

/* Given string 'str', return an integer representing the bit-endianness.
If the string is invalid, set exception and return -1. */
static int
endian_from_string(const char *str)
{
assert(default_endian == ENDIAN_LITTLE || default_endian == ENDIAN_BIG);
int c_default_endian = get_c_default_endian();
assert(c_default_endian == ENDIAN_LITTLE || c_default_endian == ENDIAN_BIG);

if (str == NULL)
return default_endian;
return c_default_endian;

if (strcmp(str, "little") == 0)
return ENDIAN_LITTLE;
Expand Down Expand Up @@ -4126,7 +4143,7 @@ reconstructor(PyObject *module, PyObject *args)
static PyObject *
get_default_endian(PyObject *module)
{
return PyUnicode_FromString(ENDIAN_STR(default_endian));
return PyUnicode_FromString(ENDIAN_STR(get_c_default_endian()));
}

PyDoc_STRVAR(get_default_endian_doc,
Expand All @@ -4150,8 +4167,15 @@ set_default_endian(PyObject *module, PyObject *args)
in a temporary variable before setting default_endian. */
if ((t = endian_from_string(endian_str)) < 0)
return NULL;
default_endian = t;
PyObject *py_t = PyLong_FromLong(t);
if (py_t == NULL)
return NULL;

if (PyContextVar_Set(default_endian, py_t) == NULL) {
return NULL;
}

Py_DECREF(py_t);
Py_RETURN_NONE;
}

Expand Down Expand Up @@ -4264,7 +4288,15 @@ PyInit__bitarray(void)
return NULL;
Py_SET_TYPE(&Bitarray_Type, &PyType_Type);
Py_INCREF((PyObject *) &Bitarray_Type);
PyModule_AddObject(m, "bitarray", (PyObject *) &Bitarray_Type);
PyModule_AddObject(m, "bitarray", (PyObject *)&Bitarray_Type);

PyObject *Py_ENDIAN_BIG = PyLong_FromLong(ENDIAN_BIG);

default_endian = PyContextVar_New("default_endian", Py_ENDIAN_BIG);

Py_DECREF(Py_ENDIAN_BIG);

PyModule_AddObject(m, "_default_endian_contextvar", default_endian);

if (register_abc() < 0)
return NULL;
Expand Down
159 changes: 115 additions & 44 deletions bitarray/test_bitarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@

# imports needed inside tests
import array
import asyncio
import concurrent.futures
import copy
import itertools
import mmap
import pickle
import shelve
import threading
import weakref


Expand All @@ -35,7 +38,7 @@
from bitarray import (bitarray, frozenbitarray, bits2bytes, decodetree,
get_default_endian, _set_default_endian,
_bitarray_reconstructor, _sysinfo as sysinfo,
BufferInfo, __version__)
BufferInfo, __version__, default_endian)

def skipIf(condition):
"Skip a test if the condition is true."
Expand Down Expand Up @@ -176,35 +179,52 @@ def test_sysinfo_byteorder(self):
sysinfo("PY_BIG_ENDIAN"))

def test_set_default_endian(self):
for default_endian in 'big', 'little':
_set_default_endian(default_endian)
a = bitarray()
self.assertEqual(a.endian, default_endian)
for x in None, 0, 64, '10111', [1, 0]:
a = bitarray(x)
self.assertEqual(a.endian, default_endian)

for endian in 'big', 'little', None:
a = bitarray(endian=endian)
self.assertEqual(a.endian,
default_endian if endian is None else endian)

# make sure that wrong calling _set_default_endian() does not
# change the default endianness
self.assertRaises(ValueError, _set_default_endian, 'foobar')
self.assertEqual(bitarray().endian, default_endian)
for default in 'big', 'little':
with default_endian(default) as used_endian:
assert used_endian == default
a = bitarray()
self.assertEqual(a.endian, default)
for x in None, 0, 64, '10111', [1, 0]:
a = bitarray(x)
self.assertEqual(a.endian, default)

for endian in 'big', 'little', None:
a = bitarray(endian=endian)
self.assertEqual(a.endian,
default if endian is None else endian)

# make sure that wrong calling _set_default_endian() does not
# change the default endianness
self.assertRaises(ValueError, _set_default_endian, 'foobar')
self.assertEqual(bitarray().endian, default)

def test_set_default_endian_errors(self):
self.assertRaises(TypeError, _set_default_endian, 0)
self.assertRaises(TypeError, _set_default_endian, 'little', 0)
self.assertRaises(ValueError, _set_default_endian, 'foo')

def test_default_endian_errors(self):
def integer_arg():
with default_endian(0):
pass
self.assertRaises(TypeError, integer_arg)

def too_many_args():
with default_endian('little', 0):
pass
self.assertRaises(TypeError, too_many_args)

def invalid_string_arg():
with default_endian('foo'):
pass
self.assertRaises(ValueError, invalid_string_arg)

def test_get_default_endian(self):
for default_endian in 'big', 'little':
_set_default_endian(default_endian)
endian = get_default_endian()
self.assertEqual(endian, default_endian)
self.assertEqual(type(endian), str)
for default in 'big', 'little':
with default_endian(default):
endian = get_default_endian()
self.assertEqual(endian, default)
self.assertEqual(type(endian), str)

def test_get_default_endian_errors(self):
# takes no arguments
Expand Down Expand Up @@ -258,11 +278,10 @@ def test_endian(self):
self.assertEqual(a.tobytes(), b.tobytes())

def test_endian_default(self):
_set_default_endian('big')
a_big = bitarray()
_set_default_endian('little')
a_little = bitarray()
_set_default_endian('big')
with default_endian('big'):
a_big = bitarray()
with default_endian('little'):
a_little = bitarray()

self.assertEqual(a_big.endian, 'big')
self.assertEqual(a_little.endian, 'little')
Expand All @@ -283,8 +302,8 @@ def test_buffer_endian(self):
a = bitarray(buffer=b'', endian=endian)
self.assertEQUAL(a, bitarray(0, endian))

_set_default_endian(endian)
a = bitarray(buffer=b'A')
with default_endian(endian):
a = bitarray(buffer=b'A')
self.assertEqual(a.endian, endian)
self.assertEqual(len(a), 8)

Expand Down Expand Up @@ -1535,6 +1554,58 @@ def test_basic(self):
self.assertRaises(IndexError, a.__delitem__, [10])
self.assertRaises(TypeError, a.__delitem__, (1, 3))

# asyncio.Barrier was introduced in Python 3.11
@skipIf(sys.version_info < (3, 11))
def test_default_endian_async_safe(self):
endian_start = get_default_endian()
b = asyncio.Barrier(10)

async def big():
_set_default_endian('big')
await b.wait()
endian = get_default_endian()
assert endian == 'big'

async def little():
_set_default_endian('little')
await b.wait()
endian = get_default_endian()
assert endian == 'little'

async def main():
await asyncio.gather(
big(), big(), big(), big(), big(),
little(), little(), little(), little(), little()
)

loop = asyncio.new_event_loop()
asyncio.run(main())
loop.close()
assert get_default_endian() == endian_start

def test_default_endian_thread_safe(self):
endian_start = get_default_endian()
b = threading.Barrier(10)

def big():
b.wait()
_set_default_endian('big')
endian = get_default_endian()
assert endian == 'big'

def little():
b.wait()
_set_default_endian('little')
endian = get_default_endian()
assert endian == 'little'

tpe = concurrent.futures.ThreadPoolExecutor(max_workers=10)
futures = [tpe.submit(func) for func in [big]*5 + [little]*5]
for f in futures:
f.result()

assert get_default_endian() == endian_start

def test_delete_one(self):
for a in self.randombitarrays(start=1):
b = a.copy()
Expand Down Expand Up @@ -2421,9 +2492,9 @@ def test_frozenbitarray(self):
self.assertEqual(type(b), frozenbitarray)
self.assertRaises(TypeError, a.__ilshift__, 4)

@default_endian("big")
@skipIf(is_pypy)
def test_imported(self):
_set_default_endian("big")
a = bytearray([0xf0, 0x01, 0x02, 0x0f])
b = bitarray(buffer=a)
self.assertFalse(b.readonly)
Expand Down Expand Up @@ -2938,19 +3009,19 @@ class PackTests(unittest.TestCase, Util):

def test_pack_simple(self):
for endian in 'little', 'big':
_set_default_endian(endian)
a = bitarray()
a.pack(bytes())
self.assertEQUAL(a, bitarray())
a.pack(b'\x00')
self.assertEQUAL(a, bitarray('0'))
a.pack(b'\xff')
self.assertEQUAL(a, bitarray('01'))
a.pack(b'\x01\x00\x7a')
self.assertEQUAL(a, bitarray('01101'))
a.pack(bytearray([0x01, 0x00, 0xff, 0xa7]))
self.assertEQUAL(a, bitarray('01101 1011'))
self.check_obj(a)
with default_endian(endian):
a = bitarray()
a.pack(bytes())
self.assertEQUAL(a, bitarray())
a.pack(b'\x00')
self.assertEQUAL(a, bitarray('0'))
a.pack(b'\xff')
self.assertEQUAL(a, bitarray('01'))
a.pack(b'\x01\x00\x7a')
self.assertEQUAL(a, bitarray('01101'))
a.pack(bytearray([0x01, 0x00, 0xff, 0xa7]))
self.assertEQUAL(a, bitarray('01101 1011'))
self.check_obj(a)

def test_pack_types(self):
a = bitarray()
Expand Down
Loading