Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions pymodbus/transport/serialtransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, loop, protocol, *args, **kwargs) -> None:
def setup(self):
"""Prepare to read/write."""
if os.name == "nt" or self.force_poll:
self.poll_task = asyncio.create_task(self._polling_task())
self.poll_task = asyncio.create_task(self.polling_task())
self.poll_task.set_name("SerialTransport poll")
else:
self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
Expand All @@ -41,9 +41,6 @@ def close(self, exc: Exception | None = None) -> None:
"""Close the transport gracefully."""
if not self.sync_serial:
return
with contextlib.suppress(Exception):
self.sync_serial.flush()

self.flush()
if self.poll_task:
self.poll_task.cancel()
Expand All @@ -61,7 +58,7 @@ def write(self, data) -> None:
"""Write some data to the transport."""
self._write_buffer.append(data)
if not self.poll_task:
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
self.async_loop.add_writer(self.sync_serial.fileno(), self.write_ready)

def flush(self) -> None:
"""Clear output buffer and stops any more data being written."""
Expand Down Expand Up @@ -131,15 +128,15 @@ def _read_ready(self):
except serial.SerialException as exc:
self.close(exc=exc)

def _write_ready(self):
def write_ready(self):
"""Asynchronously write buffered data."""
data = b"".join(self._write_buffer)
try:
if (nlen := self.sync_serial.write(data)) < len(data):
self._write_buffer = [data[nlen:]]
if not self.poll_task:
self.async_loop.add_writer(
self.sync_serial.fileno(), self._write_ready
self.sync_serial.fileno(), self.write_ready
)
return
self.flush()
Expand All @@ -148,19 +145,17 @@ def _write_ready(self):
except serial.SerialException as exc:
self.close(exc=exc)

async def _polling_task(self):
async def polling_task(self):
"""Poll and try to read/write."""
try:
while True:
while self.sync_serial:
await asyncio.sleep(self._poll_wait_time)
while self._write_buffer:
self._write_ready()
self.write_ready()
if self.sync_serial.in_waiting:
self._read_ready()
except serial.SerialException as exc:
self.close(exc=exc)
except asyncio.CancelledError:
pass
self.close("Cancelled")


async def create_serial_connection(
Expand Down
15 changes: 1 addition & 14 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import asyncio
import dataclasses
import ssl
import sys
from contextlib import suppress
from enum import Enum
from typing import Any, Callable, Coroutine
Expand All @@ -62,14 +61,6 @@

NULLMODEM_HOST = "__pymodbus_nullmodem"

if sys.version_info.minor == 11:
USEEXCEPTIONS: tuple[type[Any], type[Any]] | type[Any] = OSError
else:
USEEXCEPTIONS = (
asyncio.TimeoutError,
OSError,
)


class CommType(Enum):
"""Type of transport."""
Expand Down Expand Up @@ -254,13 +245,9 @@ async def transport_connect(self) -> bool:
self.call_create(),
timeout=self.comm_params.timeout_connect,
)
except USEEXCEPTIONS as exc:
except (asyncio.TimeoutError, OSError) as exc: # pylint: disable=overlapping-except
Log.warning("Failed to connect {}", exc)
# self.transport_close(intern=True, reconnect=True)
return False
except Exception as exc:
Log.warning("Failed to connect UNKNOWN EXCEPTION {}", exc)
raise
return bool(self.transport)

async def transport_listen(self) -> bool:
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ documentation = [
development = [
"build>=1.0.3",
"codespell>=2.2.2",
"coverage>=7.1.0",
"coverage>=7.4.0",
"mypy>=1.6.0",
"pylint>=2.17.2",
"pytest>=7.3.1",
Expand Down Expand Up @@ -232,6 +232,14 @@ include = [
]
omit = ["examples/contrib/"]

[tool.coverage.report]
exclude_lines = [
"_check_system_health",
"if __name__ == .__main__.:",
]

ignore_errors = true

[tool.codespell]
skip = "./build,./doc/source/_static,venv,.venv,.git,htmlcov,CHANGELOG.rst,.mypy_cache"
ignore-words-list = "asend"
Expand Down
4 changes: 2 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ async def _check_system_health():
start_tasks = {task.get_name(): task for task in asyncio.all_tasks()}
yield
await asyncio.sleep(0.1)
all_clean = True
for count in range(10):
all_clean = True
error_text = f"ERROR tasks/threads hanging after {count} retries:\n"
for thread in thread_enumerate():
name = thread.getName()
Expand All @@ -72,7 +72,7 @@ async def _check_system_health():
all_clean = False
if all_clean:
break
await asyncio.sleep(1)
await asyncio.sleep(0.3)
assert all_clean, error_text
assert not NullModem.is_dirty()

Expand Down
5 changes: 1 addition & 4 deletions test/sub_examples/test_client_server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
from pymodbus.server import ServerStop


if os.name == "nt":
SLEEPING = 5
else:
SLEEPING = 1
SLEEPING = 5 if os.name == "nt" else 1


@pytest.mark.parametrize("use_host", ["localhost"])
Expand Down
19 changes: 0 additions & 19 deletions test/sub_transport/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
CommType,
ModbusProtocol,
)
from pymodbus.transport.transport import NullModem


class DummyProtocol(ModbusProtocol):
Expand Down Expand Up @@ -50,12 +49,6 @@ def prepare_dummy_protocol():
return DummyProtocol


@pytest.fixture(name="cwd_certificate")
def prepare_cwd_certificate():
"""Prepare path to certificate."""
return os.path.dirname(__file__) + "/../../examples/certificates/pymodbus."


@pytest.fixture(name="use_comm_type")
def prepare_dummy_use_comm_type():
"""Return default comm_type."""
Expand Down Expand Up @@ -138,15 +131,3 @@ def prepare_transport_server(use_cls):
True, certfile=cwd + "crt", keyfile=cwd + "key"
)
return transport


@pytest.fixture(name="nullmodem")
def prepare_nullmodem():
"""Prepare nullmodem object."""
return NullModem(mock.Mock())


@pytest.fixture(name="nullmodem_server")
def prepare_nullmodem_server():
"""Prepare nullmodem object."""
return NullModem(mock.Mock())
48 changes: 40 additions & 8 deletions test/sub_transport/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test transport."""
import asyncio
import os
from unittest import mock

import pytest
import serial

from pymodbus.transport import (
CommType,
Expand All @@ -22,7 +24,7 @@
]


class TestBasicModbusProtocol:
class TestBasicModbusProtocol: # pylint: disable=too-many-public-methods
"""Test transport module."""

@staticmethod
Expand Down Expand Up @@ -53,6 +55,11 @@ async def test_init_serial(self, client, server):
server.comm_params.sslctx = None
assert server.is_server

async def test_init_source_addr(self, use_clc):
"""Test callbacks."""
_client = ModbusProtocol(use_clc, True)


async def test_connect(self, client, dummy_protocol):
"""Test properties."""
client.loop = None
Expand Down Expand Up @@ -279,13 +286,6 @@ def test_generate_ssl(self, use_clc):
class TestBasicSerial:
"""Test transport serial module."""

@staticmethod
@pytest.fixture(name="use_port")
def get_port_in_class(base_ports):
"""Return next port."""
base_ports[__class__.__name__] += 1
return base_ports[__class__.__name__]

async def test_init(self):
"""Test null modem init."""
SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
Expand Down Expand Up @@ -330,3 +330,35 @@ async def test_external_methods(self):
assert transport
assert protocol
transport.close()

async def test_serial_polling(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.read.side_effect = asyncio.CancelledError("test")
await comm.polling_task()

async def test_serial_ready(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.read.side_effect = serial.SerialException("test")
await comm.polling_task()

async def test_serial_write_ready(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.side_effect = BlockingIOError("test")
comm.write_ready()
comm.sync_serial.write.side_effect = serial.SerialException("test")
comm.write_ready()
18 changes: 9 additions & 9 deletions test/sub_transport/test_comm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Test transport."""
import asyncio
import sys
import time
from unittest import mock

import pytest

from pymodbus.logging import Log
from pymodbus.transport import (
CommType,
ModbusProtocol,
Expand Down Expand Up @@ -37,7 +37,7 @@ def get_port_in_class(base_ports):
)
async def test_connect(self, client, use_port):
"""Test connect()."""
print(f"JAN test_connect --> {use_port}", file=sys.stderr)
Log.debug("test_connect {}", use_port)
start = time.time()
assert not await client.transport_connect()
delta = time.time() - start
Expand All @@ -55,7 +55,7 @@ async def test_connect(self, client, use_port):
)
async def test_connect_not_ok(self, client, use_port):
"""Test connect()."""
print(f"JAN test_connect_not_ok --> {use_port}", file=sys.stderr)
Log.debug("test_connect_not_ok {}", use_port)
start = time.time()
assert not await client.transport_connect()
delta = time.time() - start
Expand All @@ -73,7 +73,7 @@ async def test_connect_not_ok(self, client, use_port):
)
async def test_listen(self, server, use_port):
"""Test listen()."""
print(f"JAN test_listen --> {use_port}", file=sys.stderr)
Log.debug("test_listen {}", use_port)
assert await server.transport_listen()
assert server.transport
server.transport_close()
Expand All @@ -89,7 +89,7 @@ async def test_listen(self, server, use_port):
)
async def test_listen_not_ok(self, server, use_port):
"""Test listen()."""
print(f"JAN test_listen_not_ok --> {use_port}", file=sys.stderr)
Log.debug("test_listen_not_ok {}", use_port)
assert not await server.transport_listen()
assert not server.transport
server.transport_close()
Expand All @@ -105,7 +105,7 @@ async def test_listen_not_ok(self, server, use_port):
)
async def test_connected(self, client, server, use_comm_type, use_port):
"""Test connection and data exchange."""
print(f"JAN test_connected --> {use_port}", file=sys.stderr)
Log.debug("test_connected {}", use_port)
assert await server.transport_listen()
assert await client.transport_connect()
await asyncio.sleep(0.5)
Expand Down Expand Up @@ -144,7 +144,7 @@ def wrapped_write(self, data):
)
async def test_split_serial_packet(self, client, server, use_port):
"""Test connection and data exchange."""
print(f"JAN test_split_serial_packet --> {use_port}", file=sys.stderr)
Log.debug("test_split_serial_packet {}", use_port)
assert await server.transport_listen()
assert await client.transport_connect()
await asyncio.sleep(0.5)
Expand Down Expand Up @@ -173,7 +173,7 @@ async def test_split_serial_packet(self, client, server, use_port):
)
async def test_serial_poll(self, client, server, use_port):
"""Test connection and data exchange."""
print(f"JAN test_serial_poll --> {use_port}", file=sys.stderr)
Log.debug("test_serial_poll {}", use_port)
assert await server.transport_listen()
SerialTransport.force_poll = True
assert await client.transport_connect()
Expand All @@ -200,7 +200,7 @@ async def test_serial_poll(self, client, server, use_port):
)
async def test_connected_multiple(self, client, server, use_port):
"""Test connection and data exchange."""
print(f"JAN test_connected_multiple --> {use_port}", file=sys.stderr)
Log.debug("test_connected {}", use_port)
client.comm_params.reconnect_delay = 0.0
assert await server.transport_listen()
assert await client.transport_connect()
Expand Down
5 changes: 0 additions & 5 deletions test/test_framers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def fixture_rtu_framer():
"""RTU framer."""
return ModbusRtuFramer(ClientDecoder())

@pytest.fixture(name="socket_framer")
def fixture_socket_framer():
"""Socket framer."""
return ModbusSocketFramer(ClientDecoder())

@pytest.fixture(name="ascii_framer")
def fixture_ascii_framer():
"""Ascii framer."""
Expand Down