diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index c86153f8d..d496b4146 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -12,7 +12,7 @@ from pymodbus.framer import FRAMER_NAME_TO_CLASS, Framer, ModbusFramer from pymodbus.logging import Log from pymodbus.pdu import ModbusRequest, ModbusResponse -from pymodbus.transaction import DictTransactionManager +from pymodbus.transaction import ModbusTransactionManager from pymodbus.transport import CommParams, ModbusProtocol from pymodbus.utilities import ModbusTransactionState @@ -92,7 +92,7 @@ def __init__( self.framer = FRAMER_NAME_TO_CLASS.get( framer, cast(Type[ModbusFramer], framer) )(ClientDecoder(), self) - self.transaction = DictTransactionManager( + self.transaction = ModbusTransactionManager( self, retries=retries, retry_on_empty=retry_on_empty, **kwargs ) self.use_udp = False @@ -341,7 +341,7 @@ def __init__( self.framer = FRAMER_NAME_TO_CLASS.get( framer, cast(Type[ModbusFramer], framer) )(ClientDecoder(), self) - self.transaction = DictTransactionManager( + self.transaction = ModbusTransactionManager( self, retries=retries, retry_on_empty=retry_on_empty, **kwargs ) self.reconnect_delay_current = self.params.reconnect_delay or 0 diff --git a/pymodbus/transaction.py b/pymodbus/transaction.py index 5bb2ec69c..e0239f5ca 100644 --- a/pymodbus/transaction.py +++ b/pymodbus/transaction.py @@ -2,6 +2,7 @@ __all__ = [ "DictTransactionManager", + "ModbusTransactionManager", "ModbusSocketFramer", "ModbusTlsFramer", "ModbusRtuFramer", @@ -19,7 +20,6 @@ from pymodbus.exceptions import ( InvalidMessageReceivedException, ModbusIOException, - NotImplementedException, ) from pymodbus.framer.ascii_framer import ModbusAsciiFramer from pymodbus.framer.binary_framer import ModbusBinaryFramer @@ -47,6 +47,8 @@ class ModbusTransactionManager: while (count < 3) This module helps to abstract this away from the framer and protocol. + + Results are keyed based on the supplied transaction id. """ def __init__(self, client, **kwargs): @@ -62,11 +64,19 @@ def __init__(self, client, **kwargs): self.retry_on_empty = kwargs.get("retry_on_empty", False) self.retry_on_invalid = kwargs.get("retry_on_invalid", False) self.retries = kwargs.get("retries", 3) + self.transactions = {} self._transaction_lock = RLock() self._no_response_devices = [] if client: self._set_adu_size() + def __iter__(self): + """Iterate over the current managed transactions. + + :returns: An iterator of the managed transactions + """ + return iter(self.transactions.keys()) + def _set_adu_size(self): """Set adu size.""" # base ADU size of modbus frame in bytes @@ -422,9 +432,10 @@ def addTransaction(self, request, tid=None): :param request: The request to hold on to :param tid: The overloaded transaction id to use - :raises NotImplementedException: """ - raise NotImplementedException("addTransaction") + tid = tid if tid is not None else request.transaction_id + Log.debug("Adding transaction {}", tid) + self.transactions[tid] = request def getTransaction(self, tid): """Return a transaction matching the referenced tid. @@ -432,17 +443,22 @@ def getTransaction(self, tid): If the transaction does not exist, None is returned :param tid: The transaction to retrieve - :raises NotImplementedException: + """ - raise NotImplementedException("getTransaction") + Log.debug("Getting transaction {}", tid) + if not tid: + if self.transactions: + return self.transactions.popitem()[1] + return None + return self.transactions.pop(tid, None) def delTransaction(self, tid): """Remove a transaction matching the referenced tid. :param tid: The transaction to remove - :raises NotImplementedException: """ - raise NotImplementedException("delTransaction") + Log.debug("deleting transaction {}", tid) + self.transactions.pop(tid, None) def getNextTID(self): """Retrieve the next unique transaction identifier. @@ -458,64 +474,7 @@ def getNextTID(self): def reset(self): """Reset the transaction identifier.""" self.tid = 0 - self.transactions = type( # pylint: disable=attribute-defined-outside-init - self.transactions - )() - - -class DictTransactionManager(ModbusTransactionManager): - """Implements a transaction for a manager. - - Where the results are keyed based on the supplied transaction id. - """ - - def __init__(self, client, **kwargs): - """Initialize an instance of the ModbusTransactionManager. - - :param client: The client socket wrapper - """ self.transactions = {} - super().__init__(client, **kwargs) - - def __iter__(self): - """Iterate over the current managed transactions. - - :returns: An iterator of the managed transactions - """ - return iter(self.transactions.keys()) - - def addTransaction(self, request, tid=None): - """Add a transaction to the handler. - - This holds the requests in case it needs to be resent. - After being sent, the request is removed. - - :param request: The request to hold on to - :param tid: The overloaded transaction id to use - """ - tid = tid if tid is not None else request.transaction_id - Log.debug("Adding transaction {}", tid) - self.transactions[tid] = request - - def getTransaction(self, tid): - """Return a transaction matching the referenced tid. - - If the transaction does not exist, None is returned - :param tid: The transaction to retrieve - - """ - Log.debug("Getting transaction {}", tid) - if not tid: - if self.transactions: - return self.transactions.popitem()[1] - return None - return self.transactions.pop(tid, None) - - def delTransaction(self, tid): - """Remove a transaction matching the referenced tid. - - :param tid: The transaction to remove - """ - Log.debug("deleting transaction {}", tid) - self.transactions.pop(tid, None) +class DictTransactionManager(ModbusTransactionManager): + """Old alias for ModbusTransactionManager.""" diff --git a/test/test_transaction.py b/test/test_transaction.py index 3385647aa..95bbdc5b3 100755 --- a/test/test_transaction.py +++ b/test/test_transaction.py @@ -12,7 +12,6 @@ from pymodbus.factory import ServerDecoder from pymodbus.pdu import ModbusRequest from pymodbus.transaction import ( - DictTransactionManager, ModbusAsciiFramer, ModbusBinaryFramer, ModbusRtuFramer, @@ -50,24 +49,23 @@ def setup_method(self): self._rtu = ModbusRtuFramer(decoder=self.decoder, client=None) self._ascii = ModbusAsciiFramer(decoder=self.decoder, client=None) self._binary = ModbusBinaryFramer(decoder=self.decoder, client=None) - self._manager = DictTransactionManager(self.client) - self._tm = ModbusTransactionManager(self.client) + self._manager = ModbusTransactionManager(self.client) # ----------------------------------------------------------------------- # - # Base transaction manager + # Modbus transaction manager # ----------------------------------------------------------------------- # def test_calculate_expected_response_length(self): """Test calculate expected response length.""" - self._tm.client = mock.MagicMock() - self._tm.client.framer = mock.MagicMock() - self._tm._set_adu_size() # pylint: disable=protected-access - assert not self._tm._calculate_response_length( # pylint: disable=protected-access + self._manager.client = mock.MagicMock() + self._manager.client.framer = mock.MagicMock() + self._manager._set_adu_size() # pylint: disable=protected-access + assert not self._manager._calculate_response_length( # pylint: disable=protected-access 0 ) - self._tm.base_adu_size = 10 + self._manager.base_adu_size = 10 assert ( - self._tm._calculate_response_length(5) # pylint: disable=protected-access + self._manager._calculate_response_length(5) # pylint: disable=protected-access == 15 ) @@ -81,23 +79,23 @@ def test_calculate_exception_length(self): ("tls", 2), ("dummy", None), ): - self._tm.client = mock.MagicMock() + self._manager.client = mock.MagicMock() if framer == "ascii": - self._tm.client.framer = self._ascii + self._manager.client.framer = self._ascii elif framer == "binary": - self._tm.client.framer = self._binary + self._manager.client.framer = self._binary elif framer == "rtu": - self._tm.client.framer = self._rtu + self._manager.client.framer = self._rtu elif framer == "tcp": - self._tm.client.framer = self._tcp + self._manager.client.framer = self._tcp elif framer == "tls": - self._tm.client.framer = self._tls + self._manager.client.framer = self._tls else: - self._tm.client.framer = mock.MagicMock() + self._manager.client.framer = mock.MagicMock() - self._tm._set_adu_size() # pylint: disable=protected-access + self._manager._set_adu_size() # pylint: disable=protected-access assert ( - self._tm._calculate_exception_length() # pylint: disable=protected-access + self._manager._calculate_exception_length() # pylint: disable=protected-access == exception_length ) @@ -140,7 +138,7 @@ def test_execute(self, mock_time): trans._recv = mock.MagicMock( # pylint: disable=protected-access return_value=b"abcdef" ) - trans.transactions = [] + trans.transactions = {} trans.getTransaction = mock.MagicMock() trans.getTransaction.return_value = None response = trans.execute(request) @@ -198,19 +196,15 @@ def test_execute(self, mock_time): recv.assert_called_once_with(8, False) client.comm_params.handle_local_echo = False - # ----------------------------------------------------------------------- # - # Dictionary based transaction manager - # ----------------------------------------------------------------------- # - - def test_dict_transaction_manager_tid(self): - """Test the dict transaction manager TID.""" + def test_transaction_manager_tid(self): + """Test the transaction manager TID.""" for tid in range(1, self._manager.getNextTID() + 10): assert tid + 1 == self._manager.getNextTID() self._manager.reset() assert self._manager.getNextTID() == 1 - def test_get_dict_fifo_transaction_manager_transaction(self): - """Test the dict transaction manager.""" + def test_get_transaction_manager_transaction(self): + """Test the getting a transaction from the transaction manager.""" class Request: # pylint: disable=too-few-public-methods """Request.""" @@ -225,8 +219,8 @@ class Request: # pylint: disable=too-few-public-methods result = self._manager.getTransaction(handle.transaction_id) assert handle.message == result.message - def test_delete_dict_fifo_transaction_manager_transaction(self): - """Test the dict transaction manager.""" + def test_delete_transaction_manager_transaction(self): + """Test deleting a transaction from the dict transaction manager.""" class Request: # pylint: disable=too-few-public-methods """Request."""