From 01c4d67fc6d93701ba7a89b84c5ea9898bab789b Mon Sep 17 00:00:00 2001 From: Bastian Krause Date: Mon, 22 Sep 2025 17:27:17 +0200 Subject: [PATCH 1/5] remote/client: move ensure_event_loop() to labgrid.util.loop Moving the function and the ContextVar to labgrid.util.loop will allow non-client code to use it as well. Otherwise usage will result in circular imports such as: from labgrid import Target labgrid/__init__.py:1: in from .target import Target labgrid/target.py:10: in from .driver import Driver labgrid/driver/__init__.py:4: in from .serialdriver import SerialDriver labgrid/driver/serialdriver.py:10: in from ..util.proxy import proxymanager labgrid/util/proxy.py:7: in from ..resource.common import Resource labgrid/resource/__init__.py:2: in from .ethernetport import SNMPEthernetPort labgrid/resource/ethernetport.py:9: in from ..remote.client import ensure_event_loop labgrid/remote/client.py:41: in from .. import Environment, Target, target_factory E ImportError: cannot import name 'Environment' from partially initialized module 'labgrid' (most likely due to a circular import) (labgrid/__init__.py) Signed-off-by: Bastian Krause --- labgrid/remote/client.py | 38 +------------------------------------- labgrid/util/loop.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 37 deletions(-) create mode 100755 labgrid/util/loop.py diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index 27108a7b4..737da4d81 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -4,7 +4,6 @@ import argparse import asyncio import contextlib -from contextvars import ContextVar import enum import os import pathlib @@ -45,6 +44,7 @@ from ..util import diff_dict, flat_dict, dump, atomic_replace, labgrid_version, Timeout from ..util.proxy import proxymanager from ..util.helper import processwrapper +from ..util.loop import ensure_event_loop from ..driver import Mode, ExecutionError from ..logging import basicConfig, StepLogger @@ -1577,42 +1577,6 @@ def print_version(self): print(labgrid_version()) -_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_loop", default=None) - - -def ensure_event_loop(external_loop=None): - """Get the event loop for this thread, or create a new event loop.""" - # get stashed loop - loop = _loop.get() - - # ignore closed stashed loop - if loop and loop.is_closed(): - loop = None - - if external_loop: - # if a loop is stashed, expect it to be the same as the external one - if loop: - assert loop is external_loop - _loop.set(external_loop) - return external_loop - - # return stashed loop - if loop: - return loop - - try: - # if called from async code, try to get current's thread loop - loop = asyncio.get_running_loop() - except RuntimeError: - # no previous, external or running loop found, create a new one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # stash it - _loop.set(loop) - return loop - - def start_session( address: str, *, extra: Dict[str, Any] = None, debug: bool = False, loop: "asyncio.AbstractEventLoop | None" = None ): diff --git a/labgrid/util/loop.py b/labgrid/util/loop.py new file mode 100755 index 000000000..eb27c424e --- /dev/null +++ b/labgrid/util/loop.py @@ -0,0 +1,37 @@ +import asyncio +from contextvars import ContextVar + + +_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_loop", default=None) + +def ensure_event_loop(external_loop=None): + """Get the event loop for this thread, or create a new event loop.""" + # get stashed loop + loop = _loop.get() + + # ignore closed stashed loop + if loop and loop.is_closed(): + loop = None + + if external_loop: + # if a loop is stashed, expect it to be the same as the external one + if loop: + assert loop is external_loop + _loop.set(external_loop) + return external_loop + + # return stashed loop + if loop: + return loop + + try: + # if called from async code, try to get current's thread loop + loop = asyncio.get_running_loop() + except RuntimeError: + # no previous, external or running loop found, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # stash it + _loop.set(loop) + return loop From 0afee1fdfbaf3459295af3e93471ac873f49c31b Mon Sep 17 00:00:00 2001 From: Bastian Krause Date: Mon, 22 Sep 2025 17:27:58 +0200 Subject: [PATCH 2/5] util/loop: add is_new_loop() There is currently no way to find out whether ensure_event_loop() created a new loop. This information is useful in scenarios where a resource is created without a session in a sync context, because then it needs to shut down the newly created loop in clean up code. But this must not happen if a running loop is used. Since ensure_event_loop() is already public API people use, add a separate function checking whether the passed in loop is identical to the loop stashed in the loop creation code in ensure_event_loop(). Signed-off-by: Bastian Krause --- labgrid/util/loop.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/labgrid/util/loop.py b/labgrid/util/loop.py index eb27c424e..ca438e92d 100755 --- a/labgrid/util/loop.py +++ b/labgrid/util/loop.py @@ -3,6 +3,7 @@ _loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_loop", default=None) +_last_created_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_last_created_loop:", default=None) def ensure_event_loop(external_loop=None): """Get the event loop for this thread, or create a new event loop.""" @@ -32,6 +33,13 @@ def ensure_event_loop(external_loop=None): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - # stash it + # stash new loop + _last_created_loop.set(loop) + + # stash loop _loop.set(loop) return loop + +def is_new_loop(loop): + """Check whether the given loop was created in ensure_event_loop() before.""" + return loop is _last_created_loop.get() From ef0135ce4830b1fe5e98de1e8175c552dae73cca Mon Sep 17 00:00:00 2001 From: Rouven Czerwinski Date: Tue, 24 Sep 2024 09:42:28 +0200 Subject: [PATCH 3/5] util: convert SimpleSNMP to asyncio Wrap the now asyncio based co-routines in loop.run_until_complete(). Signed-off-by: Rouven Czerwinski Signed-off-by: Bastian Krause --- labgrid/driver/power/eaton.py | 2 ++ labgrid/driver/power/poe_mib.py | 2 ++ labgrid/driver/power/raritan.py | 2 ++ labgrid/driver/powerdriver.py | 1 + labgrid/util/snmp.py | 42 ++++++++++++++++++++++++--------- 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/labgrid/driver/power/eaton.py b/labgrid/driver/power/eaton.py index f93d8f9f3..8d682fcdc 100644 --- a/labgrid/driver/power/eaton.py +++ b/labgrid/driver/power/eaton.py @@ -14,6 +14,7 @@ def power_set(host, port, index, value): outlet_control_oid = "{}.{}.0.{}".format(OID, cmd_id, index) _snmp.set(outlet_control_oid, 1) + _snmp.cleanup() def power_get(host, port, index): @@ -24,6 +25,7 @@ def power_get(host, port, index): value = _snmp.get(output_status_oid) + _snmp.cleanup() if value == 1: # On return True if value == 0: # Off diff --git a/labgrid/driver/power/poe_mib.py b/labgrid/driver/power/poe_mib.py index 12b159e36..a4d705bf7 100644 --- a/labgrid/driver/power/poe_mib.py +++ b/labgrid/driver/power/poe_mib.py @@ -12,6 +12,7 @@ def power_set(host, port, index, value): oid_value = "1" if value else "2" _snmp.set(outlet_control_oid, oid_value) + _snmp.cleanup() def power_get(host, port, index): _snmp = SimpleSNMP(host, 'private', port=port) @@ -19,6 +20,7 @@ def power_get(host, port, index): value = _snmp.get(output_status_oid) + _snmp.cleanup() if value == 1: # On return True if value == 2: # Off diff --git a/labgrid/driver/power/raritan.py b/labgrid/driver/power/raritan.py index 597e72d0b..0103e4ec7 100644 --- a/labgrid/driver/power/raritan.py +++ b/labgrid/driver/power/raritan.py @@ -17,6 +17,7 @@ def power_set(host, port, index, value): outlet_control_oid = "{}.2.1.{}".format(OID, index) _snmp.set(outlet_control_oid, str(int(value))) + _snmp.cleanup() def power_get(host, port, index): @@ -25,6 +26,7 @@ def power_get(host, port, index): value = _snmp.get(output_status_oid) + _snmp.cleanup() if value == 7: # On return True if value == 8: # Off diff --git a/labgrid/driver/powerdriver.py b/labgrid/driver/powerdriver.py index 80c8377fb..8046a5aeb 100644 --- a/labgrid/driver/powerdriver.py +++ b/labgrid/driver/powerdriver.py @@ -232,6 +232,7 @@ def cycle(self): def get(self): return self.backend.power_get(self._host, self._port, self.port.index) + @target_factory.reg_driver @attr.s(eq=False) class DigitalOutputPowerDriver(Driver, PowerResetMixin, PowerProtocol): diff --git a/labgrid/util/snmp.py b/labgrid/util/snmp.py index 51ddaa6ef..dca68bb5f 100644 --- a/labgrid/util/snmp.py +++ b/labgrid/util/snmp.py @@ -1,31 +1,51 @@ -from pysnmp import hlapi +import pysnmp.hlapi.v3arch.asyncio as hlapi from ..driver.exception import ExecutionError +from .loop import ensure_event_loop, is_new_loop class SimpleSNMP: """A class that helps wrap pysnmp""" + def __init__(self, host, community, port=161): if port is None: port = 161 + self.loop = ensure_event_loop() + self.engine = hlapi.SnmpEngine() - self.transport = hlapi.UdpTransportTarget((host, port)) + self.transport = self.loop.run_until_complete(hlapi.UdpTransportTarget.create((host, port))) self.community = hlapi.CommunityData(community, mpModel=0) self.context = hlapi.ContextData() def get(self, oid): - g = hlapi.getCmd(self.engine, self.community, self.transport, - self.context, hlapi.ObjectType(hlapi.ObjectIdentity(oid)), - lookupMib=False) + g = self.loop.run_until_complete( + hlapi.getCmd( + self.engine, + self.community, + self.transport, + self.context, + hlapi.ObjectType(hlapi.ObjectIdentity(oid)), + lookupMib=False, + ) + ) - error_indication, error_status, _, res = next(g) + error_indication, error_status, _, res = g if error_indication or error_status: raise ExecutionError("Failed to get SNMP value.") return res[0][1] def set(self, oid, value): - identify = hlapi.ObjectType(hlapi.ObjectIdentity(oid), - hlapi.Integer(value)) - g = hlapi.setCmd(self.engine, self.community, self.transport, - self.context, identify, lookupMib=False) - next(g) + identify = hlapi.ObjectType(hlapi.ObjectIdentity(oid), hlapi.Integer(value)) + g = self.loop.run_until_complete( + hlapi.setCmd(self.engine, self.community, self.transport, self.context, identify, lookupMib=False) + ) + + error_indication, error_status, _, _ = g + if error_indication or error_status: + raise ExecutionError("Failed to set SNMP value.") + + def cleanup(self): + self.engine.closeDispatcher() + if is_new_loop(self.loop): + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() From 019daad36e21a5015e2fa7d85e9a72cc245269fd Mon Sep 17 00:00:00 2001 From: Rouven Czerwinski Date: Tue, 24 Sep 2024 09:43:05 +0200 Subject: [PATCH 4/5] pyproject: allow pysnmp newer than 6.x With the new asyncio based handling we can unlock the pysnmp dependency. The lexstudio fork has also taken over maintenance of pysnmp on pypi, so switch back to pysnmp. Signed-off-by: Rouven Czerwinski --- pyproject.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 33caa6f31..d11cd8b86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,10 +68,7 @@ pyvisa = [ "pyvisa>=1.11.3", "PyVISA-py>=0.5.2", ] -snmp = [ - "pysnmp>=4.4.12, <6", - "pyasn1<0.6.1", -] +snmp = ["pysnmp>=6"] vxi11 = ["python-vxi11>=0.9"] xena = ["xenavalkyrie>=3.0.1"] deb = ["labgrid[modbus,onewire,snmp]"] From 9c7cb236b32a5298d85e9e278c7a79dee2420261 Mon Sep 17 00:00:00 2001 From: Rouven Czerwinski Date: Tue, 24 Sep 2024 12:15:58 +0200 Subject: [PATCH 5/5] resouce: ethernetport: convert to asyncio Follow the sync API deprecation and use ensure_event_loop() making sure we have an asyncio loop even if no loop can be retrieved. Signed-off-by: Rouven Czerwinski Signed-off-by: Bastian Krause --- labgrid/resource/ethernetport.py | 43 ++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/labgrid/resource/ethernetport.py b/labgrid/resource/ethernetport.py index 0020f26ba..cd7496410 100644 --- a/labgrid/resource/ethernetport.py +++ b/labgrid/resource/ethernetport.py @@ -6,29 +6,34 @@ from ..factory import target_factory from .common import ManagedResource, ResourceManager +from ..util.loop import ensure_event_loop @attr.s class SNMPSwitch: """SNMPSwitch describes a switch accessible over SNMP. This class implements functions to query ports and the forwarding database.""" hostname = attr.ib(validator=attr.validators.instance_of(str)) + loop = attr.ib() def __attrs_post_init__(self): + import pysnmp.hlapi.v3arch.asyncio as hlapi + self.logger = logging.getLogger(f"{self}") self.ports = {} self.fdb = {} self.macs_by_port = {} + self.transport = self.loop.run_until_complete(hlapi.UdpTransportTarget.create((self.hostname, 161))) self._autodetect() def _autodetect(self): - from pysnmp import hlapi + import pysnmp.hlapi.v3arch.asyncio as hlapi - for (errorIndication, errorStatus, _, varBindTable) in hlapi.getCmd( + for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.getCmd( hlapi.SnmpEngine(), hlapi.CommunityData('public'), - hlapi.UdpTransportTarget((self.hostname, 161)), + self.transport, hlapi.ContextData(), - hlapi.ObjectType(hlapi.ObjectIdentity('SNMPv2-MIB', 'sysDescr', 0))): + hlapi.ObjectType(hlapi.ObjectIdentity('SNMPv2-MIB', 'sysDescr', 0)))): if errorIndication: raise Exception(f"snmp error {errorIndication}") elif errorStatus: @@ -51,7 +56,7 @@ def _get_ports(self): Returns: Dict[Dict[]]: ports and their values """ - from pysnmp import hlapi + import pysnmp.hlapi.v3arch.asyncio as hlapi variables = [ (hlapi.ObjectType(hlapi.ObjectIdentity('IF-MIB', 'ifIndex')), 'index'), @@ -64,14 +69,14 @@ def _get_ports(self): ] ports = {} - for (errorIndication, errorStatus, _, varBindTable) in hlapi.bulkCmd( + for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.bulkCmd( hlapi.SnmpEngine(), hlapi.CommunityData('public'), - hlapi.UdpTransportTarget((self.hostname, 161)), + self.transport, hlapi.ContextData(), 0, 20, *[x[0] for x in variables], - lexicographicMode=False): + lexicographicMode=False)): if errorIndication: raise Exception(f"snmp error {errorIndication}") elif errorStatus: @@ -93,18 +98,18 @@ def _get_fdb_dot1d(self): Returns: Dict[List[str]]: ports and their values """ - from pysnmp import hlapi + import pysnmp.hlapi.v3arch.asyncio as hlapi ports = {} - for (errorIndication, errorStatus, _, varBindTable) in hlapi.bulkCmd( + for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.bulkCmd( hlapi.SnmpEngine(), hlapi.CommunityData('public'), - hlapi.UdpTransportTarget((self.hostname, 161)), + self.transport, hlapi.ContextData(), 0, 50, hlapi.ObjectType(hlapi.ObjectIdentity('BRIDGE-MIB', 'dot1dTpFdbPort')), - lexicographicMode=False): + lexicographicMode=False)): if errorIndication: raise Exception(f"snmp error {errorIndication}") elif errorStatus: @@ -126,18 +131,18 @@ def _get_fdb_dot1q(self): Returns: Dict[List[str]]: ports and their values """ - from pysnmp import hlapi + import pysnmp.hlapi.v3arch.asyncio as hlapi ports = {} - for (errorIndication, errorStatus, _, varBindTable) in hlapi.bulkCmd( + for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.bulkCmd( hlapi.SnmpEngine(), hlapi.CommunityData('public'), - hlapi.UdpTransportTarget((self.hostname, 161)), + self.transport, hlapi.ContextData(), 0, 50, hlapi.ObjectType(hlapi.ObjectIdentity('Q-BRIDGE-MIB', 'dot1qTpFdbPort')), - lexicographicMode=False): + lexicographicMode=False)): if errorIndication: raise Exception(f"snmp error {errorIndication}") elif errorStatus: @@ -223,6 +228,8 @@ async def poll_neighbour(self): await asyncio.sleep(1.0) + self.loop = ensure_event_loop() + async def poll_switches(self): current = set(resource.switch for resource in self.resources) removed = set(self.switches) - current @@ -230,7 +237,7 @@ async def poll_switches(self): for switch in removed: del self.switches[switch] for switch in new: - self.switches[switch] = SNMPSwitch(switch) + self.switches[switch] = SNMPSwitch(switch, self.loop) for switch in current: self.switches[switch].update() await asyncio.sleep(1.0) @@ -248,7 +255,6 @@ async def poll(self, handler): import traceback traceback.print_exc(file=sys.stderr) - self.loop = asyncio.get_event_loop() self.poll_tasks.append(self.loop.create_task(poll(self, poll_neighbour))) self.poll_tasks.append(self.loop.create_task(poll(self, poll_switches))) @@ -309,7 +315,6 @@ def poll(self): resource.extra = extra self.logger.debug("new information for %s: %s", resource, extra) - @target_factory.reg_resource @attr.s class SNMPEthernetPort(ManagedResource):