diff --git a/python/sbp/client/handler.py b/python/sbp/client/handler.py index ff2500d848..98908b73ee 100644 --- a/python/sbp/client/handler.py +++ b/python/sbp/client/handler.py @@ -180,10 +180,15 @@ def add_callback(self, callback, msg_type=None): ---------- callback : fn Callback function - msg_type : int + msg_type : int | iterable Message type to register callback against. Default `None` means global callback. + Iterable type adds the callback to all the message types. """ - self.callbacks[msg_type].add(callback) + try: + for mt in iter(msg_type): + self.callbacks[mt].add(callback) + except TypeError: + self.callbacks[msg_type].add(callback) def remove_callback(self, callback, msg_type=None): """ @@ -193,10 +198,15 @@ def remove_callback(self, callback, msg_type=None): ---------- callback : fn Callback function - msg_type : int + msg_type : int | iterable Message type to remove callback from. Default `None` means global callback. + Iterable type removes the callback from all the message types. """ - self.callbacks[msg_type].remove(callback) + try: + for mt in iter(msg_type): + self.callbacks[mt].remove(callback) + except TypeError: + self.callbacks[msg_type].remove(callback) def get_callbacks(self, msg_type): """ diff --git a/python/tests/sbp/client/test_handler.py b/python/tests/sbp/client/test_handler.py index d02d087082..df12c72cba 100644 --- a/python/tests/sbp/client/test_handler.py +++ b/python/tests/sbp/client/test_handler.py @@ -101,3 +101,30 @@ def test_handler_callbacks(): assert global_counter2.value == 2 assert msg_type_counter1.value == 1 assert msg_type_counter2.value == 0 + handler.remove_callback(global_counter1) + handler.remove_callback(global_counter2) + handler.remove_callback(msg_type_counter1, 0x55) + handler.remove_callback(msg_type_counter2, 0x66) + handler.call(SBP(0x11, None, None, None, None)) + handler.call(SBP(0x55, None, None, None, None)) + assert global_counter1.value == 2 + assert global_counter2.value == 2 + assert msg_type_counter1.value == 1 + assert msg_type_counter2.value == 0 + +def test_multiple_handler_callbacks(): + handler = Handler(None, None) + msg_type_counter1 = TestCallbackCounter() + msg_type_counter2 = TestCallbackCounter() + handler.add_callback(msg_type_counter1, [0x55, 0x66]) + handler.add_callback(msg_type_counter2, [0x11, 0x55]) + handler.call(SBP(0x11, None, None, None, None)) + handler.call(SBP(0x55, None, None, None, None)) + assert msg_type_counter1.value == 1 + assert msg_type_counter2.value == 2 + handler.remove_callback(msg_type_counter1, [0x55, 0x66]) + handler.remove_callback(msg_type_counter2, [0x11, 0x55]) + handler.call(SBP(0x11, None, None, None, None)) + handler.call(SBP(0x55, None, None, None, None)) + assert msg_type_counter1.value == 1 + assert msg_type_counter2.value == 2