From 20422092a8317444ab6132f01b7ef76381af39dc Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 19 Nov 2024 16:07:11 -0600 Subject: [PATCH] Replace more string with types. --- makefile | 17 +- pynetsnmp/errors.py | 61 +++++++ pynetsnmp/netsnmp.py | 25 +-- pynetsnmp/oids.py | 94 ++++++++++ pynetsnmp/security.py | 118 ++++++++---- pynetsnmp/twistedsnmp.py | 277 ++++++++++++---------------- pynetsnmp/usm.py | 17 +- tests/__init__.py | 0 tests/test_security.py | 382 +++++++++++++++++++++++++++++++++++++++ tests/test_usm.py | 114 ++++++++++++ 10 files changed, 874 insertions(+), 231 deletions(-) create mode 100644 pynetsnmp/errors.py create mode 100644 pynetsnmp/oids.py create mode 100644 tests/__init__.py create mode 100644 tests/test_security.py create mode 100644 tests/test_usm.py diff --git a/makefile b/makefile index 2956abe..ef8b6ac 100644 --- a/makefile +++ b/makefile @@ -25,11 +25,14 @@ clean: rm -rf *.pyc dist build pynetsnmp.egg-info .PHONY: test -HOST ?= 127.0.0.1 test: - docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) \ - bash -c "python setup.py bdist_wheel \ - && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ - && cd test \ - && python test_runner.py --host $(HOST) \ - && chown -R $(UID):$(GID) /mnt" ; + @$(DOCKER_COMMAND) bash -c "pip --no-python-version-warning install -q .; cd tests; python -m unittest discover" + +# HOST ?= 127.0.0.1 +# test: +# docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) \ +# bash -c "python setup.py bdist_wheel \ +# && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ +# && cd test \ +# && python test_runner.py --host $(HOST) \ +# && chown -R $(UID):$(GID) /mnt" ; diff --git a/pynetsnmp/errors.py b/pynetsnmp/errors.py new file mode 100644 index 0000000..6cc73e2 --- /dev/null +++ b/pynetsnmp/errors.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import + +from . import oids + + +class SnmpTimeoutError(Exception): + pass + + +class ArgumentParseError(Exception): + pass + + +class TransportError(Exception): + pass + + +class SnmpNameError(Exception): + def __init__(self, oid): + Exception.__init__(self, "Bad Name", oid) + + +class SnmpError(Exception): + def __init__(self, message, *args, **kwargs): + self.message = message + + def __str__(self): + return self.message + + def __repr__(self): + return self.message + + +class SnmpUsmError(SnmpError): + pass + + +class SnmpUsmStatsError(SnmpUsmError): + def __init__(self, mesg, oid): + super(SnmpUsmStatsError, self).__init__(mesg) + self.oid = oid + + +_stats_oid_error_map = { + oids.WrongDigest: SnmpUsmStatsError( + "unexpected authentication digest", oids.WrongDigest + ), + oids.UnknownUserName: SnmpUsmStatsError( + "unknown user", oids.UnknownUserName + ), + oids.UnknownSecurityLevel: SnmpUsmStatsError( + "unknown or unavailable security level", oids.UnknownSecurityLevel + ), + oids.DecryptionError: SnmpUsmStatsError( + "privacy decryption error", oids.DecryptionError + ), +} + + +def get_stats_error(oid): + return _stats_oid_error_map.get(oid) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 4240a94..f893ae5 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -87,6 +87,7 @@ USM_AUTH_KU_LEN, USM_PRIV_KU_LEN, ) +from .errors import ArgumentParseError, SnmpTimeoutError def _getLogger(name): @@ -304,7 +305,7 @@ class netsnmp_container_s(Structure): ("securityAuthLocalKeyLen", c_size_t), ("securityPrivProto", POINTER(oid)), ("securityPrivProtoLen", c_size_t), - ("securityPrivKey", c_char * USM_PRIV_KU_LEN), + ("securityPrivKey", u_char * USM_PRIV_KU_LEN), ("securityPrivKeyLen", c_size_t), ("securityPrivLocalKey", c_char_p), ("securityPrivLocalKeyLen", c_size_t), @@ -638,16 +639,12 @@ def getResult(pdu, log): return result -class SnmpError(Exception): +class NetSnmpError(Exception): def __init__(self, why): lib.snmp_perror(why) Exception.__init__(self, why) -class SnmpTimeoutError(Exception): - pass - - sessionMap = {} @@ -676,14 +673,6 @@ def _callback(operation, sp, reqid, pdu, magic): _callback = netsnmp_callback(_callback) -class ArgumentParseError(Exception): - pass - - -class TransportError(Exception): - pass - - def _doNothingProc(argc, argv, arg): return 0 @@ -807,7 +796,7 @@ def open(self): ref = byref(sess) self.sess = lib.snmp_open(ref) if not self.sess: - raise SnmpError("snmp_open") + raise NetSnmpError("snmp_open") def awaitTraps( self, peername, fileno=-1, pre_parse_callback=None, debug=False @@ -834,7 +823,7 @@ def awaitTraps( lib.setup_engineID(None, None) transport = lib.netsnmp_tdomain_transport(peername, 1, "udp") if not transport: - raise SnmpError( + raise NetSnmpError( "Unable to create transport {peername}".format( peername=peername ) @@ -861,7 +850,7 @@ def awaitTraps( sess.isAuthoritative = SNMP_SESS_UNKNOWNAUTH rc = lib.snmp_add(self.sess, transport, pre_parse_callback, None) if not rc: - raise SnmpError("snmp_add") + raise NetSnmpError("snmp_add") def create_users(self, users): self._log.debug("create_users: Creating %s users.", len(users)) @@ -991,7 +980,7 @@ def _handle_send_status(self, req, send_status, send_type): lib.snmp_free_pdu(req) if snmperr.value == SNMPERR_TIMEOUT: raise SnmpTimeoutError() - raise SnmpError(msg_fmt % msg_args) + raise NetSnmpError(msg_fmt % msg_args) def get(self, oids): req = self._create_request(SNMP_MSG_GET) diff --git a/pynetsnmp/oids.py b/pynetsnmp/oids.py new file mode 100644 index 0000000..873991b --- /dev/null +++ b/pynetsnmp/oids.py @@ -0,0 +1,94 @@ +from __future__ import absolute_import + +from .conversions import asOidStr + + +class OID(object): + __slots__ = ("oid",) + + def __init__(self, oid): + super(OID, self).__setattr__("oid", oid) + + def __setattr__(self, key, value): + if key in OID.__slots__: + raise AttributeError( + "can't set attribute '{}' on 'OID' object".format(key) + ) + super(OID, self).__setattr__(key, value) + + def __eq__(this, that): + if isinstance(that, (tuple, list)): + return this.oid == that + if isinstance(that, OID): + return this.oid == that.oid + return NotImplemented + + def __ne__(this, that): + if isinstance(that, (tuple, list)): + return this.oid != that + if isinstance(that, OID): + return this.oid != that.oid + return NotImplemented + + def __hash__(self): + return hash(self.oid) + + def __repr__(self): + return "<{0.__module__}.{0.__class__.__name__} {1}>".format( + self, asOidStr(self.oid) + ) + + def __str__(self): + return asOidStr(self.oid) + + +_base_status_oid = (1, 3, 6, 1, 6, 3, 15, 1, 1) + + +class UnknownSecurityLevel(OID): + __slots__ = () + + +UnknownSecurityLevel = UnknownSecurityLevel(_base_status_oid + (1, 0)) + + +class NotInTimeWindow(OID): + __slots__ = () + + +NotInTimeWindow = NotInTimeWindow(_base_status_oid + (2, 0)) + + +class UnknownUserName(OID): + __slots__ = () + + +UnknownUserName = UnknownUserName(_base_status_oid + (3, 0)) + + +class UnknownEngineId(OID): + __slots__ = () + + +UnknownEngineId = UnknownEngineId(_base_status_oid + (4, 0)) + + +class WrongDigest(OID): + __slots__ = () + + +WrongDigest = WrongDigest(_base_status_oid + (5, 0)) + + +class DecryptionError(OID): + __slots__ = () + + +DecryptionError = DecryptionError(_base_status_oid + (6, 0)) + + +class SysDescr(OID): + __slots__ = () + + +SysDescr = SysDescr((1, 3, 6, 1, 2, 1, 1, 1, 0)) diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py index 21e4b4f..0a48b40 100644 --- a/pynetsnmp/security.py +++ b/pynetsnmp/security.py @@ -1,7 +1,9 @@ from __future__ import absolute_import from .CONSTANTS import SNMP_VERSION_1, SNMP_VERSION_2c, SNMP_VERSION_3 -from .usm import auth_protocols, priv_protocols +from .usm import AUTH_NOAUTH, auth_protocols, PRIV_NOPRIV, priv_protocols + +__all__ = ("Community", "UsmUser", "Authentication", "Privacy") class Community(object): @@ -10,11 +12,13 @@ class Community(object): """ def __init__(self, name, version=SNMP_VERSION_2c): - version = _version_map.get(version) - if version is None: - raise ValueError("Unsupported SNMP version '{}'".format(version)) + mapped = _version_map.get(version) + if mapped is None or mapped == "3": + raise ValueError( + "SNMP version '{}' not supported for Community".format(version) + ) self.name = name - self.version = version + self.version = mapped def getArguments(self): community = ("-c", str(self.name)) if self.name else () @@ -28,43 +32,44 @@ class UsmUser(object): def __init__(self, name, auth=None, priv=None, engine=None, context=None): self.name = name - if not isinstance(auth, (type(None), Authentication)): - raise ValueError("invalid authentication protocol") + if auth is None: + auth = Authentication.new_noauth() + if not isinstance(auth, Authentication): + raise ValueError("invalid authentication object") self.auth = auth - if not isinstance(priv, (type(None), Privacy)): - raise ValueError("invalid privacy protocol") + if priv is None: + priv = Privacy.new_nopriv() + if not isinstance(priv, Privacy): + raise ValueError("invalid privacy object") self.priv = priv self.engine = engine self.context = context self.version = _version_map.get(SNMP_VERSION_3) def getArguments(self): - auth = ( + auth_args = ( ("-a", self.auth.protocol.name, "-A", self.auth.passphrase) if self.auth else () ) - if auth: + if auth_args: # The privacy arguments are only given if the authentication # arguments are also provided. - priv = ( + priv_args = ( ("-x", self.priv.protocol.name, "-X", self.priv.passphrase) if self.priv else () ) else: - priv = () - seclevel = ( - "-l", - _sec_level.get((bool(auth), bool(priv)), "noAuthNoPriv"), - ) + priv_args = () + seclevel_arg = ("-l", _sec_level[(bool(self.auth), bool(self.priv))]) return ( ("-v", self.version) + (("-u", self.name) if self.name else ()) - + seclevel - + auth - + priv + + seclevel_arg + + auth_args + + priv_args + (("-e", self.engine) if self.engine else ()) + (("-n", self.context) if self.context else ()) ) @@ -95,8 +100,15 @@ def __str__(self): ) -_sec_level = {(True, True): "authPriv", (True, False): "authNoPriv"} +_sec_level = { + (True, True): "authPriv", + (True, False): "authNoPriv", + (False, False): "noAuthNoPriv", +} _version_map = { + "1": "1", + "2c": "2c", + "3": "3", SNMP_VERSION_1: "1", SNMP_VERSION_2c: "2c", SNMP_VERSION_3: "3", @@ -113,15 +125,25 @@ class Authentication(object): __slots__ = ("protocol", "passphrase") + @classmethod + def new_noauth(cls): + return cls(None, None) + def __init__(self, protocol, passphrase): - if protocol is None: - raise ValueError( - "Invalid Authentication protocol '{}'".format(protocol) - ) - self.protocol = auth_protocols[protocol] - if not passphrase: - raise ValueError("Authentication protocol requires a passphrase") - self.passphrase = passphrase + if ( + not protocol + or protocol is AUTH_NOAUTH + or protocol == "AUTH_NOAUTH" + ): + self.protocol = AUTH_NOAUTH + self.passphrase = None + else: + self.protocol = auth_protocols[protocol] + if not passphrase: + raise ValueError( + "Authentication protocol requires a passphrase" + ) + self.passphrase = passphrase def __eq__(self, other): if not isinstance(other, Authentication): @@ -131,6 +153,14 @@ def __eq__(self, other): and self.passphrase == other.passphrase ) + def __nonzero__(self): + return self.protocol is not AUTH_NOAUTH + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + def __str__(self): return "{0.__class__.__name__}(protocol={0.protocol})".format(self) @@ -142,13 +172,23 @@ class Privacy(object): __slots__ = ("protocol", "passphrase") + @classmethod + def new_nopriv(cls): + return cls(None, None) + def __init__(self, protocol, passphrase): - if protocol is None: - raise ValueError("Invalid Privacy protocol '{}'".format(protocol)) - self.protocol = priv_protocols[protocol] - if not passphrase: - raise ValueError("Privacy protocol requires a passphrase") - self.passphrase = passphrase + if ( + not protocol + or protocol is PRIV_NOPRIV + or protocol == "PRIV_NOPRIV" + ): + self.protocol = PRIV_NOPRIV + self.passphrase = None + else: + self.protocol = priv_protocols[protocol] + if not passphrase: + raise ValueError("Privacy protocol requires a passphrase") + self.passphrase = passphrase def __eq__(self, other): if not isinstance(other, Privacy): @@ -158,5 +198,13 @@ def __eq__(self, other): and self.passphrase == other.passphrase ) + def __nonzero__(self): + return self.protocol is not PRIV_NOPRIV + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + def __str__(self): return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/twistedsnmp.py b/pynetsnmp/twistedsnmp.py index 9a4592f..52142de 100644 --- a/pynetsnmp/twistedsnmp.py +++ b/pynetsnmp/twistedsnmp.py @@ -8,7 +8,7 @@ from twisted.internet.error import TimeoutError from twisted.python import failure -from . import netsnmp +from . import netsnmp, oids from .CONSTANTS import ( SNMP_ERR_AUTHORIZATIONERROR, SNMP_ERR_BADVALUE, @@ -31,8 +31,11 @@ SNMP_ERR_WRONGVALUE, ) from .conversions import asAgent, asOidStr, asOid +from .errors import SnmpError, SnmpUsmError, get_stats_error from .tableretriever import TableRetriever +log = netsnmp._getLogger("agentproxy") + class Timer(object): callLater = None @@ -131,47 +134,6 @@ def updateReactor(): timer.callLater = reactor.callLater(t, checkTimeouts) -class SnmpNameError(Exception): - def __init__(self, oid): - Exception.__init__(self, "Bad Name", oid) - - -class SnmpError(Exception): - def __init__(self, message, *args, **kwargs): - self.message = message - - def __str__(self): - return self.message - - def __repr__(self): - return self.message - - -class Snmpv3Error(SnmpError): - pass - - -USM_STATS_OIDS = { - # usmStatsWrongDigests - ".1.3.6.1.6.3.15.1.1.5.0": ( - "check zSnmpAuthType and zSnmpAuthPassword, " - "packet did not include the expected digest value" - ), - # usmStatsUnknownUserNames - ".1.3.6.1.6.3.15.1.1.3.0": ( - "check zSnmpSecurityName, packet referenced an unknown user" - ), - # usmStatsUnsupportedSecLevels - ".1.3.6.1.6.3.15.1.1.1.0": ( - "packet requested an unknown or unavailable security level" - ), - # usmStatsDecryptionErrors - ".1.3.6.1.6.3.15.1.1.6.0": ( - "check zSnmpPrivType, packet could not be decrypted" - ), -} - - class AgentProxy(object): """The public methods on AgentProxy (get, walk, getbulk) expect input OIDs to be strings, and the result they produce is a dictionary. The @@ -229,42 +191,53 @@ def __init__( self.timeout = timeout self.tries = tries self.cmdLineArgs = cmdLineArgs - self.defers = {} + self.defers = _DeferredMap() self.session = None - self._log = netsnmp._getLogger("agentproxy") - def _signSafePop(self, d, intkey): - """ - Attempt to pop the item at intkey from dictionary d. - Upon failure, try to convert intkey from a signed to an unsigned - integer and try to pop again. + def open(self): + if self.session is not None: + self.session.close() + self.session = None + updateReactor() - This addresses potential integer rollover issues caused by the fact - that netsnmp_pdu.reqid is a c_long and the netsnmp_callback function - pointer definition specifies it as a c_int. See ZEN-4481. - """ - try: - return d.pop(intkey) - except KeyError as ex: - if intkey < 0: - self._log.debug("Negative ID for _signSafePop: %s", intkey) - # convert to unsigned, try that key - uintkey = struct.unpack("I", struct.pack("i", intkey))[0] - try: - return d.pop(uintkey) - except KeyError: - # Nothing by the unsigned key either, - # throw the original KeyError for consistency - raise ex - raise + if self._security: + agent = asAgent(self.ip, self.port) + cmdlineargs = self._security.getArguments() + ( + ("-t", str(self.timeout), "-r", str(self.tries), agent) + ) + self.session = netsnmp.Session(cmdLineArgs=cmdlineargs) + else: + self.session = netsnmp.Session( + version=netsnmp.SNMP_VERSION_MAP.get( + self.snmpVersion, netsnmp.SNMP_VERSION_2c + ), + timeout=int(self.timeout), + retries=int(self.tries), + peername="%s:%d" % (self.ip, self.port), + community=self.community, + community_len=len(self.community), + cmdLineArgs=self._getCmdLineArgs(), + ) + + self.session.callback = self.callback + self.session.timeout = self._handle_timeout + self.session.open() + updateReactor() + + def close(self): + if self.session is not None: + self.session.close() + self.session = None + updateReactor() def callback(self, pdu): """netsnmp session callback""" - response = netsnmp.getResult(pdu, self._log) + response = netsnmp.getResult(pdu, log) try: - d, oids_requested = self._pop_requested_oids(pdu, response) - except RuntimeError: + d, oids_requested = self.defers.pop(pdu.reqid) + except KeyError: + self._handle_missing_request(response) return result = tuple( @@ -273,23 +246,21 @@ def callback(self, pdu): ) if len(result) == 1 and result[0][0] not in oids_requested: - usmStatsOidStr = asOidStr(result[0][0]) - if usmStatsOidStr in USM_STATS_OIDS: - msg = USM_STATS_OIDS.get(usmStatsOidStr) - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(msg)) - ) + statsOid = result[0][0] + error = get_stats_error(statsOid) + if error: + reactor.callLater(0, d.errback, failure.Failure(error)) return - elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": + if statsOid == oids.NotInTimeWindow: # we may get a subsequent snmp result with the correct value # if not the timeout will be called at some point self.defers[pdu.reqid] = (d, oids_requested) return if pdu.errstat != SNMP_ERR_NOERROR: pduError = PDU_ERRORS.get( - pdu.errstat, "Unknown error (%d)" % pdu.errstat + pdu.errstat, "unknown error (%d)" % pdu.errstat ) - message = "Packet for %s has error: %s" % (self.ip, pduError) + message = "packet for %s has error: %s" % (self.ip, pduError) if pdu.errstat in ( SNMP_ERR_NOACCESS, SNMP_ERR_RESOURCEUNAVAILABLE, @@ -301,63 +272,53 @@ def callback(self, pdu): return else: result = [] - self._log.warning(message + ". OIDS: %s", oids_requested) + log.warning(message + ". OIDS: %s", oids_requested) reactor.callLater(0, d.callback, result) - def _pop_requested_oids(self, pdu, response): - try: - return self._signSafePop(self.defers, pdu.reqid) - except KeyError: - # We seem to end up here if we use bad credentials with authPriv. - # The only reasonable thing to do is call all of the deferreds with - # Snmpv3Errors. - for usmStatsOid, _ in response: - usmStatsOidStr = asOidStr(usmStatsOid) - - if usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": - # Some devices use usmStatsNotInTimeWindows as a normal - # part of the SNMPv3 handshake (JIRA-1565). - # net-snmp automatically retries the request with the - # previous request_id and the values for - # msgAuthoritativeEngineBoots and - # msgAuthoritativeEngineTime from this error packet. - self._log.debug( - "Received a usmStatsNotInTimeWindows error. Some " - "devices use usmStatsNotInTimeWindows as a normal " - "part of the SNMPv3 handshake." - ) - raise RuntimeError("usmStatsNotInTimeWindows error") - - if usmStatsOidStr == ".1.3.6.1.2.1.1.1.0": - # Some devices (Cisco Nexus/MDS) use sysDescr as a normal - # part of the SNMPv3 handshake (JIRA-7943) - self._log.debug( - "Received sysDescr during handshake. Some devices use " - "sysDescr as a normal part of the SNMPv3 handshake." - ) - raise RuntimeError("sysDescr during handshake") - - default_msg = "packet dropped (OID: {0})".format( - usmStatsOidStr - ) - message = USM_STATS_OIDS.get(usmStatsOidStr, default_msg) - break - else: - message = "packet dropped" + def _handle_missing_request(self, response): + usmStatsOid, _ = next(iter(response), (None, None)) + + if usmStatsOid == oids.NotInTimeWindow: + # Some devices use usmStatsNotInTimeWindows as a normal part of + # the SNMPv3 handshake (JIRA-1565). net-snmp automatically + # retries the request with the previous request_id and the + # values for msgAuthoritativeEngineBoots and + # msgAuthoritativeEngineTime from this error packet. + log.debug( + "Received a usmStatsNotInTimeWindows error. Some " + "devices use usmStatsNotInTimeWindows as a normal " + "part of the SNMPv3 handshake." + ) + return - for d in ( - d for d, rOids in self.defers.itervalues() if not d.called - ): - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(message)) + if usmStatsOid == oids.SysDescr: + # Some devices (Cisco Nexus/MDS) use sysDescr as a normal + # part of the SNMPv3 handshake (JIRA-7943) + log.debug( + "Received sysDescr during handshake. Some devices use " + "sysDescr as a normal part of the SNMPv3 handshake." + ) + return + + if usmStatsOid is not None: + error = get_stats_error(usmStatsOid) + if not error: + error = SnmpUsmError( + "packet dropped (OID: {0})".format(asOidStr(usmStatsOid)) ) + else: + error = SnmpUsmError("packet dropped") - raise RuntimeError(message) + for d in (d for d, _ in self.defers.itervalues() if not d.called): + reactor.callLater(0, d.errback, failure.Failure(error)) - def timeout_(self, reqid): - d = self._signSafePop(self.defers, reqid)[0] - reactor.callLater(0, d.errback, failure.Failure(TimeoutError())) + def _handle_timeout(self, reqid): + try: + d = self.defers.pop(reqid)[0] + reactor.callLater(0, d.errback, failure.Failure(TimeoutError())) + except KeyError: + log.warning("handled timeout for unknown request") def _getCmdLineArgs(self): if not self.cmdLineArgs: @@ -382,42 +343,6 @@ def _getCmdLineArgs(self): ] return cmdLineArgs - def open(self): - if self.session is not None: - self.session.close() - self.session = None - updateReactor() - - if self._security: - agent = asAgent(self.ip, self.port) - cmdlineargs = self._security.getArguments() + ( - ("-t", str(self.timeout), "-r", str(self.tries), agent) - ) - self.session = netsnmp.Session(cmdLineArgs=cmdlineargs) - else: - self.session = netsnmp.Session( - version=netsnmp.SNMP_VERSION_MAP.get( - self.snmpVersion, netsnmp.SNMP_VERSION_2c - ), - timeout=int(self.timeout), - retries=int(self.tries), - peername="%s:%d" % (self.ip, self.port), - community=self.community, - community_len=len(self.community), - cmdLineArgs=self._getCmdLineArgs(), - ) - - self.session.callback = self.callback - self.session.timeout = self.timeout_ - self.session.open() - updateReactor() - - def close(self): - if self.session is not None: - self.session.close() - self.session = None - updateReactor() - def _get(self, oids, timeout=None, retryCount=None): d = defer.Deferred() try: @@ -488,3 +413,25 @@ def port(self): snmpprotocol = _FakeProtocol() + + +class _DeferredMap(dict): + """ + Wrap the dict type to add extra behavior. + """ + + def pop(self, key): + """ + Attempt to pop the item at key from the dictionary. + """ + # Check for negative key to address potential integer rollover issues + # caused by the fact that netsnmp_pdu.reqid is a c_long and the + # netsnmp_callback function pointer definition specifies it as a + # c_int. See ZEN-4481. + if key not in self and key < 0: + log.debug("try negative ID for deferred map: %s", key) + # convert to unsigned, try that key + uintkey = struct.unpack("I", struct.pack("i", key))[0] + if uintkey in self: + key = uintkey + return super(_DeferredMap, self).pop(key) diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm.py index 7455e9f..3606eb3 100644 --- a/pynetsnmp/usm.py +++ b/pynetsnmp/usm.py @@ -38,6 +38,8 @@ def __iter__(self): return iter(self.__protocols) def __contains__(self, proto): + if not proto: + proto = self.__noargs if proto not in self.__protocols: return any(str(p) == proto for p in self.__protocols) return True @@ -46,7 +48,9 @@ def __getitem__(self, name): name = str(name) proto = next((p for p in self.__protocols if str(p) == name), None) if proto is None: - raise KeyError("No {} protocol '{}'".format(self.__kind, name)) + raise KeyError( + "unknown {} protocol '{}'".format(self.__kind, name) + ) return proto def __repr__(self): @@ -65,6 +69,7 @@ def __repr__(self): auth_protocols = _Protocols( ( + AUTH_NOAUTH, AUTH_MD5, AUTH_SHA, AUTH_SHA_224, @@ -82,25 +87,25 @@ def __repr__(self): PRIV_AES_256 = _Protocol("AES-256", (1, 3, 6, 1, 4, 1, 14832, 1, 4)) priv_protocols = _Protocols( - (PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" + (PRIV_NOPRIV, PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" ) del _Protocol del _Protocols __all__ = ( - "AUTH_NOAUTH", "AUTH_MD5", + "AUTH_NOAUTH", + "auth_protocols", "AUTH_SHA", "AUTH_SHA_224", "AUTH_SHA_256", "AUTH_SHA_384", "AUTH_SHA_512", - "auth_protocols", - "PRIV_NOPRIV", - "PRIV_DES", "PRIV_AES", "PRIV_AES_192", "PRIV_AES_256", + "PRIV_DES", + "PRIV_NOPRIV", "priv_protocols", ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..3783922 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,382 @@ +import unittest + +from pynetsnmp import security, usm + + +class TestCommunity(unittest.TestCase): + name = "public" + + def test_default(t): + c = security.Community(t.name) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_constant(t): + c = security.Community(t.name, security.SNMP_VERSION_1) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_v1(t): + c = security.Community(t.name, "v1") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_constant(t): + c = security.Community(t.name, security.SNMP_VERSION_2c) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_v2c(t): + c = security.Community(t.name, "v2c") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v3_constant(t): + with t.assertRaises(ValueError): + security.Community(t.name, security.SNMP_VERSION_3) + + def test_v3_v3(t): + with t.assertRaises(ValueError): + security.Community(t.name, "v3") + + def test_none_version(t): + with t.assertRaises(ValueError): + security.Community(t.name, None) + + def test_not_a_version_str(t): + with t.assertRaises(ValueError): + security.Community(t.name, "oi") + + def test_not_a_version_number(t): + with t.assertRaises(ValueError): + security.Community(t.name, 3947) + + +class TestUsmUser(unittest.TestCase): + name = "john_doe" + passwd = "secured123" # noqa: S105 + + def test_default(t): + user = security.UsmUser(t.name) + t.assertEqual(t.name, user.name) + t.assertEqual(security.Authentication.new_noauth(), user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ("-v", "3", "-u", t.name, "-l", "noAuthNoPriv") + t.assertSequenceEqual(expected, user.getArguments()) + + def test_engineid(t): + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = security.UsmUser(t.name, engine=engineid) + t.assertEqual(t.name, user.name) + t.assertEqual(security.Authentication.new_noauth(), user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertEqual(engineid, user.engine) + t.assertIsNone(user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-e", + engineid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_contextid(t): + contextid = hex(9084090984572743455234)[2:].strip("L") + user = security.UsmUser(t.name, context=contextid) + t.assertEqual(t.name, user.name) + t.assertEqual(security.Authentication.new_noauth(), user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertEqual(contextid, user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_auth(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + user = security.UsmUser(t.name, auth=auth) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authNoPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_authpriv(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = security.Privacy(usm.PRIV_AES_256, t.passwd) + user = security.UsmUser(t.name, auth=auth, priv=priv) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_all_args(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = security.Privacy(usm.PRIV_AES_256, t.passwd) + contextid = hex(9084090984572743455234)[2:].strip("L") + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = security.UsmUser( + t.name, auth=auth, priv=priv, engine=engineid, context=contextid + ) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertEqual(engineid, user.engine) + t.assertEqual(contextid, user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + "-e", + engineid, + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_equality(t): + auth1 = security.Authentication(usm.AUTH_SHA_224, t.passwd) + user1 = security.UsmUser(t.name, auth=auth1) + auth2 = security.Authentication(usm.AUTH_SHA_224, t.passwd) + user2 = security.UsmUser(t.name, auth=auth2) + auth3 = security.Authentication(usm.AUTH_SHA_256, t.passwd) + priv3 = security.Privacy(usm.PRIV_AES_256, t.passwd) + user3 = security.UsmUser(t.name, auth=auth3, priv=priv3) + t.assertEqual(user1, user2) + t.assertNotEqual(user1, user3) + + +class TestAuthentication(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_noauth_classmethod(t): + auth = security.Authentication.new_noauth() + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_none_init(t): + auth = security.Authentication(None, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = security.Authentication(None, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_init(t): + auth = security.Authentication(usm.AUTH_NOAUTH, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = security.Authentication(usm.AUTH_NOAUTH, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_str_init(t): + auth = security.Authentication("AUTH_NOAUTH", None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = security.Authentication("AUTH_NOAUTH", t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_is_false(t): + auth = security.Authentication.new_noauth() + t.assertFalse(auth) + + def test_md5(t): + auth = security.Authentication(usm.AUTH_MD5, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_MD5) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha(t): + auth = security.Authentication(usm.AUTH_SHA, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_224(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_224) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_256(t): + auth = security.Authentication(usm.AUTH_SHA_256, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_256) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_384(t): + auth = security.Authentication(usm.AUTH_SHA_384, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_384) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_512(t): + auth = security.Authentication(usm.AUTH_SHA_512, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_512) + t.assertEqual(auth.passphrase, t.passwd) + + def test_equal(t): + auth1 = security.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = security.Authentication(usm.AUTH_MD5, t.passwd) + t.assertEqual(auth1, auth2) + + def test_not_equal(t): + auth1 = security.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = security.Authentication(usm.AUTH_SHA, t.passwd) + t.assertNotEqual(auth1, auth2) + + auth3 = security.Authentication(usm.AUTH_SHA, t.passwd + "456") + t.assertNotEqual(auth2, auth3) + + +class TestPrivacy(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_nopriv_classmethod(t): + priv = security.Privacy.new_nopriv() + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_none_init(t): + priv = security.Privacy(None, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = security.Privacy(None, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_init(t): + priv = security.Privacy(usm.PRIV_NOPRIV, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = security.Privacy(usm.PRIV_NOPRIV, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_str_init(t): + priv = security.Privacy("PRIV_NOPRIV", None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = security.Privacy("PRIV_NOPRIV", t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_is_false(t): + priv = security.Privacy.new_nopriv() + t.assertFalse(priv) + + def test_des(t): + priv = security.Privacy(usm.PRIV_DES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_DES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes(t): + priv = security.Privacy(usm.PRIV_AES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_192(t): + priv = security.Privacy(usm.PRIV_AES_192, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_192) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_256(t): + priv = security.Privacy(usm.PRIV_AES_256, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_256) + t.assertEqual(priv.passphrase, t.passwd) + + def test_equal(t): + priv1 = security.Privacy(usm.PRIV_DES, t.passwd) + priv2 = security.Privacy(usm.PRIV_DES, t.passwd) + t.assertEqual(priv1, priv2) + + def test_not_equal(t): + priv1 = security.Privacy(usm.PRIV_DES, t.passwd) + priv2 = security.Privacy(usm.PRIV_AES, t.passwd) + t.assertNotEqual(priv1, priv2) + + priv3 = security.Privacy(usm.PRIV_AES, t.passwd + "456") + t.assertNotEqual(priv2, priv3) diff --git a/tests/test_usm.py b/tests/test_usm.py new file mode 100644 index 0000000..3a6e4e9 --- /dev/null +++ b/tests/test_usm.py @@ -0,0 +1,114 @@ +import unittest + +from pynetsnmp import usm + +_sorted_auth_names = sorted( + ["NOAUTH", "MD5", "SHA", "SHA-224", "SHA-256", "SHA-384", "SHA-512"] +) + + +class TestAuthProtocols(unittest.TestCase): + def test_noauth_contained(t): + t.assertIn(usm.AUTH_NOAUTH, usm.auth_protocols) + + def test_md5_contained(t): + t.assertIn(usm.AUTH_MD5, usm.auth_protocols) + + def test_sha_contained(t): + t.assertIn(usm.AUTH_SHA, usm.auth_protocols) + + def test_sha_224_contained(t): + t.assertIn(usm.AUTH_SHA_224, usm.auth_protocols) + + def test_sha_256_contained(t): + t.assertIn(usm.AUTH_SHA_256, usm.auth_protocols) + + def test_sha_384_contained(t): + t.assertIn(usm.AUTH_SHA_384, usm.auth_protocols) + + def test_sha_512_contained(t): + t.assertIn(usm.AUTH_SHA_512, usm.auth_protocols) + + def test_length(t): + t.assertEqual(7, len(usm.auth_protocols)) + + def test_iterable(t): + names = sorted(str(p) for p in usm.auth_protocols) + t.assertEqual(7, len(names)) + t.assertListEqual(_sorted_auth_names, names) + + def test_noauth_getitem(t): + proto = usm.auth_protocols[usm.AUTH_NOAUTH.name] + t.assertEqual(usm.AUTH_NOAUTH, proto) + + def test_md5_getitem(t): + proto = usm.auth_protocols[usm.AUTH_MD5.name] + t.assertEqual(usm.AUTH_MD5, proto) + + def test_sha_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA.name] + t.assertEqual(usm.AUTH_SHA, proto) + + def test_sha_224_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_224.name] + t.assertEqual(usm.AUTH_SHA_224, proto) + + def test_sha_256_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_256.name] + t.assertEqual(usm.AUTH_SHA_256, proto) + + def test_sha_384_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_384.name] + t.assertEqual(usm.AUTH_SHA_384, proto) + + def test_sha_512_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_512.name] + t.assertEqual(usm.AUTH_SHA_512, proto) + + +_sorted_priv_names = sorted(["NOPRIV", "DES", "AES", "AES-192", "AES-256"]) + + +class TestPrivProtocols(unittest.TestCase): + def test_nopriv_contained(t): + t.assertIn(usm.PRIV_NOPRIV, usm.priv_protocols) + + def test_des_contained(t): + t.assertIn(usm.PRIV_DES, usm.priv_protocols) + + def test_aes_contained(t): + t.assertIn(usm.PRIV_AES, usm.priv_protocols) + + def test_aes_192_contained(t): + t.assertIn(usm.PRIV_AES_192, usm.priv_protocols) + + def test_aes_256_contained(t): + t.assertIn(usm.PRIV_AES_256, usm.priv_protocols) + + def test_length(t): + t.assertEqual(5, len(usm.priv_protocols)) + + def test_iterable(t): + names = sorted(str(p) for p in usm.priv_protocols) + t.assertEqual(5, len(names)) + t.assertListEqual(_sorted_priv_names, names) + + def test_nopriv_getitem(t): + proto = usm.priv_protocols[usm.PRIV_NOPRIV.name] + t.assertEqual(usm.PRIV_NOPRIV, proto) + + def test_des_getitem(t): + proto = usm.priv_protocols[usm.PRIV_DES.name] + t.assertEqual(usm.PRIV_DES, proto) + + def test_aes_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES.name] + t.assertEqual(usm.PRIV_AES, proto) + + def test_aes_192_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES_192.name] + t.assertEqual(usm.PRIV_AES_192, proto) + + def test_aes_256_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES_256.name] + t.assertEqual(usm.PRIV_AES_256, proto)