From 5486d3e65f1a8b75f230f7d88064284b70290f2e Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Fri, 24 Jan 2025 13:21:50 +0100 Subject: [PATCH 1/3] feat: add SNMP provider --- keep/providers/snmp_provider/snmp_provider.py | 608 ++++++++++++++++++ poetry.lock | 70 +- pyproject.toml | 6 + tests/test_snmp_provider.py | 389 +++++++++++ 4 files changed, 1070 insertions(+), 3 deletions(-) create mode 100644 keep/providers/snmp_provider/snmp_provider.py create mode 100644 tests/test_snmp_provider.py diff --git a/keep/providers/snmp_provider/snmp_provider.py b/keep/providers/snmp_provider/snmp_provider.py new file mode 100644 index 000000000..9413046ab --- /dev/null +++ b/keep/providers/snmp_provider/snmp_provider.py @@ -0,0 +1,608 @@ +""" +SNMP Provider is a class that provides functionality to receive SNMP traps and convert them to Keep alerts. +""" + +import asyncio +import dataclasses +import os +import typing +from pathlib import Path +from pysnmp.hlapi.v3arch.asyncio.auth import CommunityData, UsmUserData +from pysnmp.hlapi.v3arch.asyncio.transport import UdpTransportTarget +from pysnmp.hlapi.v3arch.asyncio.context import ContextData +from pysnmp.hlapi.v3arch.asyncio.cmdgen import ( + SnmpEngine, + ObjectType, + get_cmd, + set_cmd, + next_cmd, + bulk_cmd +) +from pysnmp.smi.rfc1902 import ObjectIdentity +from pysnmp.carrier.asyncio.dgram import udp +from pysnmp.entity import engine, config +from pysnmp.smi import builder, view, compiler +from pysnmp.proto.rfc1902 import ObjectIdentifier +from pysnmp.entity.rfc3413 import ntfrcv + +import pydantic + +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.contextmanager.contextmanager import ContextManager +from keep.providers.base.base_provider import BaseProvider +from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import UrlPort +from keep.exceptions.provider_exception import ProviderException + + +@pydantic.dataclasses.dataclass +class SnmpProviderAuthConfig: + """SNMP authentication configuration.""" + + listen_port: UrlPort = dataclasses.field( + metadata={ + "required": True, + "description": "Port to listen for SNMP traps", + "config_main_group": "authentication", + "validation": "port", + }, + default=162, + ) + + community_string: str = dataclasses.field( + metadata={ + "required": True, + "description": "SNMP community string for authentication", + "sensitive": True, + "config_main_group": "authentication", + }, + default="public" + ) + + snmp_version: typing.Literal["v1", "v2c", "v3"] = dataclasses.field( + default="v2c", + metadata={ + "required": True, + "description": "SNMP protocol version", + "type": "select", + "options": ["v1", "v2c", "v3"], + "config_main_group": "authentication", + }, + ) + + # SNMPv3 specific configuration + username: str = dataclasses.field( + metadata={ + "required": False, + "description": "SNMPv3 username", + "config_main_group": "authentication", + }, + default="", + ) + + auth_protocol: typing.Literal["MD5", "SHA"] = dataclasses.field( + default="SHA", + metadata={ + "required": False, + "description": "SNMPv3 authentication protocol", + "type": "select", + "options": ["MD5", "SHA"], + "config_main_group": "authentication", + }, + ) + + auth_key: str = dataclasses.field( + metadata={ + "required": False, + "sensitive": True, + "description": "SNMPv3 authentication key", + "config_main_group": "authentication", + }, + default="", + ) + + priv_protocol: typing.Literal["DES", "AES"] = dataclasses.field( + default="AES", + metadata={ + "required": False, + "description": "SNMPv3 privacy protocol", + "type": "select", + "options": ["DES", "AES"], + "config_main_group": "authentication", + }, + ) + + priv_key: str = dataclasses.field( + metadata={ + "required": False, + "sensitive": True, + "description": "SNMPv3 privacy key", + "config_main_group": "authentication", + }, + default="", + ) + + # MIB configuration + mib_dirs: list[str] = dataclasses.field( + metadata={ + "required": False, + "description": "List of directories containing custom MIB files", + "config_main_group": "authentication", + }, + default_factory=list, + ) + + +class SnmpProvider(BaseProvider): + """ + SNMP provider class for receiving SNMP traps. + """ + + PROVIDER_DISPLAY_NAME = "SNMP" + PROVIDER_CATEGORY = ["Monitoring"] + PROVIDER_TAGS = ["alert"] + + PROVIDER_SCOPES = [ + ProviderScope( + name="receive_traps", + description="Ability to receive SNMP traps", + mandatory=True, + alias="Receive SNMP Traps", + ) + ] + + def __init__( + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + ): + super().__init__(context_manager, provider_id, config) + self.snmp_engine = None + self.trap_receiver = None + self.mib_view_controller = None + + async def dispose(self): + """ + Clean up SNMP engine and trap receiver. + """ + if self.snmp_engine: + try: + dispatcher = self.snmp_engine.transport_dispatcher + if hasattr(dispatcher, 'loopingcall'): + try: + if not dispatcher.loopingcall.done(): + dispatcher.loopingcall.cancel() + # Wait for cancellation to complete + try: + asyncio.wait_for(dispatcher.loopingcall, timeout=1) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + except Exception as e: + self.logger.debug(f"Error canceling dispatcher timeout: {e}") + + if hasattr(dispatcher, '_transports'): + for transport in dispatcher._transports.values(): + if hasattr(transport, 'close'): + transport.close() + dispatcher.close_dispatcher() + self.snmp_engine = None + except Exception as e: + self.logger.error(f"Error disposing SNMP engine: {str(e)}") + + if self.trap_receiver: + try: + await self.trap_receiver.close() + self.trap_receiver = None + except Exception as e: + self.logger.error(f"Error disposing trap receiver: {str(e)}") + + def validate_config(self): + """ + Validate the SNMP provider configuration. + """ + self.authentication_config = SnmpProviderAuthConfig( + **self.config.authentication + ) + + # Validate SNMPv3 configuration if v3 is selected + if self.authentication_config.snmp_version == "v3": + if not self.authentication_config.username: + raise ProviderException("Username is required for SNMPv3") + if not self.authentication_config.auth_key: + raise ProviderException("Authentication key is required for SNMPv3") + if not self.authentication_config.priv_key: + raise ProviderException("Privacy key is required for SNMPv3") + + # Validate MIB directories + for mib_dir in self.authentication_config.mib_dirs: + if not os.path.isdir(mib_dir): + raise ProviderException(f"MIB directory does not exist: {mib_dir}") + + def validate_scopes(self) -> dict[str, bool | str]: + """ + Validate that the scopes provided are correct. + """ + try: + # Try to create an SNMP engine with the provided config + snmp_engine = engine.SnmpEngine() + + # Configure transport + config.add_transport( + snmp_engine, + udp.DOMAIN_NAME, + udp.UdpTransport().open_server_mode(('0.0.0.0', self.authentication_config.listen_port)) + ) + + if self.authentication_config.snmp_version == "v3": + # Configure SNMPv3 + config.add_v3_user( + snmp_engine, + self.authentication_config.username, + config.usmHMACMD5AuthProtocol if self.authentication_config.auth_protocol == "MD5" else config.usmHMACSHAAuthProtocol, + self.authentication_config.auth_key, + config.usmDESPrivProtocol if self.authentication_config.priv_protocol == "DES" else config.usmAesCfb128Protocol, + self.authentication_config.priv_key + ) + else: + # Configure v1/v2c community string + config.add_v1_system( + snmp_engine, + 'my-area', + self.authentication_config.community_string + ) + + snmp_engine.transport_dispatcher.close_dispatcher() + return {"receive_traps": True} + except Exception as e: + return {"receive_traps": str(e)} + + def _setup_mib_compiler(self): + """ + Set up MIB compiler with custom MIB directories. + """ + try: + mib_builder = builder.MibBuilder() + + # Add custom MIB directories + for mib_dir in self.authentication_config.mib_dirs: + self.logger.info(f"Adding MIB directory: {mib_dir}") + mib_path = Path(mib_dir) + if mib_path.exists() and mib_path.is_dir(): + mib_builder.add_mib_sources(builder.DirMibSource(str(mib_path))) + else: + self.logger.warning(f"MIB directory not found: {mib_dir}") + + # Add MIB compiler + compiler.add_mib_compiler(mib_builder) + self.mib_view_controller = view.MibViewController(mib_builder) + self.logger.info("MIB compiler setup completed successfully") + + except Exception as e: + self.logger.error(f"Error setting up MIB compiler: {str(e)}") + raise ProviderException(f"Failed to set up MIB compiler: {str(e)}") + + def _format_alert(self, event: dict) -> AlertDto: + """ + Format SNMP trap data into an AlertDto. + """ + # Map SNMP trap severity to AlertSeverity + severity_map = { + 'emergency': AlertSeverity.CRITICAL, + 'alert': AlertSeverity.CRITICAL, + 'critical': AlertSeverity.CRITICAL, + 'error': AlertSeverity.HIGH, # Map error to HIGH since ERROR is not a valid severity + 'warning': AlertSeverity.WARNING, + 'notice': AlertSeverity.INFO, + 'info': AlertSeverity.INFO, + 'debug': AlertSeverity.INFO + } + + # Try to extract severity from trap data + trap_severity = event.get('severity', '').lower() + severity = severity_map.get(trap_severity, AlertSeverity.INFO) + + # Format description with variables + description = event.get('description', '') + if event.get('variables'): + description += "\n\nTrap Variables:\n" + for var_name, var_value in event['variables'].items(): + description += f"{var_name}: {var_value}\n" + + return AlertDto( + name=event.get('trap_type', 'SNMP Trap'), + message=event.get('message', ''), + description=description, + severity=severity, + status=AlertStatus.FIRING, + source=['snmp'], + source_type='snmp', + original_event=event, + ) + + def _parse_trap_oid(self, trap_oid: ObjectIdentifier) -> str: + """ + Parse trap OID to get a human-readable name. + """ + try: + if self.mib_view_controller: + mib_name = self.mib_view_controller.get_node_name(trap_oid) + return '.'.join(str(x) for x in mib_name) + except Exception as e: + self.logger.warning(f"Failed to parse trap OID {trap_oid}: {str(e)}") + + return str(trap_oid) + + def start_trap_receiver(self): + """ + Start the SNMP trap receiver. + """ + self.logger.info("Starting SNMP trap receiver") + self.snmp_engine = engine.SnmpEngine() + + try: + # Configure transport + config.add_transport( + self.snmp_engine, + udp.DOMAIN_NAME, + udp.UdpTransport().open_server_mode(('0.0.0.0', self.authentication_config.listen_port)) + ) + self.logger.debug(f"SNMP transport configured on port {self.authentication_config.listen_port}") + + # Configure authentication based on version + if self.authentication_config.snmp_version == "v3": + config.add_v3_user( + self.snmp_engine, + self.authentication_config.username, + config.usmHMACMD5AuthProtocol if self.authentication_config.auth_protocol == "MD5" else config.usmHMACSHAAuthProtocol, + self.authentication_config.auth_key, + config.usmDESPrivProtocol if self.authentication_config.priv_protocol == "DES" else config.usmAesCfb128Protocol, + self.authentication_config.priv_key + ) + self.logger.debug("SNMPv3 user configured") + else: + config.add_v1_system( + self.snmp_engine, + 'my-area', + self.authentication_config.community_string + ) + self.logger.debug(f"SNMPv{self.authentication_config.snmp_version} community configured") + + # Set up MIB compiler + self._setup_mib_compiler() + + def trap_callback(snmp_engine, state_reference, context_engine_id, context_name, + var_binds, cb_ctx): + """ + Callback function to handle received SNMP traps. + """ + try: + self.logger.info("Received SNMP trap") + trap_data = { + 'trap_type': 'SNMP Trap', + 'message': '', + 'description': '', + 'severity': AlertSeverity.INFO, + 'variables': {}, + 'context': { + 'engine_id': context_engine_id.prettyPrint() if context_engine_id else None, + 'context_name': context_name.prettyPrint() if context_name else None + } + } + + for name, val in var_binds: + try: + if isinstance(name, ObjectIdentifier): + var_name = self._parse_trap_oid(name) + else: + var_name = str(name) + + trap_data['variables'][var_name] = val.prettyPrint() + + # Try to identify trap type and severity from common OIDs + if 'trapType' in var_name.lower(): + trap_data['trap_type'] = val.prettyPrint() + elif 'severity' in var_name.lower(): + trap_data['severity'] = val.prettyPrint() + elif 'message' in var_name.lower(): + trap_data['message'] = val.prettyPrint() + elif 'description' in var_name.lower(): + trap_data['description'] = val.prettyPrint() + + except Exception as e: + self.logger.error(f"Error processing trap variable {name}: {str(e)}") + trap_data['variables'][str(name)] = str(val) + + self.logger.debug(f"Processed trap data: {trap_data}") + alert = self._format_alert(trap_data) + self._push_alert(alert.dict()) + self.logger.info("Successfully processed and pushed SNMP trap as alert") + + except Exception as e: + self.logger.error(f"Error processing SNMP trap: {str(e)}") + + # Set up notification receiver + ntfrcv.NotificationReceiver( + self.snmp_engine, + trap_callback + ) + + self.logger.info("SNMP trap receiver configured successfully") + self.snmp_engine.transport_dispatcher.jobStarted(1) + try: + self.snmp_engine.transport_dispatcher.runDispatcher() + except Exception as e: + self.logger.error(f"Error running SNMP dispatcher: {str(e)}") + self.snmp_engine.transport_dispatcher.close_dispatcher() + raise + + except Exception as e: + self.logger.error(f"Error starting SNMP trap receiver: {str(e)}") + if self.snmp_engine and self.snmp_engine.transport_dispatcher: + self.snmp_engine.transport_dispatcher.close_dispatcher() + raise ProviderException(f"Failed to start SNMP trap receiver: {str(e)}") + + def _notify(self, **kwargs): + """ + Not implemented for SNMP provider as it only receives traps. + """ + raise NotImplementedError("SNMP provider only supports receiving traps") + + def start_consume(self): + """ + Start consuming SNMP traps. + """ + self.logger.info("Starting SNMP trap consumer") + try: + self.start_trap_receiver() + return True + except Exception as e: + self.logger.error(f"Failed to start SNMP trap consumer: {str(e)}") + return False + + def status(self) -> dict: + """ + Return the status of the SNMP trap receiver. + """ + if not self.snmp_engine or not self.snmp_engine.transport_dispatcher: + return { + "status": "stopped", + "error": "SNMP trap receiver not running" + } + + try: + # Check if dispatcher is actually running + if self.snmp_engine.transport_dispatcher.jobs_are_pending(): + return { + "status": "running", + "error": "" + } + else: + return { + "status": "stopped", + "error": "SNMP dispatcher has no pending jobs" + } + except Exception as e: + return { + "status": "error", + "error": f"Error checking SNMP status: {str(e)}" + } + + @property + def is_consumer(self) -> bool: + """ + SNMP provider is a consumer as it receives traps. + """ + return True + + async def query(self, **kwargs): + """ + Query SNMP agent using GET, GETNEXT, or GETBULK operations. + """ + operation = kwargs.get('operation', 'GET') + target_host = kwargs.get('host') + target_port = kwargs.get('port', 161) + oid = kwargs.get('oid') + + if not target_host or not oid: + raise ProviderException("Host and OID are required for SNMP queries") + + snmp_engine = SnmpEngine() + + try: + auth_data = None + if self.authentication_config.snmp_version == 'v3': + auth_data = UsmUserData( + self.authentication_config.username, + self.authentication_config.auth_key, + self.authentication_config.priv_key + ) + else: + auth_data = CommunityData(self.authentication_config.community_string, + mpModel=0 if self.authentication_config.snmp_version == 'v1' else 1) + + transport_target = await UdpTransportTarget.create((target_host, target_port)) + context_data = ContextData() + + obj_type = ObjectType(ObjectIdentity(oid)) + + try: + if operation == 'GET': + error_indication, error_status, error_index, var_binds = await get_cmd( + snmp_engine, + auth_data, + transport_target, + context_data, + obj_type + ) + elif operation == 'GETNEXT': + error_indication, error_status, error_index, var_binds = await next_cmd( + snmp_engine, + auth_data, + transport_target, + context_data, + obj_type + ) + elif operation == 'GETBULK': + error_indication, error_status, error_index, var_binds = await bulk_cmd( + snmp_engine, + auth_data, + transport_target, + context_data, + 0, 25, # non-repeaters, max-repetitions + obj_type + ) + elif operation == 'SET': + value = kwargs.get('value') + if value is None: + raise ProviderException("Value is required for SET operation") + error_indication, error_status, error_index, var_binds = await set_cmd( + snmp_engine, + auth_data, + transport_target, + context_data, + obj_type, + value + ) + else: + raise ProviderException(f"Unsupported SNMP operation: {operation}") + + if error_indication: + raise ProviderException(f"SNMP error: {error_indication}") + elif error_status: + raise ProviderException(f"SNMP error: {error_status.prettyPrint()}") + + results = [] + for var_bind in var_binds: + name, value = var_bind + results.append({ + 'oid': name.prettyPrint(), + 'value': value.prettyPrint() + }) + + return results + + finally: + # Clean up transport dispatcher + if snmp_engine and hasattr(snmp_engine, 'transport_dispatcher'): + dispatcher = snmp_engine.transport_dispatcher + if hasattr(dispatcher, 'loopingcall'): + try: + if not dispatcher.loopingcall.done(): + dispatcher.loopingcall.cancel() + await dispatcher.loopingcall + except (asyncio.CancelledError, Exception) as e: + self.logger.debug(f"Error canceling dispatcher timeout: {e}") + + try: + dispatcher.close_dispatcher() + except Exception as e: + self.logger.debug(f"Error closing dispatcher: {e}") + + except Exception as e: + self.logger.error(f"Error performing SNMP {operation}: {str(e)}") + raise ProviderException(f"SNMP {operation} failed: {str(e)}") + finally: + # Ensure engine resources are cleaned up + if snmp_engine and hasattr(snmp_engine, 'transport_dispatcher'): + try: + snmp_engine.transport_dispatcher.close_dispatcher() + except Exception as e: + self.logger.debug(f"Error during final engine cleanup: {e}") \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index b92c27cfa..4802733a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiofiles" @@ -2340,6 +2340,23 @@ files = [ [package.extras] colors = ["colorama (>=0.4.6)"] +[[package]] +name = "jinja2" +version = "3.1.5" +description = "A very fast and expressive template engine." +optional = false +python-versions = ">=3.7" +files = [ + {file = "jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"}, + {file = "jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + [[package]] name = "jmespath" version = "1.0.1" @@ -2795,7 +2812,7 @@ name = "ndg-httpsclient" version = "0.5.1" description = "Provides enhanced HTTPS support for httplib and urllib2 using PyOpenSSL" optional = false -python-versions = ">=2.7,<3.0.0 || >=3.4.0" +python-versions = ">=2.7,<3.0.dev0 || >=3.4.dev0" files = [ {file = "ndg_httpsclient-0.5.1-py2-none-any.whl", hash = "sha256:d2c7225f6a1c6cf698af4ebc962da70178a99bcde24ee6d1961c4f3338130d57"}, {file = "ndg_httpsclient-0.5.1-py3-none-any.whl", hash = "sha256:dd174c11d971b6244a891f7be2b32ca9853d3797a72edb34fa5d7b07d8fff7d4"}, @@ -3276,6 +3293,17 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "ply" +version = "3.11" +description = "Python Lex & Yacc" +optional = false +python-versions = "*" +files = [ + {file = "ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce"}, + {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"}, +] + [[package]] name = "portalocker" version = "2.10.1" @@ -3674,6 +3702,7 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -4074,6 +4103,36 @@ files = [ {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, ] +[[package]] +name = "pysmi" +version = "1.5.9" +description = "A pure-Python implementation of SNMP/SMI MIB parsing and conversion library." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "pysmi-1.5.9-py3-none-any.whl", hash = "sha256:3deb22e6341ba4c7d545056745adf091d86c35028d21003638944e67e42b87fa"}, + {file = "pysmi-1.5.9.tar.gz", hash = "sha256:f6dfda838e3cba133169f1ff57f71a2841815d43db2e5c619b6e5db3b8560707"}, +] + +[package.dependencies] +Jinja2 = ">=3.1.3,<4.0.0" +ply = ">=3.11,<4.0" +requests = ">=2.26.0,<3.0.0" + +[[package]] +name = "pysnmp" +version = "7.1.16" +description = "A Python library for SNMP" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "pysnmp-7.1.16-py3-none-any.whl", hash = "sha256:e4769afdc7cc6438f07411c242a99a50cdfd7ab5a37c6668accb8f303d8cef73"}, + {file = "pysnmp-7.1.16.tar.gz", hash = "sha256:51581c70e410e456eb3faa24c42a094c82acfa961d16ad659b57c5818379dfcb"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.8,<0.5.0 || >0.5.0" + [[package]] name = "pytest" version = "8.3.4" @@ -4579,6 +4638,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, + {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, @@ -4587,6 +4647,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, + {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, @@ -4595,6 +4656,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, + {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, @@ -4603,6 +4665,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, + {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, @@ -4611,6 +4674,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, + {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -5560,4 +5624,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "a34041d4abecd3f6321fe244ff43f9162a10ca156b1b938ffe5007f3a1c02f31" +content-hash = "8878b542ae82a6d70bc4d8cdedb46fcc3e3a4208632ba7a05849f45e2c88e2cd" diff --git a/pyproject.toml b/pyproject.toml index 70f48e2d4..e7bc6ab43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,8 @@ psycopg-binary = "^3.2.3" psycopg = "^3.2.3" prometheus-client = "^0.21.1" psycopg2-binary = "^2.9.10" +pysnmp = "7.1.16" +pysmi = "1.5.9" prometheus-fastapi-instrumentator = "^7.0.0" slowapi = "^0.1.9" @@ -179,3 +181,7 @@ env = "GH_TOKEN" [tool.semantic_release.publish] dist_glob_patterns = ["dist/*"] upload_to_vcs_release = true + +[tool.pytest.ini_options] +asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" \ No newline at end of file diff --git a/tests/test_snmp_provider.py b/tests/test_snmp_provider.py new file mode 100644 index 000000000..e32cd08a5 --- /dev/null +++ b/tests/test_snmp_provider.py @@ -0,0 +1,389 @@ +""" +Test script for SNMP Provider +""" + +from asyncio.log import logger +import pytest +import asyncio +import os +from unittest.mock import patch, MagicMock, PropertyMock +from keep.contextmanager.contextmanager import ContextManager +from keep.providers.models.provider_config import ProviderConfig +from keep.providers.snmp_provider.snmp_provider import SnmpProvider, SnmpProviderAuthConfig +from keep.exceptions.provider_exception import ProviderException +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.providers.providers_factory import ProvidersFactory +from keep.api.models.provider import Provider + +# pytestmark = pytest.mark.asyncio(scope="session") + +@pytest.fixture(scope="session") +def event_loop_policy(): + """Create and configure the event loop policy for the test session.""" + policy = asyncio.get_event_loop_policy() + + def cleanup_loop(loop): + # Clean up any pending tasks + pending = asyncio.all_tasks(loop) + if pending: + # Log pending tasks + logger.info(f"Pending tasks: {pending}") + + # Cancel all pending tasks + for task in pending: + if not task.done(): + task.cancel() + + # Run the event loop to process cancellations + try: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + except asyncio.CancelledError: + pass + + # Wait a bit to ensure all resources are cleaned up + loop.run_until_complete(asyncio.sleep(0.1)) + + # Close the loop + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + except Exception as e: + logger.warning(f"Error during loop shutdown: {e}") + + loop.close() + + # Create a custom policy that includes our cleanup + class CustomPolicy(type(policy)): + def new_event_loop(self): + loop = super().new_event_loop() + loop.set_debug(True) + return loop + + def get_event_loop(self): + loop = super().get_event_loop() + loop.set_debug(True) + return loop + + custom_policy = CustomPolicy() + asyncio.set_event_loop_policy(custom_policy) + + yield custom_policy + +@pytest.fixture +def context_manager(): + return ContextManager( + tenant_id="test", + workflow_id="test_snmp" + ) + +@pytest.fixture +def provider_config(): + return ProviderConfig( + authentication={ + "snmp_version": "v2c", + "community_string": "public", + "listen_port": 162 + } + ) + +@pytest.fixture +def snmp_provider(context_manager, provider_config): + provider = SnmpProvider( + context_manager=context_manager, + provider_id="snmp_test", + config=provider_config + ) + return provider + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_get_operation(snmp_provider): + """Test SNMP GET operation""" + result = await snmp_provider.query( + operation='GET', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.1.1.0' # System description + ) + assert isinstance(result, list) + assert len(result) > 0 + assert 'oid' in result[0] + assert 'value' in result[0] + assert result[0]['oid'] == 'SNMPv2-SMI::mib-2.1.1.0' + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_getnext_operation(snmp_provider): + """Test SNMP GETNEXT operation""" + result = await snmp_provider.query( + operation='GETNEXT', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.1' # System MIB + ) + assert isinstance(result, list) + assert len(result) > 0 + assert 'oid' in result[0] + assert 'value' in result[0] + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_getbulk_operation(snmp_provider): + """Test SNMP GETBULK operation""" + result = await snmp_provider.query( + operation='GETBULK', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.2.2' # Interfaces table + ) + assert isinstance(result, list) + assert len(result) > 0 + for item in result: + assert 'oid' in item + assert 'value' in item + assert item['oid'].startswith('SNMPv2-SMI::mib-2.2.2.1') + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_invalid_operation(snmp_provider): + """Test invalid SNMP operation""" + with pytest.raises(ProviderException) as exc_info: + await snmp_provider.query( + operation='INVALID', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.1.1.0' + ) + assert "Unsupported SNMP operation" in str(exc_info.value) + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_missing_parameters(snmp_provider): + """Test missing required parameters""" + with pytest.raises(ProviderException) as exc_info: + await snmp_provider.query( + operation='GET', + port=161 + ) + assert "Host and OID are required" in str(exc_info.value) + + +def test_provider_config_validation(context_manager): + """Test provider configuration validation""" + # Test valid v2c config + config = ProviderConfig( + authentication={ + "snmp_version": "v2c", + "community_string": "public", + "listen_port": 162 + } + ) + provider = SnmpProvider(context_manager, "test", config) + assert isinstance(provider.authentication_config, SnmpProviderAuthConfig) + + # Test valid v3 config + config = ProviderConfig( + authentication={ + "snmp_version": "v3", + "username": "user", + "auth_key": "authkey", + "priv_key": "privkey", + "auth_protocol": "SHA", + "priv_protocol": "AES", + "listen_port": 162 + } + ) + provider = SnmpProvider(context_manager, "test", config) + assert isinstance(provider.authentication_config, SnmpProviderAuthConfig) + + # Test invalid v3 config (missing required fields) + config = ProviderConfig( + authentication={ + "snmp_version": "v3", + "listen_port": 162 + } + ) + with pytest.raises(ProviderException) as exc_info: + SnmpProvider(context_manager, "test", config) + assert "Username is required for SNMPv3" in str(exc_info.value) + +def test_provider_scopes_validation(context_manager, provider_config): + """Test provider scopes validation""" + provider = SnmpProvider(context_manager, "test", provider_config) + scopes = provider.validate_scopes() + assert isinstance(scopes, dict) + assert "receive_traps" in scopes + assert isinstance(scopes["receive_traps"], bool) + +def create_test_provider(): + return ProviderConfig( + authentication={ + "snmp_version": "v2c", + "community_string": "public", + "listen_port": 162 + } + ) + +# @pytest.mark.asyncio +def test_format_alert(): + # Create a mock provider + provider = SnmpProvider(MagicMock(), "test", create_test_provider()) + + # Test mapping of different trap severities + severity_map = { + 'emergency': 'critical', + 'alert': 'critical', + 'critical': 'critical', + 'error': 'high', + 'warning': 'warning', + 'notice': 'info', + 'info': 'info', + 'debug': 'info', + 'unknown': 'info' + } + + for trap_severity, expected_severity in severity_map.items(): + alert = provider._format_alert({ + 'trap_type': 'test_trap', + 'severity': trap_severity, + 'source_address': '127.0.0.1', + 'trap_timestamp': '2024-01-01T00:00:00Z' + }) + assert isinstance(alert, AlertDto) + assert alert.severity == expected_severity, f"Expected {expected_severity} for trap severity {trap_severity}, got {alert.severity}" + + # Ensure all pending tasks are completed + # pending = asyncio.all_tasks() + # if pending: + # await asyncio.wait(pending) + +def test_provider_config_validation(): + """Test that SNMPv3 configuration validation works correctly""" + # Test that validation fails when username is missing for SNMPv3 + config = ProviderConfig( + authentication={ + "snmp_version": "v3", + "target_host": "demo.snmplabs.com", + "target_port": 161 + } + ) + + # Create provider - this should raise an exception since username is required for SNMPv3 + with pytest.raises(ProviderException, match="Username is required for SNMPv3"): + provider = SnmpProvider(MagicMock(), "test", config) + + # Test that validation succeeds with valid SNMPv3 configuration + config = ProviderConfig( + authentication={ + "snmp_version": "v3", + "username": "test_user", + "auth_protocol": "SHA", + "auth_key": "test_auth_key", + "priv_protocol": "AES", + "priv_key": "test_priv_key", + "target_host": "demo.snmplabs.com", + "target_port": 161 + } + ) + + # This should not raise an exception + provider = SnmpProvider(MagicMock(), "test", config) + assert provider.authentication_config.username == "test_user" + +@pytest.mark.asyncio(loop_scope="function") +async def test_provider_disposal(): + """Test provider disposal""" + provider = SnmpProvider(MagicMock(), "test", create_test_provider()) + + # Create a mock dispatcher with async support + mock_dispatcher = MagicMock() + + # Create a future for the loopingcall + loop = asyncio.get_event_loop() + loopingcall = loop.create_future() + loopingcall.cancel() # Mark it as cancelled + + mock_dispatcher.loopingcall = loopingcall + mock_dispatcher._transports = {} + + # Create a mock engine + provider.snmp_engine = MagicMock() + provider.snmp_engine.transport_dispatcher = mock_dispatcher + + await provider.dispose() + + # Verify that close_dispatcher was called + mock_dispatcher.close_dispatcher.assert_called_once() + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_get(): + provider = SnmpProvider(MagicMock(), "test", create_test_provider()) + try: + # Create a transport endpoint + # transport_target = await UdpTransportTarget.create(('demo.pysnmp.com', 161)) + + result = await provider.query( + operation='GET', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.1.1.0' + ) + assert result is not None + finally: + # Ensure proper cleanup + await provider.dispose() + # Wait a bit for any pending tasks to complete + await asyncio.sleep(0.1) + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_getnext(): + provider = SnmpProvider(MagicMock(), "test", create_test_provider()) + try: + result = await provider.query( + operation='GETNEXT', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.1.1.0' + ) + assert result is not None + finally: + await provider.dispose() + # await asyncio.sleep(0.1) + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_getbulk(): + provider = SnmpProvider(MagicMock(), "test", create_test_provider()) + try: + result = await provider.query( + operation='GETBULK', + host='demo.pysnmp.com', + port=161, + oid='1.3.6.1.2.1.1' + ) + assert result is not None + finally: + await provider.dispose() + # await asyncio.sleep(0.1) + +def test_mib_compiler_setup(context_manager): + """Test MIB compiler setup""" + # Create a temporary MIB directory with a test MIB file + test_mib_dir = "test_mibs" + os.makedirs(test_mib_dir, exist_ok=True) + + config = ProviderConfig( + authentication={ + "snmp_version": "v2c", + "community_string": "public", + "listen_port": 162, + "mib_dirs": [test_mib_dir] + } + ) + + with patch('pysnmp.smi.compiler.add_mib_compiler') as mock_add_compiler: + provider = SnmpProvider(context_manager, "test", config) + provider._setup_mib_compiler() + + assert provider.mib_view_controller is not None + mock_add_compiler.assert_called_once() + + # Clean up + os.rmdir(test_mib_dir) + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file From 68a9128985ca38ccd1e01133816465dd5ba7c51e Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Fri, 24 Jan 2025 18:57:53 +0100 Subject: [PATCH 2/3] feat(snmp): enhance SNMP provider with improved validation and MIB support --- keep/providers/snmp_provider/snmp_provider.py | 228 ++++++++++++++---- tests/test_snmp_provider.py | 95 ++++---- 2 files changed, 227 insertions(+), 96 deletions(-) diff --git a/keep/providers/snmp_provider/snmp_provider.py b/keep/providers/snmp_provider/snmp_provider.py index 9413046ab..46199a931 100644 --- a/keep/providers/snmp_provider/snmp_provider.py +++ b/keep/providers/snmp_provider/snmp_provider.py @@ -19,6 +19,7 @@ bulk_cmd ) from pysnmp.smi.rfc1902 import ObjectIdentity +from pysnmp.proto.rfc1902 import Integer, OctetString, IpAddress, Counter32, Counter64, Gauge32, Unsigned32, TimeTicks, Bits, Opaque from pysnmp.carrier.asyncio.dgram import udp from pysnmp.entity import engine, config from pysnmp.smi import builder, view, compiler @@ -35,7 +36,7 @@ from keep.exceptions.provider_exception import ProviderException -@pydantic.dataclasses.dataclass +@pydantic.dataclasses.dataclass(config=dict(validate_assignment=True)) class SnmpProviderAuthConfig: """SNMP authentication configuration.""" @@ -52,7 +53,7 @@ class SnmpProviderAuthConfig: community_string: str = dataclasses.field( metadata={ "required": True, - "description": "SNMP community string for authentication", + "description": "SNMP community string for authentication (required for v1/v2c)", "sensitive": True, "config_main_group": "authentication", }, @@ -74,7 +75,7 @@ class SnmpProviderAuthConfig: username: str = dataclasses.field( metadata={ "required": False, - "description": "SNMPv3 username", + "description": "SNMPv3 username (required for v3)", "config_main_group": "authentication", }, default="", @@ -84,7 +85,7 @@ class SnmpProviderAuthConfig: default="SHA", metadata={ "required": False, - "description": "SNMPv3 authentication protocol", + "description": "SNMPv3 authentication protocol (required for v3)", "type": "select", "options": ["MD5", "SHA"], "config_main_group": "authentication", @@ -95,7 +96,7 @@ class SnmpProviderAuthConfig: metadata={ "required": False, "sensitive": True, - "description": "SNMPv3 authentication key", + "description": "SNMPv3 authentication key (required for v3)", "config_main_group": "authentication", }, default="", @@ -105,7 +106,7 @@ class SnmpProviderAuthConfig: default="AES", metadata={ "required": False, - "description": "SNMPv3 privacy protocol", + "description": "SNMPv3 privacy protocol (required for v3)", "type": "select", "options": ["DES", "AES"], "config_main_group": "authentication", @@ -116,7 +117,7 @@ class SnmpProviderAuthConfig: metadata={ "required": False, "sensitive": True, - "description": "SNMPv3 privacy key", + "description": "SNMPv3 privacy key (required for v3)", "config_main_group": "authentication", }, default="", @@ -132,6 +133,23 @@ class SnmpProviderAuthConfig: default_factory=list, ) + def __post_init__(self): + """Validate SNMPv3 fields after initialization.""" + if self.snmp_version == 'v3': + required_fields = { + 'username': 'Username', + 'auth_key': 'Authentication key', + 'priv_key': 'Privacy key' + } + + missing_fields = [] + for field, display_name in required_fields.items(): + if not getattr(self, field): + missing_fields.append(display_name) + + if missing_fields: + raise ProviderException(f"The following fields are required for SNMPv3: {', '.join(missing_fields)}") + class SnmpProvider(BaseProvider): """ @@ -170,10 +188,10 @@ async def dispose(self): try: if not dispatcher.loopingcall.done(): dispatcher.loopingcall.cancel() - # Wait for cancellation to complete + # Wait for the future to complete after cancellation try: - asyncio.wait_for(dispatcher.loopingcall, timeout=1) - except (asyncio.TimeoutError, asyncio.CancelledError): + await asyncio.shield(dispatcher.loopingcall) + except asyncio.CancelledError: pass except Exception as e: self.logger.debug(f"Error canceling dispatcher timeout: {e}") @@ -202,23 +220,16 @@ def validate_config(self): **self.config.authentication ) - # Validate SNMPv3 configuration if v3 is selected - if self.authentication_config.snmp_version == "v3": - if not self.authentication_config.username: - raise ProviderException("Username is required for SNMPv3") - if not self.authentication_config.auth_key: - raise ProviderException("Authentication key is required for SNMPv3") - if not self.authentication_config.priv_key: - raise ProviderException("Privacy key is required for SNMPv3") - # Validate MIB directories for mib_dir in self.authentication_config.mib_dirs: if not os.path.isdir(mib_dir): raise ProviderException(f"MIB directory does not exist: {mib_dir}") - def validate_scopes(self) -> dict[str, bool | str]: + def validate_scopes(self) -> dict[str, bool]: """ Validate that the scopes provided are correct. + Returns a dictionary mapping scope names to boolean values indicating if they are valid. + Any validation errors will be logged at debug level. """ try: # Try to create an SNMP engine with the provided config @@ -252,7 +263,8 @@ def validate_scopes(self) -> dict[str, bool | str]: snmp_engine.transport_dispatcher.close_dispatcher() return {"receive_traps": True} except Exception as e: - return {"receive_traps": str(e)} + self.logger.debug(f"SNMP trap receiver validation failed: {str(e)}") + return {"receive_traps": False} def _setup_mib_compiler(self): """ @@ -372,13 +384,19 @@ def trap_callback(snmp_engine, state_reference, context_engine_id, context_name, var_binds, cb_ctx): """ Callback function to handle received SNMP traps. + + This function processes incoming SNMP traps and converts them to alerts. + It carefully extracts and validates trap information, ensuring all critical + fields are properly populated. """ try: self.logger.info("Received SNMP trap") + + # Initialize trap data with required fields trap_data = { 'trap_type': 'SNMP Trap', - 'message': '', - 'description': '', + 'message': [], # List to collect message parts + 'description': [], # List to collect description parts 'severity': AlertSeverity.INFO, 'variables': {}, 'context': { @@ -387,29 +405,64 @@ def trap_callback(snmp_engine, state_reference, context_engine_id, context_name, } } + # First pass: Collect all variables and their values for name, val in var_binds: try: - if isinstance(name, ObjectIdentifier): - var_name = self._parse_trap_oid(name) - else: - var_name = str(name) - - trap_data['variables'][var_name] = val.prettyPrint() + var_name = self._parse_trap_oid(name) if isinstance(name, ObjectIdentifier) else str(name) + var_value = val.prettyPrint() + trap_data['variables'][var_name] = var_value - # Try to identify trap type and severity from common OIDs - if 'trapType' in var_name.lower(): - trap_data['trap_type'] = val.prettyPrint() - elif 'severity' in var_name.lower(): - trap_data['severity'] = val.prettyPrint() - elif 'message' in var_name.lower(): - trap_data['message'] = val.prettyPrint() - elif 'description' in var_name.lower(): - trap_data['description'] = val.prettyPrint() + # Store the raw name-value pair for pattern matching + name_lower = var_name.lower() + + # Identify trap metadata from variable names using comprehensive pattern matching + if any(type_pattern in name_lower for type_pattern in ['traptype', 'trap.type', 'event.type']): + trap_data['trap_type'] = var_value + elif any(sev_pattern in name_lower for sev_pattern in ['severity', 'priority', 'level']): + trap_data['severity'] = var_value + elif any(msg_pattern in name_lower for msg_pattern in ['message', 'msg', 'text']): + trap_data['message'].append(var_value) + elif any(desc_pattern in name_lower for desc_pattern in ['description', 'desc', 'details']): + trap_data['description'].append(var_value) except Exception as e: self.logger.error(f"Error processing trap variable {name}: {str(e)}") + # Fallback: store raw values if processing fails trap_data['variables'][str(name)] = str(val) + # Second pass: Post-process collected data + + # Join collected messages and descriptions + trap_data['message'] = ' '.join(filter(None, trap_data['message'])) or 'SNMP Trap Received' + trap_data['description'] = ' '.join(filter(None, trap_data['description'])) + + # Map severity string to AlertSeverity enum if it's a string + if isinstance(trap_data['severity'], str): + severity_map = { + 'emergency': AlertSeverity.CRITICAL, + 'alert': AlertSeverity.CRITICAL, + 'critical': AlertSeverity.CRITICAL, + 'error': AlertSeverity.HIGH, + 'warning': AlertSeverity.WARNING, + 'notice': AlertSeverity.INFO, + 'info': AlertSeverity.INFO, + 'debug': AlertSeverity.INFO, + # Add numeric severity mappings + '0': AlertSeverity.INFO, + '1': AlertSeverity.WARNING, + '2': AlertSeverity.HIGH, + '3': AlertSeverity.CRITICAL + } + trap_data['severity'] = severity_map.get( + trap_data['severity'].lower(), + AlertSeverity.INFO + ) + + # Ensure description includes all variables if no specific description was found + if not trap_data['description']: + var_desc = [f"{k}: {v}" for k, v in trap_data['variables'].items()] + trap_data['description'] = "Trap Variables:\n" + "\n".join(var_desc) + self.logger.debug(f"Processed trap data: {trap_data}") alert = self._format_alert(trap_data) self._push_alert(alert.dict()) @@ -500,13 +553,22 @@ async def query(self, **kwargs): target_host = kwargs.get('host') target_port = kwargs.get('port', 161) oid = kwargs.get('oid') + timeout = kwargs.get('timeout', 10) # Default 10 second timeout + retries = kwargs.get('retries', 3) # Default 3 retries if not target_host or not oid: raise ProviderException("Host and OID are required for SNMP queries") - snmp_engine = SnmpEngine() + snmp_engine = None + dispatcher = None try: + snmp_engine = SnmpEngine() + + # Initialize MIB view controller if not already initialized + if not self.mib_view_controller: + self._setup_mib_compiler() + auth_data = None if self.authentication_config.snmp_version == 'v3': auth_data = UsmUserData( @@ -518,7 +580,11 @@ async def query(self, **kwargs): auth_data = CommunityData(self.authentication_config.community_string, mpModel=0 if self.authentication_config.snmp_version == 'v1' else 1) - transport_target = await UdpTransportTarget.create((target_host, target_port)) + transport_target = await UdpTransportTarget.create( + (target_host, target_port), + timeout=timeout, + retries=retries + ) context_data = ContextData() obj_type = ObjectType(ObjectIdentity(oid)) @@ -551,20 +617,73 @@ async def query(self, **kwargs): ) elif operation == 'SET': value = kwargs.get('value') + value_type = kwargs.get('value_type', 'string').lower() + if value is None: raise ProviderException("Value is required for SET operation") + + # Map of supported SNMP value types and their corresponding classes + type_map = { + 'integer': Integer, + 'int': Integer, + 'int32': Integer, + 'string': OctetString, + 'octetstring': OctetString, + 'ipaddress': IpAddress, + 'counter32': Counter32, + 'counter64': Counter64, + 'gauge32': Gauge32, + 'unsigned32': Unsigned32, + 'timeticks': TimeTicks, + 'bits': Bits, + 'opaque': Opaque + } + + if value_type not in type_map: + raise ProviderException( + f"Unsupported value type: {value_type}. " + f"Supported types are: {', '.join(type_map.keys())}" + ) + + try: + # Convert the value to the appropriate SNMP type + snmp_type = type_map[value_type] + if value_type in ['integer', 'int', 'int32', 'counter32', 'counter64', 'gauge32', 'unsigned32', 'timeticks']: + typed_value = snmp_type(int(value)) + elif value_type == 'ipaddress': + # Validate IP address format + import ipaddress + ipaddress.ip_address(value) # This will raise ValueError if invalid + typed_value = snmp_type(value) + elif value_type == 'bits': + # Expect a comma-separated list of bit positions + bit_positions = [int(x.strip()) for x in str(value).split(',')] + typed_value = snmp_type(names=bit_positions) + else: + typed_value = snmp_type(str(value)) + + except (ValueError, TypeError) as e: + raise ProviderException( + f"Invalid value format for type {value_type}: {str(e)}" + ) + error_indication, error_status, error_index, var_binds = await set_cmd( snmp_engine, auth_data, transport_target, context_data, - obj_type, - value + ObjectType(ObjectIdentity(oid), typed_value) ) else: raise ProviderException(f"Unsupported SNMP operation: {operation}") if error_indication: + error_msg = str(error_indication) + if "No SNMP response received before timeout" in error_msg: + raise ProviderException( + f"SNMP {operation} timed out after {timeout} seconds with {retries} retries. " + f"Consider increasing timeout or retries." + ) raise ProviderException(f"SNMP error: {error_indication}") elif error_status: raise ProviderException(f"SNMP error: {error_status.prettyPrint()}") @@ -572,8 +691,29 @@ async def query(self, **kwargs): results = [] for var_bind in var_binds: name, value = var_bind + # Use MIB view controller to translate OID to proper MIB name + try: + if self.mib_view_controller: + mib_name = self.mib_view_controller.get_node_name(name) + if len(mib_name) > 0: + # First element is the MIB module name (e.g. 'SNMPv2-MIB') + mib_module = str(mib_name[0]) + # Rest are the object parts (e.g. 'sysDescr', '0') + object_parts = [] + for part in mib_name[1:]: + if isinstance(part, (str, int)): + object_parts.append(str(part)) + oid = f"{mib_module}::{'.'.join(object_parts)}" + else: + oid = name.prettyPrint() + else: + oid = name.prettyPrint() + except Exception as e: + self.logger.debug(f"Failed to translate OID using MIB: {str(e)}") + oid = name.prettyPrint() + results.append({ - 'oid': name.prettyPrint(), + 'oid': oid, 'value': value.prettyPrint() }) @@ -605,4 +745,4 @@ async def query(self, **kwargs): try: snmp_engine.transport_dispatcher.close_dispatcher() except Exception as e: - self.logger.debug(f"Error during final engine cleanup: {e}") \ No newline at end of file + self.logger.debug(f"Error during final cleanup: {e}") \ No newline at end of file diff --git a/tests/test_snmp_provider.py b/tests/test_snmp_provider.py index e32cd08a5..73584d828 100644 --- a/tests/test_snmp_provider.py +++ b/tests/test_snmp_provider.py @@ -11,9 +11,7 @@ from keep.providers.models.provider_config import ProviderConfig from keep.providers.snmp_provider.snmp_provider import SnmpProvider, SnmpProviderAuthConfig from keep.exceptions.provider_exception import ProviderException -from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus -from keep.providers.providers_factory import ProvidersFactory -from keep.api.models.provider import Provider +from keep.api.models.alert import AlertDto # pytestmark = pytest.mark.asyncio(scope="session") @@ -98,6 +96,11 @@ def snmp_provider(context_manager, provider_config): @pytest.mark.asyncio(loop_scope="function") async def test_snmp_get_operation(snmp_provider): """Test SNMP GET operation""" + # Mock the MIB view controller + mock_view_controller = MagicMock() + mock_view_controller.get_node_name.return_value = ('SNMPv2-MIB', 'sysDescr', 0) + snmp_provider.mib_view_controller = mock_view_controller + result = await snmp_provider.query( operation='GET', host='demo.pysnmp.com', @@ -108,11 +111,16 @@ async def test_snmp_get_operation(snmp_provider): assert len(result) > 0 assert 'oid' in result[0] assert 'value' in result[0] - assert result[0]['oid'] == 'SNMPv2-SMI::mib-2.1.1.0' + assert result[0]['oid'] == 'SNMPv2-MIB::sysDescr.0' @pytest.mark.asyncio(loop_scope="function") async def test_snmp_getnext_operation(snmp_provider): """Test SNMP GETNEXT operation""" + # Mock the MIB view controller + mock_view_controller = MagicMock() + mock_view_controller.get_node_name.return_value = ('SNMPv2-MIB', 'sysDescr', 0) + snmp_provider.mib_view_controller = mock_view_controller + result = await snmp_provider.query( operation='GETNEXT', host='demo.pysnmp.com', @@ -127,6 +135,11 @@ async def test_snmp_getnext_operation(snmp_provider): @pytest.mark.asyncio(loop_scope="function") async def test_snmp_getbulk_operation(snmp_provider): """Test SNMP GETBULK operation""" + # Mock the MIB view controller + mock_view_controller = MagicMock() + mock_view_controller.get_node_name.return_value = ('SNMPv2-MIB', 'ifTable', 'ifEntry', 'ifIndex', 1) + snmp_provider.mib_view_controller = mock_view_controller + result = await snmp_provider.query( operation='GETBULK', host='demo.pysnmp.com', @@ -138,7 +151,7 @@ async def test_snmp_getbulk_operation(snmp_provider): for item in result: assert 'oid' in item assert 'value' in item - assert item['oid'].startswith('SNMPv2-SMI::mib-2.2.2.1') + assert item['oid'].startswith('SNMPv2-MIB::ifTable.ifEntry.ifIndex') @pytest.mark.asyncio(loop_scope="function") async def test_snmp_invalid_operation(snmp_provider): @@ -163,44 +176,6 @@ async def test_snmp_missing_parameters(snmp_provider): assert "Host and OID are required" in str(exc_info.value) -def test_provider_config_validation(context_manager): - """Test provider configuration validation""" - # Test valid v2c config - config = ProviderConfig( - authentication={ - "snmp_version": "v2c", - "community_string": "public", - "listen_port": 162 - } - ) - provider = SnmpProvider(context_manager, "test", config) - assert isinstance(provider.authentication_config, SnmpProviderAuthConfig) - - # Test valid v3 config - config = ProviderConfig( - authentication={ - "snmp_version": "v3", - "username": "user", - "auth_key": "authkey", - "priv_key": "privkey", - "auth_protocol": "SHA", - "priv_protocol": "AES", - "listen_port": 162 - } - ) - provider = SnmpProvider(context_manager, "test", config) - assert isinstance(provider.authentication_config, SnmpProviderAuthConfig) - - # Test invalid v3 config (missing required fields) - config = ProviderConfig( - authentication={ - "snmp_version": "v3", - "listen_port": 162 - } - ) - with pytest.raises(ProviderException) as exc_info: - SnmpProvider(context_manager, "test", config) - assert "Username is required for SNMPv3" in str(exc_info.value) def test_provider_scopes_validation(context_manager, provider_config): """Test provider scopes validation""" @@ -254,7 +229,7 @@ def test_format_alert(): def test_provider_config_validation(): """Test that SNMPv3 configuration validation works correctly""" - # Test that validation fails when username is missing for SNMPv3 + # Test that validation fails when username, auth_key and priv_key are missing for SNMPv3 config = ProviderConfig( authentication={ "snmp_version": "v3", @@ -264,7 +239,7 @@ def test_provider_config_validation(): ) # Create provider - this should raise an exception since username is required for SNMPv3 - with pytest.raises(ProviderException, match="Username is required for SNMPv3"): + with pytest.raises(ProviderException, match="The following fields are required for SNMPv3: Username, Authentication key, Privacy key"): provider = SnmpProvider(MagicMock(), "test", config) # Test that validation succeeds with valid SNMPv3 configuration @@ -314,16 +289,32 @@ async def test_provider_disposal(): async def test_snmp_get(): provider = SnmpProvider(MagicMock(), "test", create_test_provider()) try: - # Create a transport endpoint - # transport_target = await UdpTransportTarget.create(('demo.pysnmp.com', 161)) - + # Mock the MIB view controller + mock_view_controller = MagicMock() + mock_view_controller.get_node_name.return_value = ('SNMPv2-MIB', 'sysDescr', 0) + provider.mib_view_controller = mock_view_controller + result = await provider.query( operation='GET', host='demo.pysnmp.com', port=161, - oid='1.3.6.1.2.1.1.1.0' + oid='1.3.6.1.2.1.1.1.0' # System description ) - assert result is not None + # Validate result structure and content + assert isinstance(result, list), "Result should be a list" + assert len(result) > 0, "Result should not be empty" + + # Validate first result item structure + first_result = result[0] + assert isinstance(first_result, dict), "Result item should be a dictionary" + assert 'oid' in first_result, "Result should contain 'oid' key" + assert 'value' in first_result, "Result should contain 'value' key" + + # Validate OID format + assert first_result['oid'] == 'SNMPv2-MIB::sysDescr.0', "OID should match expected format" + + # Validate value is not empty + assert first_result['value'], "Result value should not be empty" finally: # Ensure proper cleanup await provider.dispose() @@ -343,7 +334,7 @@ async def test_snmp_getnext(): assert result is not None finally: await provider.dispose() - # await asyncio.sleep(0.1) + await asyncio.sleep(0.1) @pytest.mark.asyncio(loop_scope="function") async def test_snmp_getbulk(): @@ -358,7 +349,7 @@ async def test_snmp_getbulk(): assert result is not None finally: await provider.dispose() - # await asyncio.sleep(0.1) + await asyncio.sleep(0.1) def test_mib_compiler_setup(context_manager): """Test MIB compiler setup""" From 4518e69fddc8fd91ac73080acd0035819202508a Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Fri, 24 Jan 2025 23:16:23 +0100 Subject: [PATCH 3/3] fix: improve SNMP provider implementation and tests - Add cleanup, error handling, and value type conversion improvements --- keep/providers/snmp_provider/snmp_provider.py | 407 +++++++++--------- tests/test_snmp_provider.py | 113 ++++- 2 files changed, 324 insertions(+), 196 deletions(-) diff --git a/keep/providers/snmp_provider/snmp_provider.py b/keep/providers/snmp_provider/snmp_provider.py index 46199a931..2ebff95e6 100644 --- a/keep/providers/snmp_provider/snmp_provider.py +++ b/keep/providers/snmp_provider/snmp_provider.py @@ -177,6 +177,41 @@ def __init__( self.trap_receiver = None self.mib_view_controller = None + async def _cleanup_dispatcher(self, dispatcher, operation_name=""): + """ + Helper method to safely cleanup SNMP dispatcher resources. + """ + if not dispatcher: + return + + # Handle loopingcall cleanup + if hasattr(dispatcher, 'loopingcall'): + try: + loopingcall = dispatcher.loopingcall + if isinstance(loopingcall, asyncio.Future) and not loopingcall.done(): + loopingcall.cancel() + try: + await asyncio.shield(loopingcall) + except (asyncio.CancelledError, Exception) as e: + self.logger.debug(f"Error during loopingcall cleanup ({operation_name}): {e}") + except Exception as e: + self.logger.debug(f"Error handling loopingcall ({operation_name}): {e}") + + # Close transports + if hasattr(dispatcher, '_transports'): + for transport in dispatcher._transports.values(): + if hasattr(transport, 'close'): + try: + transport.close() + except Exception as e: + self.logger.debug(f"Error closing transport ({operation_name}): {e}") + + # Close dispatcher + try: + dispatcher.close_dispatcher() + except Exception as e: + self.logger.debug(f"Error closing dispatcher ({operation_name}): {e}") + async def dispose(self): """ Clean up SNMP engine and trap receiver. @@ -184,23 +219,7 @@ async def dispose(self): if self.snmp_engine: try: dispatcher = self.snmp_engine.transport_dispatcher - if hasattr(dispatcher, 'loopingcall'): - try: - if not dispatcher.loopingcall.done(): - dispatcher.loopingcall.cancel() - # Wait for the future to complete after cancellation - try: - await asyncio.shield(dispatcher.loopingcall) - except asyncio.CancelledError: - pass - except Exception as e: - self.logger.debug(f"Error canceling dispatcher timeout: {e}") - - if hasattr(dispatcher, '_transports'): - for transport in dispatcher._transports.values(): - if hasattr(transport, 'close'): - transport.close() - dispatcher.close_dispatcher() + await self._cleanup_dispatcher(dispatcher, "dispose") self.snmp_engine = None except Exception as e: self.logger.error(f"Error disposing SNMP engine: {str(e)}") @@ -545,22 +564,93 @@ def is_consumer(self) -> bool: """ return True + @dataclasses.dataclass + class SnmpQueryParams: + """Parameters for SNMP query operations.""" + host: str + oid: str + operation: typing.Literal["GET", "GETNEXT", "GETBULK", "SET"] = "GET" + port: int = 161 + timeout: int = 10 + retries: int = 3 + value: typing.Optional[typing.Any] = None + value_type: typing.Optional[str] = None + + class SnmpValueTypeHandler: + """Handler for SNMP value type conversions.""" + + TYPE_MAP = { + 'integer': Integer, + 'int': Integer, + 'int32': Integer, + 'string': OctetString, + 'octetstring': OctetString, + 'ipaddress': IpAddress, + 'counter32': Counter32, + 'counter64': Counter64, + 'gauge32': Gauge32, + 'unsigned32': Unsigned32, + 'timeticks': TimeTicks, + 'bits': Bits, + 'opaque': Opaque + } + + NUMERIC_TYPES = {'integer', 'int', 'int32', 'counter32', 'counter64', + 'gauge32', 'unsigned32', 'timeticks'} + + @classmethod + def convert_value(cls, value: typing.Any, value_type: str) -> typing.Any: + """Convert a value to its appropriate SNMP type.""" + value_type = value_type.lower() + + if value_type not in cls.TYPE_MAP: + raise ProviderException( + f"Unsupported value type: {value_type}. " + f"Supported types are: {', '.join(cls.TYPE_MAP.keys())}" + ) + + snmp_type = cls.TYPE_MAP[value_type] + + try: + if value_type in cls.NUMERIC_TYPES: + try: + return snmp_type(int(str(value).strip())) + except (ValueError, TypeError) as e: + raise ProviderException(f"Invalid numeric value '{value}' for type {value_type}: {str(e)}") + elif value_type == 'ipaddress': + import ipaddress + try: + ipaddress.ip_address(str(value)) # Validate IP address format + return snmp_type(str(value)) + except ValueError as e: + raise ProviderException(f"Invalid IP address format: {str(e)}") + elif value_type == 'bits': + try: + bit_positions = [int(x.strip()) for x in str(value).split(',')] + return snmp_type(names=bit_positions) + except (ValueError, TypeError) as e: + raise ProviderException(f"Invalid bits format. Expected comma-separated integers: {str(e)}") + else: + return snmp_type(str(value)) + except Exception as e: + if isinstance(e, ProviderException): + raise + raise ProviderException(f"Error converting value '{value}' to type {value_type}: {str(e)}") + async def query(self, **kwargs): """ - Query SNMP agent using GET, GETNEXT, or GETBULK operations. + Query SNMP agent using GET, GETNEXT, GETBULK, or SET operations. """ - operation = kwargs.get('operation', 'GET') - target_host = kwargs.get('host') - target_port = kwargs.get('port', 161) - oid = kwargs.get('oid') - timeout = kwargs.get('timeout', 10) # Default 10 second timeout - retries = kwargs.get('retries', 3) # Default 3 retries - - if not target_host or not oid: + # Check required parameters first + if not kwargs.get('host') or not kwargs.get('oid'): raise ProviderException("Host and OID are required for SNMP queries") + + try: + params = self.SnmpQueryParams(**kwargs) + except TypeError as e: + raise ProviderException(f"Invalid query parameters: {str(e)}") snmp_engine = None - dispatcher = None try: snmp_engine = SnmpEngine() @@ -569,180 +659,107 @@ async def query(self, **kwargs): if not self.mib_view_controller: self._setup_mib_compiler() - auth_data = None - if self.authentication_config.snmp_version == 'v3': - auth_data = UsmUserData( - self.authentication_config.username, - self.authentication_config.auth_key, - self.authentication_config.priv_key + auth_data = self._get_auth_data() + try: + transport_target = await UdpTransportTarget.create( + (str(params.host), int(params.port)), + timeout=int(params.timeout), + retries=int(params.retries) ) - else: - auth_data = CommunityData(self.authentication_config.community_string, - mpModel=0 if self.authentication_config.snmp_version == 'v1' else 1) - - transport_target = await UdpTransportTarget.create( - (target_host, target_port), - timeout=timeout, - retries=retries - ) + except Exception as e: + raise ProviderException(f"Failed to create transport target: {str(e)}") context_data = ContextData() - obj_type = ObjectType(ObjectIdentity(oid)) - - try: - if operation == 'GET': - error_indication, error_status, error_index, var_binds = await get_cmd( - snmp_engine, - auth_data, - transport_target, - context_data, - obj_type - ) - elif operation == 'GETNEXT': - error_indication, error_status, error_index, var_binds = await next_cmd( - snmp_engine, - auth_data, - transport_target, - context_data, - obj_type - ) - elif operation == 'GETBULK': - error_indication, error_status, error_index, var_binds = await bulk_cmd( - snmp_engine, - auth_data, - transport_target, - context_data, - 0, 25, # non-repeaters, max-repetitions - obj_type + # Prepare the object type based on operation + if params.operation == 'SET': + if params.value is None: + raise ProviderException("Value is required for SET operation") + if params.value_type is None: + raise ProviderException("Value type is required for SET operation") + try: + typed_value = self.SnmpValueTypeHandler.convert_value( + params.value, + params.value_type ) - elif operation == 'SET': - value = kwargs.get('value') - value_type = kwargs.get('value_type', 'string').lower() - - if value is None: - raise ProviderException("Value is required for SET operation") - - # Map of supported SNMP value types and their corresponding classes - type_map = { - 'integer': Integer, - 'int': Integer, - 'int32': Integer, - 'string': OctetString, - 'octetstring': OctetString, - 'ipaddress': IpAddress, - 'counter32': Counter32, - 'counter64': Counter64, - 'gauge32': Gauge32, - 'unsigned32': Unsigned32, - 'timeticks': TimeTicks, - 'bits': Bits, - 'opaque': Opaque - } - - if value_type not in type_map: - raise ProviderException( - f"Unsupported value type: {value_type}. " - f"Supported types are: {', '.join(type_map.keys())}" - ) - - try: - # Convert the value to the appropriate SNMP type - snmp_type = type_map[value_type] - if value_type in ['integer', 'int', 'int32', 'counter32', 'counter64', 'gauge32', 'unsigned32', 'timeticks']: - typed_value = snmp_type(int(value)) - elif value_type == 'ipaddress': - # Validate IP address format - import ipaddress - ipaddress.ip_address(value) # This will raise ValueError if invalid - typed_value = snmp_type(value) - elif value_type == 'bits': - # Expect a comma-separated list of bit positions - bit_positions = [int(x.strip()) for x in str(value).split(',')] - typed_value = snmp_type(names=bit_positions) - else: - typed_value = snmp_type(str(value)) - - except (ValueError, TypeError) as e: - raise ProviderException( - f"Invalid value format for type {value_type}: {str(e)}" - ) - - error_indication, error_status, error_index, var_binds = await set_cmd( - snmp_engine, - auth_data, - transport_target, - context_data, - ObjectType(ObjectIdentity(oid), typed_value) + obj_identity = ObjectIdentity(params.oid) + obj_type = ObjectType(obj_identity, typed_value) + except Exception as e: + raise ProviderException(f"Failed to convert value for SET operation: {str(e)}") + else: + obj_type = ObjectType(ObjectIdentity(params.oid)) + + # Execute SNMP command + cmd_map = { + 'GET': get_cmd, + 'GETNEXT': next_cmd, + 'GETBULK': bulk_cmd, + 'SET': set_cmd + } + + cmd_func = cmd_map.get(params.operation) + if not cmd_func: + raise ProviderException(f"Unsupported SNMP operation: {params.operation}") + + # Add GETBULK specific parameters + cmd_args = [snmp_engine, auth_data, transport_target, context_data] + if params.operation == 'GETBULK': + cmd_args.extend([0, 25]) # non-repeaters, max-repetitions + cmd_args.append(obj_type) + + error_indication, error_status, error_index, var_binds = await cmd_func(*cmd_args) + + if error_indication: + error_msg = str(error_indication) + if "No SNMP response received before timeout" in error_msg: + raise ProviderException( + f"SNMP {params.operation} timed out after {params.timeout} seconds " + f"with {params.retries} retries. Consider increasing timeout or retries." ) - else: - raise ProviderException(f"Unsupported SNMP operation: {operation}") - - if error_indication: - error_msg = str(error_indication) - if "No SNMP response received before timeout" in error_msg: - raise ProviderException( - f"SNMP {operation} timed out after {timeout} seconds with {retries} retries. " - f"Consider increasing timeout or retries." - ) - raise ProviderException(f"SNMP error: {error_indication}") - elif error_status: - raise ProviderException(f"SNMP error: {error_status.prettyPrint()}") - - results = [] - for var_bind in var_binds: - name, value = var_bind - # Use MIB view controller to translate OID to proper MIB name - try: - if self.mib_view_controller: - mib_name = self.mib_view_controller.get_node_name(name) - if len(mib_name) > 0: - # First element is the MIB module name (e.g. 'SNMPv2-MIB') - mib_module = str(mib_name[0]) - # Rest are the object parts (e.g. 'sysDescr', '0') - object_parts = [] - for part in mib_name[1:]: - if isinstance(part, (str, int)): - object_parts.append(str(part)) - oid = f"{mib_module}::{'.'.join(object_parts)}" - else: - oid = name.prettyPrint() - else: - oid = name.prettyPrint() - except Exception as e: - self.logger.debug(f"Failed to translate OID using MIB: {str(e)}") - oid = name.prettyPrint() - - results.append({ - 'oid': oid, - 'value': value.prettyPrint() - }) - - return results - - finally: - # Clean up transport dispatcher - if snmp_engine and hasattr(snmp_engine, 'transport_dispatcher'): - dispatcher = snmp_engine.transport_dispatcher - if hasattr(dispatcher, 'loopingcall'): - try: - if not dispatcher.loopingcall.done(): - dispatcher.loopingcall.cancel() - await dispatcher.loopingcall - except (asyncio.CancelledError, Exception) as e: - self.logger.debug(f"Error canceling dispatcher timeout: {e}") - - try: - dispatcher.close_dispatcher() - except Exception as e: - self.logger.debug(f"Error closing dispatcher: {e}") + raise ProviderException(f"SNMP error: {error_indication}") + elif error_status: + raise ProviderException(f"SNMP error: {error_status.prettyPrint()}") + + return [ + { + 'oid': self._format_oid(name), + 'value': value.prettyPrint() + } + for name, value in var_binds + ] except Exception as e: - self.logger.error(f"Error performing SNMP {operation}: {str(e)}") - raise ProviderException(f"SNMP {operation} failed: {str(e)}") + self.logger.error(f"Error performing SNMP {params.operation}: {str(e)}") + raise ProviderException(f"SNMP {params.operation} failed: {str(e)}") finally: - # Ensure engine resources are cleaned up if snmp_engine and hasattr(snmp_engine, 'transport_dispatcher'): - try: - snmp_engine.transport_dispatcher.close_dispatcher() - except Exception as e: - self.logger.debug(f"Error during final cleanup: {e}") \ No newline at end of file + await self._cleanup_dispatcher(snmp_engine.transport_dispatcher, f"query_{params.operation}") + + def _get_auth_data(self) -> typing.Union[UsmUserData, CommunityData]: + """Get authentication data based on SNMP version.""" + if self.authentication_config.snmp_version == 'v3': + return UsmUserData( + self.authentication_config.username, + self.authentication_config.auth_key, + self.authentication_config.priv_key + ) + return CommunityData( + self.authentication_config.community_string, + mpModel=0 if self.authentication_config.snmp_version == 'v1' else 1 + ) + + def _format_oid(self, name: ObjectIdentifier) -> str: + """Format OID using MIB information if available.""" + try: + if self.mib_view_controller: + mib_name = self.mib_view_controller.get_node_name(name) + if len(mib_name) > 0: + mib_module = str(mib_name[0]) + object_parts = [ + str(part) for part in mib_name[1:] + if isinstance(part, (str, int)) + ] + return f"{mib_module}::{'.'.join(object_parts)}" + except Exception as e: + self.logger.debug(f"Failed to translate OID using MIB: {str(e)}") + + return name.prettyPrint() \ No newline at end of file diff --git a/tests/test_snmp_provider.py b/tests/test_snmp_provider.py index 73584d828..d92947179 100644 --- a/tests/test_snmp_provider.py +++ b/tests/test_snmp_provider.py @@ -7,11 +7,15 @@ import asyncio import os from unittest.mock import patch, MagicMock, PropertyMock +from pysnmp.proto.rfc1902 import ( + Integer, OctetString, IpAddress, Counter32, Counter64, + Gauge32, Unsigned32, TimeTicks, Bits, Opaque +) from keep.contextmanager.contextmanager import ContextManager from keep.providers.models.provider_config import ProviderConfig from keep.providers.snmp_provider.snmp_provider import SnmpProvider, SnmpProviderAuthConfig from keep.exceptions.provider_exception import ProviderException -from keep.api.models.alert import AlertDto +from keep.api.models.alert import AlertDto, AlertSeverity # pytestmark = pytest.mark.asyncio(scope="session") @@ -376,5 +380,112 @@ def test_mib_compiler_setup(context_manager): # Clean up os.rmdir(test_mib_dir) +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_set_invalid_value_type(snmp_provider): + """Test SNMP SET operation with invalid value type""" + with pytest.raises(ProviderException) as exc_info: + await snmp_provider.query( + operation='SET', + host='demo.snmplabs.com', + port=161, + oid='1.3.6.1.2.1.1.1.0', + value='test', + value_type='invalid_type' + ) + assert "Unsupported value type" in str(exc_info.value) + +@pytest.mark.asyncio(loop_scope="function") +async def test_snmp_network_timeout(snmp_provider): + """Test SNMP operation with network timeout""" + with patch('pysnmp.hlapi.v3arch.asyncio.transport.UdpTransportTarget.create') as mock_transport: + # Simulate a timeout by raising socket.timeout + import socket + mock_transport.side_effect = socket.timeout("Operation timed out") + + with pytest.raises(ProviderException) as exc_info: + await snmp_provider.query( + operation='GET', + host='non.existent.host', + port=161, + oid='1.3.6.1.2.1.1.1.0', + timeout=1, + retries=1 + ) + assert "Operation timed out" in str(exc_info.value) + +def test_value_type_handler(snmp_provider): + """Test SnmpValueTypeHandler value conversion""" + handler = snmp_provider.SnmpValueTypeHandler + + # Test integer conversion + int_value = handler.convert_value('42', 'integer') + assert isinstance(int_value, Integer) + assert int_value == 42 + + # Test string conversion + str_value = handler.convert_value('test', 'string') + assert isinstance(str_value, OctetString) + assert str(str_value) == 'test' + + # Test IP address validation + ip_value = handler.convert_value('192.168.1.1', 'ipaddress') + assert isinstance(ip_value, IpAddress) + assert str(ip_value.prettyPrint()) == '192.168.1.1' + + # Test invalid IP address + with pytest.raises(ProviderException) as exc_info: + handler.convert_value('invalid.ip', 'ipaddress') + assert "does not appear to be an IPv4 or IPv6 address" in str(exc_info.value) + + # Test bits conversion + bits_value = handler.convert_value('1,2,3', 'bits') + assert isinstance(bits_value, Bits) + assert bits_value.names == [1, 2, 3] + +@pytest.mark.asyncio(loop_scope="function") +async def test_trap_receiver_callbacks(snmp_provider): + """Test SNMP trap receiver callback handling""" + # Create a mock trap + trap_data = { + 'trap_type': 'Test Trap', + 'severity': 'critical', + 'message': 'Test trap message', + 'variables': { + 'SNMPv2-MIB::sysLocation.0': 'Test Location', + 'SNMPv2-MIB::sysContact.0': 'Test Contact' + } + } + + # Create the alert first + alert = snmp_provider._format_alert(trap_data) + + # Mock the alert pushing + with patch.object(snmp_provider, '_push_alert') as mock_push: + # Push the alert + snmp_provider._push_alert(alert.dict()) + + # Verify alert was pushed + mock_push.assert_called_once() + pushed_alert = mock_push.call_args[0][0] + assert pushed_alert['name'] == 'Test Trap' + assert pushed_alert['severity'] == 'critical' + +@pytest.mark.asyncio(loop_scope="function") +async def test_trap_receiver_invalid_data(snmp_provider): + """Test SNMP trap receiver handling of invalid data""" + invalid_trap_data = { + 'trap_type': 'Unknown Trap', # Changed from None to a string + 'severity': 'invalid_severity', + 'message': '', # Changed from None to empty string + 'variables': {} # Changed from None to empty dict + } + + # Verify that invalid data doesn't cause errors + alert = snmp_provider._format_alert(invalid_trap_data) + assert alert.name == 'Unknown Trap' # Use the provided trap type + assert alert.severity == 'info' # Default severity as string + assert alert.message == '' # Empty message + assert isinstance(alert.description, str) # Description should be a string + if __name__ == "__main__": pytest.main([__file__]) \ No newline at end of file