Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions labgrid/driver/power/eaton.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def power_set(host, port, index, value):
outlet_control_oid = "{}.{}.0.{}".format(OID, cmd_id, index)

_snmp.set(outlet_control_oid, 1)
_snmp.cleanup()


def power_get(host, port, index):
Expand All @@ -24,6 +25,7 @@ def power_get(host, port, index):

value = _snmp.get(output_status_oid)

_snmp.cleanup()
if value == 1: # On
return True
if value == 0: # Off
Expand Down
2 changes: 2 additions & 0 deletions labgrid/driver/power/poe_mib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ def power_set(host, port, index, value):
oid_value = "1" if value else "2"

_snmp.set(outlet_control_oid, oid_value)
_snmp.cleanup()

def power_get(host, port, index):
_snmp = SimpleSNMP(host, 'private', port=port)
output_status_oid = "{}.{}".format(OID, index)

value = _snmp.get(output_status_oid)

_snmp.cleanup()
if value == 1: # On
return True
if value == 2: # Off
Expand Down
2 changes: 2 additions & 0 deletions labgrid/driver/power/raritan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def power_set(host, port, index, value):
outlet_control_oid = "{}.2.1.{}".format(OID, index)

_snmp.set(outlet_control_oid, str(int(value)))
_snmp.cleanup()


def power_get(host, port, index):
Expand All @@ -25,6 +26,7 @@ def power_get(host, port, index):

value = _snmp.get(output_status_oid)

_snmp.cleanup()
if value == 7: # On
return True
if value == 8: # Off
Expand Down
1 change: 1 addition & 0 deletions labgrid/driver/powerdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def cycle(self):
def get(self):
return self.backend.power_get(self._host, self._port, self.port.index)


@target_factory.reg_driver
@attr.s(eq=False)
class DigitalOutputPowerDriver(Driver, PowerResetMixin, PowerProtocol):
Expand Down
38 changes: 1 addition & 37 deletions labgrid/remote/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import argparse
import asyncio
import contextlib
from contextvars import ContextVar
import enum
import os
import pathlib
Expand Down Expand Up @@ -45,6 +44,7 @@
from ..util import diff_dict, flat_dict, dump, atomic_replace, labgrid_version, Timeout
from ..util.proxy import proxymanager
from ..util.helper import processwrapper
from ..util.loop import ensure_event_loop
from ..driver import Mode, ExecutionError
from ..logging import basicConfig, StepLogger

Expand Down Expand Up @@ -1577,42 +1577,6 @@ def print_version(self):
print(labgrid_version())


_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_loop", default=None)


def ensure_event_loop(external_loop=None):
"""Get the event loop for this thread, or create a new event loop."""
# get stashed loop
loop = _loop.get()

# ignore closed stashed loop
if loop and loop.is_closed():
loop = None

if external_loop:
# if a loop is stashed, expect it to be the same as the external one
if loop:
assert loop is external_loop
_loop.set(external_loop)
return external_loop

# return stashed loop
if loop:
return loop

try:
# if called from async code, try to get current's thread loop
loop = asyncio.get_running_loop()
except RuntimeError:
# no previous, external or running loop found, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# stash it
_loop.set(loop)
return loop


def start_session(
address: str, *, extra: Dict[str, Any] = None, debug: bool = False, loop: "asyncio.AbstractEventLoop | None" = None
):
Expand Down
43 changes: 24 additions & 19 deletions labgrid/resource/ethernetport.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,34 @@

from ..factory import target_factory
from .common import ManagedResource, ResourceManager
from ..util.loop import ensure_event_loop

@attr.s
class SNMPSwitch:
"""SNMPSwitch describes a switch accessible over SNMP. This class
implements functions to query ports and the forwarding database."""
hostname = attr.ib(validator=attr.validators.instance_of(str))
loop = attr.ib()

def __attrs_post_init__(self):
import pysnmp.hlapi.v3arch.asyncio as hlapi

self.logger = logging.getLogger(f"{self}")
self.ports = {}
self.fdb = {}
self.macs_by_port = {}
self.transport = self.loop.run_until_complete(hlapi.UdpTransportTarget.create((self.hostname, 161)))
self._autodetect()

def _autodetect(self):
from pysnmp import hlapi
import pysnmp.hlapi.v3arch.asyncio as hlapi

for (errorIndication, errorStatus, _, varBindTable) in hlapi.getCmd(
for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.getCmd(
hlapi.SnmpEngine(),
hlapi.CommunityData('public'),
hlapi.UdpTransportTarget((self.hostname, 161)),
self.transport,
hlapi.ContextData(),
hlapi.ObjectType(hlapi.ObjectIdentity('SNMPv2-MIB', 'sysDescr', 0))):
hlapi.ObjectType(hlapi.ObjectIdentity('SNMPv2-MIB', 'sysDescr', 0)))):
if errorIndication:
raise Exception(f"snmp error {errorIndication}")
elif errorStatus:
Expand All @@ -51,7 +56,7 @@ def _get_ports(self):
Returns:
Dict[Dict[]]: ports and their values
"""
from pysnmp import hlapi
import pysnmp.hlapi.v3arch.asyncio as hlapi

variables = [
(hlapi.ObjectType(hlapi.ObjectIdentity('IF-MIB', 'ifIndex')), 'index'),
Expand All @@ -64,14 +69,14 @@ def _get_ports(self):
]
ports = {}

for (errorIndication, errorStatus, _, varBindTable) in hlapi.bulkCmd(
for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.bulkCmd(
hlapi.SnmpEngine(),
hlapi.CommunityData('public'),
hlapi.UdpTransportTarget((self.hostname, 161)),
self.transport,
hlapi.ContextData(),
0, 20,
*[x[0] for x in variables],
lexicographicMode=False):
lexicographicMode=False)):
if errorIndication:
raise Exception(f"snmp error {errorIndication}")
elif errorStatus:
Expand All @@ -93,18 +98,18 @@ def _get_fdb_dot1d(self):
Returns:
Dict[List[str]]: ports and their values
"""
from pysnmp import hlapi
import pysnmp.hlapi.v3arch.asyncio as hlapi

ports = {}

for (errorIndication, errorStatus, _, varBindTable) in hlapi.bulkCmd(
for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.bulkCmd(
hlapi.SnmpEngine(),
hlapi.CommunityData('public'),
hlapi.UdpTransportTarget((self.hostname, 161)),
self.transport,
hlapi.ContextData(),
0, 50,
hlapi.ObjectType(hlapi.ObjectIdentity('BRIDGE-MIB', 'dot1dTpFdbPort')),
lexicographicMode=False):
lexicographicMode=False)):
if errorIndication:
raise Exception(f"snmp error {errorIndication}")
elif errorStatus:
Expand All @@ -126,18 +131,18 @@ def _get_fdb_dot1q(self):
Returns:
Dict[List[str]]: ports and their values
"""
from pysnmp import hlapi
import pysnmp.hlapi.v3arch.asyncio as hlapi

ports = {}

for (errorIndication, errorStatus, _, varBindTable) in hlapi.bulkCmd(
for (errorIndication, errorStatus, _, varBindTable) in self.loop.run_until_complete(hlapi.bulkCmd(
hlapi.SnmpEngine(),
hlapi.CommunityData('public'),
hlapi.UdpTransportTarget((self.hostname, 161)),
self.transport,
hlapi.ContextData(),
0, 50,
hlapi.ObjectType(hlapi.ObjectIdentity('Q-BRIDGE-MIB', 'dot1qTpFdbPort')),
lexicographicMode=False):
lexicographicMode=False)):
if errorIndication:
raise Exception(f"snmp error {errorIndication}")
elif errorStatus:
Expand Down Expand Up @@ -223,14 +228,16 @@ async def poll_neighbour(self):

await asyncio.sleep(1.0)

self.loop = ensure_event_loop()

async def poll_switches(self):
current = set(resource.switch for resource in self.resources)
removed = set(self.switches) - current
new = current - set(self.switches)
for switch in removed:
del self.switches[switch]
for switch in new:
self.switches[switch] = SNMPSwitch(switch)
self.switches[switch] = SNMPSwitch(switch, self.loop)
for switch in current:
self.switches[switch].update()
await asyncio.sleep(1.0)
Expand All @@ -248,7 +255,6 @@ async def poll(self, handler):
import traceback
traceback.print_exc(file=sys.stderr)

self.loop = asyncio.get_event_loop()
self.poll_tasks.append(self.loop.create_task(poll(self, poll_neighbour)))
self.poll_tasks.append(self.loop.create_task(poll(self, poll_switches)))

Expand Down Expand Up @@ -309,7 +315,6 @@ def poll(self):
resource.extra = extra
self.logger.debug("new information for %s: %s", resource, extra)


@target_factory.reg_resource
@attr.s
class SNMPEthernetPort(ManagedResource):
Expand Down
45 changes: 45 additions & 0 deletions labgrid/util/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import asyncio
from contextvars import ContextVar


_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_loop", default=None)
_last_created_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_last_created_loop:", default=None)

def ensure_event_loop(external_loop=None):
"""Get the event loop for this thread, or create a new event loop."""
# get stashed loop
loop = _loop.get()

# ignore closed stashed loop
if loop and loop.is_closed():
loop = None

if external_loop:
# if a loop is stashed, expect it to be the same as the external one
if loop:
assert loop is external_loop
_loop.set(external_loop)
return external_loop

# return stashed loop
if loop:
return loop

try:
# if called from async code, try to get current's thread loop
loop = asyncio.get_running_loop()
except RuntimeError:
# no previous, external or running loop found, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# stash new loop
_last_created_loop.set(loop)

# stash loop
_loop.set(loop)
return loop

def is_new_loop(loop):
"""Check whether the given loop was created in ensure_event_loop() before."""
return loop is _last_created_loop.get()
42 changes: 31 additions & 11 deletions labgrid/util/snmp.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,51 @@
from pysnmp import hlapi
import pysnmp.hlapi.v3arch.asyncio as hlapi
from ..driver.exception import ExecutionError
from .loop import ensure_event_loop, is_new_loop


class SimpleSNMP:
"""A class that helps wrap pysnmp"""

def __init__(self, host, community, port=161):
if port is None:
port = 161

self.loop = ensure_event_loop()

self.engine = hlapi.SnmpEngine()
self.transport = hlapi.UdpTransportTarget((host, port))
self.transport = self.loop.run_until_complete(hlapi.UdpTransportTarget.create((host, port)))
self.community = hlapi.CommunityData(community, mpModel=0)
self.context = hlapi.ContextData()

def get(self, oid):
g = hlapi.getCmd(self.engine, self.community, self.transport,
self.context, hlapi.ObjectType(hlapi.ObjectIdentity(oid)),
lookupMib=False)
g = self.loop.run_until_complete(
hlapi.getCmd(
self.engine,
self.community,
self.transport,
self.context,
hlapi.ObjectType(hlapi.ObjectIdentity(oid)),
lookupMib=False,
)
)

error_indication, error_status, _, res = next(g)
error_indication, error_status, _, res = g
if error_indication or error_status:
raise ExecutionError("Failed to get SNMP value.")
return res[0][1]

def set(self, oid, value):
identify = hlapi.ObjectType(hlapi.ObjectIdentity(oid),
hlapi.Integer(value))
g = hlapi.setCmd(self.engine, self.community, self.transport,
self.context, identify, lookupMib=False)
next(g)
identify = hlapi.ObjectType(hlapi.ObjectIdentity(oid), hlapi.Integer(value))
g = self.loop.run_until_complete(
hlapi.setCmd(self.engine, self.community, self.transport, self.context, identify, lookupMib=False)
)

error_indication, error_status, _, _ = g
if error_indication or error_status:
raise ExecutionError("Failed to set SNMP value.")

def cleanup(self):
self.engine.closeDispatcher()
if is_new_loop(self.loop):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ pyvisa = [
"pyvisa>=1.11.3",
"PyVISA-py>=0.5.2",
]
snmp = [
"pysnmp>=4.4.12, <6",
"pyasn1<0.6.1",
]
snmp = ["pysnmp>=6"]
vxi11 = ["python-vxi11>=0.9"]
xena = ["xenavalkyrie>=3.0.1"]
deb = ["labgrid[modbus,onewire,snmp]"]
Expand Down