|
1 | 1 | """Configure pytest.""" |
2 | 2 | import asyncio |
| 3 | +import os |
3 | 4 | import platform |
4 | 5 | import sys |
5 | 6 | from collections import deque |
6 | 7 | from threading import enumerate as thread_enumerate |
| 8 | +from unittest import mock |
7 | 9 |
|
8 | 10 | import pytest |
| 11 | +import pytest_asyncio |
9 | 12 |
|
10 | 13 | from pymodbus.datastore import ModbusBaseSlaveContext |
| 14 | +from pymodbus.server import ServerAsyncStop |
| 15 | +from pymodbus.transport import NULLMODEM_HOST, CommParams, CommType, ModbusProtocol |
11 | 16 | from pymodbus.transport.transport import NullModem |
12 | 17 |
|
13 | 18 |
|
| 19 | +sys.path.extend(["examples", "../examples", "../../examples"]) |
| 20 | + |
| 21 | +from examples.server_async import ( # noqa: E402 # pylint: disable=wrong-import-position |
| 22 | + run_async_server, |
| 23 | + setup_server, |
| 24 | +) |
| 25 | + |
| 26 | + |
14 | 27 | def pytest_configure(): |
15 | 28 | """Configure pytest.""" |
16 | 29 | pytest.IS_DARWIN = platform.system().lower() == "darwin" |
@@ -42,6 +55,188 @@ def get_base_ports(): |
42 | 55 | return BASE_PORTS |
43 | 56 |
|
44 | 57 |
|
| 58 | +@pytest.fixture(name="use_comm_type") |
| 59 | +def prepare_dummy_use_comm_type(): |
| 60 | + """Return default comm_type.""" |
| 61 | + return CommType.TCP |
| 62 | + |
| 63 | + |
| 64 | +@pytest.fixture(name="use_host") |
| 65 | +def define_use_host(): |
| 66 | + """Set default host.""" |
| 67 | + return NULLMODEM_HOST |
| 68 | + |
| 69 | + |
| 70 | +@pytest.fixture(name="use_cls") |
| 71 | +def prepare_commparams_server(use_port, use_host, use_comm_type): |
| 72 | + """Prepare CommParamsClass object.""" |
| 73 | + if use_host == NULLMODEM_HOST and use_comm_type == CommType.SERIAL: |
| 74 | + use_host = f"{NULLMODEM_HOST}:{use_port}" |
| 75 | + return CommParams( |
| 76 | + comm_name="test comm", |
| 77 | + comm_type=use_comm_type, |
| 78 | + reconnect_delay=0, |
| 79 | + reconnect_delay_max=0, |
| 80 | + timeout_connect=0, |
| 81 | + source_address=(use_host, use_port), |
| 82 | + baudrate=9600, |
| 83 | + bytesize=8, |
| 84 | + parity="E", |
| 85 | + stopbits=2, |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +@pytest.fixture(name="use_clc") |
| 90 | +def prepare_commparams_client(use_port, use_host, use_comm_type): |
| 91 | + """Prepare CommParamsClass object.""" |
| 92 | + if use_host == NULLMODEM_HOST and use_comm_type == CommType.SERIAL: |
| 93 | + use_host = f"{NULLMODEM_HOST}:{use_port}" |
| 94 | + timeout = 10 if not pytest.IS_WINDOWS else 2 |
| 95 | + return CommParams( |
| 96 | + comm_name="test comm", |
| 97 | + comm_type=use_comm_type, |
| 98 | + reconnect_delay=0.1, |
| 99 | + reconnect_delay_max=0.35, |
| 100 | + timeout_connect=timeout, |
| 101 | + host=use_host, |
| 102 | + port=use_port, |
| 103 | + baudrate=9600, |
| 104 | + bytesize=8, |
| 105 | + parity="E", |
| 106 | + stopbits=2, |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +@pytest.fixture(name="client") |
| 111 | +def prepare_protocol(use_clc): |
| 112 | + """Prepare transport object.""" |
| 113 | + transport = ModbusProtocol(use_clc, False) |
| 114 | + transport.callback_connected = mock.Mock() |
| 115 | + transport.callback_disconnected = mock.Mock() |
| 116 | + transport.callback_data = mock.Mock(return_value=0) |
| 117 | + if use_clc.comm_type == CommType.TLS: |
| 118 | + cwd = os.path.dirname(__file__) + "/../examples/certificates/pymodbus." |
| 119 | + transport.comm_params.sslctx = use_clc.generate_ssl( |
| 120 | + False, certfile=cwd + "crt", keyfile=cwd + "key" |
| 121 | + ) |
| 122 | + if use_clc.comm_type == CommType.SERIAL: |
| 123 | + transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}" |
| 124 | + return transport |
| 125 | + |
| 126 | + |
| 127 | +@pytest.fixture(name="server") |
| 128 | +def prepare_transport_server(use_cls): |
| 129 | + """Prepare transport object.""" |
| 130 | + transport = ModbusProtocol(use_cls, True) |
| 131 | + transport.callback_connected = mock.Mock() |
| 132 | + transport.callback_disconnected = mock.Mock() |
| 133 | + transport.callback_data = mock.Mock(return_value=0) |
| 134 | + if use_cls.comm_type == CommType.TLS: |
| 135 | + cwd = os.path.dirname(__file__) + "/../examples/certificates/pymodbus." |
| 136 | + transport.comm_params.sslctx = use_cls.generate_ssl( |
| 137 | + True, certfile=cwd + "crt", keyfile=cwd + "key" |
| 138 | + ) |
| 139 | + return transport |
| 140 | + |
| 141 | + |
| 142 | +class DummyProtocol(ModbusProtocol): |
| 143 | + """Use in connection_made calls.""" |
| 144 | + |
| 145 | + def __init__(self, is_server=False): # pylint: disable=super-init-not-called |
| 146 | + """Initialize.""" |
| 147 | + self.comm_params = CommParams() |
| 148 | + self.transport = None |
| 149 | + self.is_server = is_server |
| 150 | + self.is_closing = False |
| 151 | + self.data = b"" |
| 152 | + self.connection_made = mock.Mock() |
| 153 | + self.connection_lost = mock.Mock() |
| 154 | + self.reconnect_task: asyncio.Task = None |
| 155 | + |
| 156 | + def handle_new_connection(self): |
| 157 | + """Handle incoming connect.""" |
| 158 | + if not self.is_server: |
| 159 | + # Clients reuse the same object. |
| 160 | + return self |
| 161 | + return DummyProtocol() |
| 162 | + |
| 163 | + def close(self): |
| 164 | + """Simulate close.""" |
| 165 | + self.is_closing = True |
| 166 | + |
| 167 | + def data_received(self, data): |
| 168 | + """Call when some data is received.""" |
| 169 | + self.data += data |
| 170 | + |
| 171 | + |
| 172 | +@pytest.fixture(name="dummy_protocol") |
| 173 | +def prepare_dummy_protocol(): |
| 174 | + """Return transport object.""" |
| 175 | + return DummyProtocol |
| 176 | + |
| 177 | + |
| 178 | +@pytest.fixture(name="mock_clc") |
| 179 | +def define_commandline_client( |
| 180 | + use_comm, |
| 181 | + use_framer, |
| 182 | + use_port, |
| 183 | + use_host, |
| 184 | +): |
| 185 | + """Define commandline.""" |
| 186 | + my_port = str(use_port) |
| 187 | + cmdline = ["--comm", use_comm, "--framer", use_framer, "--timeout", "0.1"] |
| 188 | + if use_comm == "serial": |
| 189 | + if use_host == NULLMODEM_HOST: |
| 190 | + use_host = f"{use_host}:{my_port}" |
| 191 | + else: |
| 192 | + use_host = f"socket://{use_host}:{my_port}" |
| 193 | + cmdline.extend(["--baudrate", "9600", "--port", use_host]) |
| 194 | + else: |
| 195 | + cmdline.extend(["--port", my_port, "--host", use_host]) |
| 196 | + return cmdline |
| 197 | + |
| 198 | + |
| 199 | +@pytest.fixture(name="mock_cls") |
| 200 | +def define_commandline_server( |
| 201 | + use_comm, |
| 202 | + use_framer, |
| 203 | + use_port, |
| 204 | + use_host, |
| 205 | +): |
| 206 | + """Define commandline.""" |
| 207 | + my_port = str(use_port) |
| 208 | + cmdline = [ |
| 209 | + "--comm", |
| 210 | + use_comm, |
| 211 | + "--framer", |
| 212 | + use_framer, |
| 213 | + ] |
| 214 | + if use_comm == "serial": |
| 215 | + if use_host == NULLMODEM_HOST: |
| 216 | + use_host = f"{use_host}:{my_port}" |
| 217 | + else: |
| 218 | + use_host = f"socket://{use_host}:{my_port}" |
| 219 | + cmdline.extend(["--baudrate", "9600", "--port", use_host]) |
| 220 | + else: |
| 221 | + cmdline.extend(["--port", my_port, "--host", use_host]) |
| 222 | + return cmdline |
| 223 | + |
| 224 | + |
| 225 | +@pytest_asyncio.fixture(name="mock_server") |
| 226 | +async def _run_server( |
| 227 | + mock_cls, |
| 228 | +): |
| 229 | + """Run server.""" |
| 230 | + run_args = setup_server(cmdline=mock_cls) |
| 231 | + task = asyncio.create_task(run_async_server(run_args)) |
| 232 | + task.set_name("mock_server") |
| 233 | + await asyncio.sleep(0.1) |
| 234 | + yield mock_cls |
| 235 | + await ServerAsyncStop() |
| 236 | + task.cancel() |
| 237 | + await task |
| 238 | + |
| 239 | + |
45 | 240 | @pytest.fixture(name="system_health_check", autouse=True) |
46 | 241 | async def _check_system_health(): |
47 | 242 | """Check Thread, asyncio.task and NullModem for leftovers.""" |
@@ -193,14 +388,3 @@ def sendto(self, msg, *_args): |
193 | 388 | def setblocking(self, _flag): |
194 | 389 | """Set blocking.""" |
195 | 390 | return None |
196 | | - |
197 | | - |
198 | | -_CURRENT_PORT = 5200 |
199 | | - |
200 | | - |
201 | | -@pytest.fixture(name="use_port") |
202 | | -def get_port(): |
203 | | - """Get next port.""" |
204 | | - global _CURRENT_PORT # pylint: disable=global-statement |
205 | | - _CURRENT_PORT += 1 |
206 | | - return _CURRENT_PORT |
|
0 commit comments