Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pymodbus.framer import FRAMER_NAME_TO_CLASS, Framer, ModbusFramer
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transaction import DictTransactionManager
from pymodbus.transaction import ModbusTransactionManager
from pymodbus.transport import CommParams, ModbusProtocol
from pymodbus.utilities import ModbusTransactionState

Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
self.framer = FRAMER_NAME_TO_CLASS.get(
framer, cast(Type[ModbusFramer], framer)
)(ClientDecoder(), self)
self.transaction = DictTransactionManager(
self.transaction = ModbusTransactionManager(
self, retries=retries, retry_on_empty=retry_on_empty, **kwargs
)
self.use_udp = False
Expand Down Expand Up @@ -341,7 +341,7 @@ def __init__(
self.framer = FRAMER_NAME_TO_CLASS.get(
framer, cast(Type[ModbusFramer], framer)
)(ClientDecoder(), self)
self.transaction = DictTransactionManager(
self.transaction = ModbusTransactionManager(
self, retries=retries, retry_on_empty=retry_on_empty, **kwargs
)
self.reconnect_delay_current = self.params.reconnect_delay or 0
Expand Down
91 changes: 25 additions & 66 deletions pymodbus/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"DictTransactionManager",
"ModbusTransactionManager",
"ModbusSocketFramer",
"ModbusTlsFramer",
"ModbusRtuFramer",
Expand All @@ -19,7 +20,6 @@
from pymodbus.exceptions import (
InvalidMessageReceivedException,
ModbusIOException,
NotImplementedException,
)
from pymodbus.framer.ascii_framer import ModbusAsciiFramer
from pymodbus.framer.binary_framer import ModbusBinaryFramer
Expand Down Expand Up @@ -47,6 +47,8 @@ class ModbusTransactionManager:
while (count < 3)

This module helps to abstract this away from the framer and protocol.

Results are keyed based on the supplied transaction id.
"""

def __init__(self, client, **kwargs):
Expand All @@ -62,11 +64,19 @@ def __init__(self, client, **kwargs):
self.retry_on_empty = kwargs.get("retry_on_empty", False)
self.retry_on_invalid = kwargs.get("retry_on_invalid", False)
self.retries = kwargs.get("retries", 3)
self.transactions = {}
self._transaction_lock = RLock()
self._no_response_devices = []
if client:
self._set_adu_size()

def __iter__(self):
"""Iterate over the current managed transactions.

:returns: An iterator of the managed transactions
"""
return iter(self.transactions.keys())

def _set_adu_size(self):
"""Set adu size."""
# base ADU size of modbus frame in bytes
Expand Down Expand Up @@ -422,27 +432,33 @@ def addTransaction(self, request, tid=None):

:param request: The request to hold on to
:param tid: The overloaded transaction id to use
:raises NotImplementedException:
"""
raise NotImplementedException("addTransaction")
tid = tid if tid is not None else request.transaction_id
Log.debug("Adding transaction {}", tid)
self.transactions[tid] = request

def getTransaction(self, tid):
"""Return a transaction matching the referenced tid.

If the transaction does not exist, None is returned

:param tid: The transaction to retrieve
:raises NotImplementedException:

"""
raise NotImplementedException("getTransaction")
Log.debug("Getting transaction {}", tid)
if not tid:
if self.transactions:
return self.transactions.popitem()[1]
return None
return self.transactions.pop(tid, None)

def delTransaction(self, tid):
"""Remove a transaction matching the referenced tid.

:param tid: The transaction to remove
:raises NotImplementedException:
"""
raise NotImplementedException("delTransaction")
Log.debug("deleting transaction {}", tid)
self.transactions.pop(tid, None)

def getNextTID(self):
"""Retrieve the next unique transaction identifier.
Expand All @@ -458,64 +474,7 @@ def getNextTID(self):
def reset(self):
"""Reset the transaction identifier."""
self.tid = 0
self.transactions = type( # pylint: disable=attribute-defined-outside-init
self.transactions
)()


class DictTransactionManager(ModbusTransactionManager):
"""Implements a transaction for a manager.

Where the results are keyed based on the supplied transaction id.
"""

def __init__(self, client, **kwargs):
"""Initialize an instance of the ModbusTransactionManager.

:param client: The client socket wrapper
"""
self.transactions = {}
super().__init__(client, **kwargs)

def __iter__(self):
"""Iterate over the current managed transactions.

:returns: An iterator of the managed transactions
"""
return iter(self.transactions.keys())

def addTransaction(self, request, tid=None):
"""Add a transaction to the handler.

This holds the requests in case it needs to be resent.
After being sent, the request is removed.

:param request: The request to hold on to
:param tid: The overloaded transaction id to use
"""
tid = tid if tid is not None else request.transaction_id
Log.debug("Adding transaction {}", tid)
self.transactions[tid] = request

def getTransaction(self, tid):
"""Return a transaction matching the referenced tid.

If the transaction does not exist, None is returned

:param tid: The transaction to retrieve

"""
Log.debug("Getting transaction {}", tid)
if not tid:
if self.transactions:
return self.transactions.popitem()[1]
return None
return self.transactions.pop(tid, None)

def delTransaction(self, tid):
"""Remove a transaction matching the referenced tid.

:param tid: The transaction to remove
"""
Log.debug("deleting transaction {}", tid)
self.transactions.pop(tid, None)
class DictTransactionManager(ModbusTransactionManager):
"""Old alias for ModbusTransactionManager."""
54 changes: 24 additions & 30 deletions test/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pymodbus.factory import ServerDecoder
from pymodbus.pdu import ModbusRequest
from pymodbus.transaction import (
DictTransactionManager,
ModbusAsciiFramer,
ModbusBinaryFramer,
ModbusRtuFramer,
Expand Down Expand Up @@ -50,24 +49,23 @@ def setup_method(self):
self._rtu = ModbusRtuFramer(decoder=self.decoder, client=None)
self._ascii = ModbusAsciiFramer(decoder=self.decoder, client=None)
self._binary = ModbusBinaryFramer(decoder=self.decoder, client=None)
self._manager = DictTransactionManager(self.client)
self._tm = ModbusTransactionManager(self.client)
self._manager = ModbusTransactionManager(self.client)

# ----------------------------------------------------------------------- #
# Base transaction manager
# Modbus transaction manager
# ----------------------------------------------------------------------- #

def test_calculate_expected_response_length(self):
"""Test calculate expected response length."""
self._tm.client = mock.MagicMock()
self._tm.client.framer = mock.MagicMock()
self._tm._set_adu_size() # pylint: disable=protected-access
assert not self._tm._calculate_response_length( # pylint: disable=protected-access
self._manager.client = mock.MagicMock()
self._manager.client.framer = mock.MagicMock()
self._manager._set_adu_size() # pylint: disable=protected-access
assert not self._manager._calculate_response_length( # pylint: disable=protected-access
0
)
self._tm.base_adu_size = 10
self._manager.base_adu_size = 10
assert (
self._tm._calculate_response_length(5) # pylint: disable=protected-access
self._manager._calculate_response_length(5) # pylint: disable=protected-access
== 15
)

Expand All @@ -81,23 +79,23 @@ def test_calculate_exception_length(self):
("tls", 2),
("dummy", None),
):
self._tm.client = mock.MagicMock()
self._manager.client = mock.MagicMock()
if framer == "ascii":
self._tm.client.framer = self._ascii
self._manager.client.framer = self._ascii
elif framer == "binary":
self._tm.client.framer = self._binary
self._manager.client.framer = self._binary
elif framer == "rtu":
self._tm.client.framer = self._rtu
self._manager.client.framer = self._rtu
elif framer == "tcp":
self._tm.client.framer = self._tcp
self._manager.client.framer = self._tcp
elif framer == "tls":
self._tm.client.framer = self._tls
self._manager.client.framer = self._tls
else:
self._tm.client.framer = mock.MagicMock()
self._manager.client.framer = mock.MagicMock()

self._tm._set_adu_size() # pylint: disable=protected-access
self._manager._set_adu_size() # pylint: disable=protected-access
assert (
self._tm._calculate_exception_length() # pylint: disable=protected-access
self._manager._calculate_exception_length() # pylint: disable=protected-access
== exception_length
)

Expand Down Expand Up @@ -140,7 +138,7 @@ def test_execute(self, mock_time):
trans._recv = mock.MagicMock( # pylint: disable=protected-access
return_value=b"abcdef"
)
trans.transactions = []
trans.transactions = {}
trans.getTransaction = mock.MagicMock()
trans.getTransaction.return_value = None
response = trans.execute(request)
Expand Down Expand Up @@ -198,19 +196,15 @@ def test_execute(self, mock_time):
recv.assert_called_once_with(8, False)
client.comm_params.handle_local_echo = False

# ----------------------------------------------------------------------- #
# Dictionary based transaction manager
# ----------------------------------------------------------------------- #

def test_dict_transaction_manager_tid(self):
"""Test the dict transaction manager TID."""
def test_transaction_manager_tid(self):
"""Test the transaction manager TID."""
for tid in range(1, self._manager.getNextTID() + 10):
assert tid + 1 == self._manager.getNextTID()
self._manager.reset()
assert self._manager.getNextTID() == 1

def test_get_dict_fifo_transaction_manager_transaction(self):
"""Test the dict transaction manager."""
def test_get_transaction_manager_transaction(self):
"""Test the getting a transaction from the transaction manager."""

class Request: # pylint: disable=too-few-public-methods
"""Request."""
Expand All @@ -225,8 +219,8 @@ class Request: # pylint: disable=too-few-public-methods
result = self._manager.getTransaction(handle.transaction_id)
assert handle.message == result.message

def test_delete_dict_fifo_transaction_manager_transaction(self):
"""Test the dict transaction manager."""
def test_delete_transaction_manager_transaction(self):
"""Test deleting a transaction from the dict transaction manager."""

class Request: # pylint: disable=too-few-public-methods
"""Request."""
Expand Down